Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ func betaRequestOptions(flags []string) []option.RequestOption {
}

// buildRequestOptions constructs the common request options shared
// by Generate and Stream: user-agent, raw tool injection, and any
// beta API flags.
func buildRequestOptions(call fantasy.Call, rawTools []json.RawMessage, betaFlags []string) []option.RequestOption {
// by Generate and Stream: user-agent, raw tool injection, any
// beta API flags, and Bedrock-specific overrides.
func buildRequestOptions(call fantasy.Call, rawTools []json.RawMessage, betaFlags []string, bedrockSystem string) []option.RequestOption {
reqOpts := callUARequestOptions(call)
if len(rawTools) > 0 {
// Tools are injected as raw JSON rather than via params.Tools
Expand All @@ -53,6 +53,9 @@ func buildRequestOptions(call fantasy.Call, rawTools []json.RawMessage, betaFlag
if len(betaFlags) > 0 {
reqOpts = append(reqOpts, betaRequestOptions(betaFlags)...)
}
if bedrockSystem != "" {
reqOpts = append(reqOpts, option.WithJSONSet("system", bedrockSystem))
}
return reqOpts
}

Expand Down Expand Up @@ -272,14 +275,15 @@ func (a languageModel) prepareParams(call fantasy.Call) (
rawTools []json.RawMessage,
warnings []fantasy.CallWarning,
betaFlags []string,
bedrockSystem string,
err error,
) {
params = &anthropic.MessageNewParams{}
providerOptions := &ProviderOptions{}
if v, ok := call.ProviderOptions[Name]; ok {
providerOptions, ok = v.(*ProviderOptions)
if !ok {
return nil, nil, nil, nil, &fantasy.Error{Title: "invalid argument", Message: "anthropic provider options should be *anthropic.ProviderOptions"}
return nil, nil, nil, nil, "", &fantasy.Error{Title: "invalid argument", Message: "anthropic provider options should be *anthropic.ProviderOptions"}
}
}
sendReasoning := true
Expand All @@ -288,6 +292,12 @@ func (a languageModel) prepareParams(call fantasy.Call) (
}
systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning)

if a.options.useBedrock {
if flattened, ok := flattenSystemForBedrock(systemBlocks); ok {
bedrockSystem = flattened
}
}

if call.FrequencyPenalty != nil {
warnings = append(warnings, fantasy.CallWarning{
Type: fantasy.CallWarningTypeUnsupportedSetting,
Expand All @@ -302,6 +312,9 @@ func (a languageModel) prepareParams(call fantasy.Call) (
}

params.System = systemBlocks
if bedrockSystem != "" {
params.System = nil
}
params.Messages = messages
params.Model = anthropic.Model(a.modelID)
params.MaxTokens = 4096
Expand Down Expand Up @@ -330,7 +343,7 @@ func (a languageModel) prepareParams(call fantasy.Call) (
params.Thinking.OfAdaptive = &adaptive
case providerOptions.Thinking != nil:
if providerOptions.Thinking.BudgetTokens == 0 {
return nil, nil, nil, nil, &fantasy.Error{Title: "no budget", Message: "thinking requires budget"}
return nil, nil, nil, nil, "", &fantasy.Error{Title: "no budget", Message: "thinking requires budget"}
}
params.Thinking = anthropic.ThinkingConfigParamOfEnabled(providerOptions.Thinking.BudgetTokens)
if call.Temperature != nil {
Expand Down Expand Up @@ -373,7 +386,7 @@ func (a languageModel) prepareParams(call fantasy.Call) (
warnings = append(warnings, toolWarnings...)
}

return params, rawTools, warnings, betaFlags, nil
return params, rawTools, warnings, betaFlags, bedrockSystem, nil
}

func (a *provider) Name() string {
Expand Down Expand Up @@ -1106,11 +1119,11 @@ func mapFinishReason(finishReason string) fantasy.FinishReason {

// Generate implements fantasy.LanguageModel.
func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
params, rawTools, warnings, betaFlags, err := a.prepareParams(call)
params, rawTools, warnings, betaFlags, bedrockSystem, err := a.prepareParams(call)
if err != nil {
return nil, err
}
reqOpts := buildRequestOptions(call, rawTools, betaFlags)
reqOpts := buildRequestOptions(call, rawTools, betaFlags, bedrockSystem)

response, err := a.client.Messages.New(ctx, *params, reqOpts...)
if err != nil {
Expand Down Expand Up @@ -1235,12 +1248,12 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas

// Stream implements fantasy.LanguageModel.
func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
params, rawTools, warnings, betaFlags, err := a.prepareParams(call)
params, rawTools, warnings, betaFlags, bedrockSystem, err := a.prepareParams(call)
if err != nil {
return nil, err
}

reqOpts := buildRequestOptions(call, rawTools, betaFlags)
reqOpts := buildRequestOptions(call, rawTools, betaFlags, bedrockSystem)

stream := a.client.Messages.NewStreaming(ctx, *params, reqOpts...)
acc := anthropic.Message{}
Expand Down
187 changes: 186 additions & 1 deletion providers/anthropic/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,121 @@ func TestStream_SendsOutputConfigEffort(t *testing.T) {
requireAnthropicEffort(t, call.body, EffortHigh)
}

func TestBedrockSystemPromptWireShape(t *testing.T) {
t.Parallel()

singleSystemPrompt := fantasy.Prompt{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "you are helpful"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "hello"},
},
},
}
promptCases := []struct {
name string
prompt fantasy.Prompt
wantSystem string
}{
{
name: "single system message",
prompt: singleSystemPrompt,
wantSystem: "you are helpful",
},
{
name: "three consecutive system messages",
prompt: fantasy.Prompt{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "first instruction"},
},
},
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "second instruction"},
},
},
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "third instruction"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "hello"},
},
},
},
wantSystem: "first instruction\n\nsecond instruction\n\nthird instruction",
},
{
name: "non contiguous system prompt keeps first block only",
prompt: fantasy.Prompt{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "first instruction"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "hello"},
},
},
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "ignored instruction"},
},
},
},
wantSystem: "first instruction",
},
}
callers := []struct {
name string
capture func(t *testing.T, prompt fantasy.Prompt, opts ...Option) map[string]any
}{
{name: "generate", capture: captureAnthropicGenerateRequestBody},
{name: "stream", capture: captureAnthropicStreamRequestBody},
}

for _, promptCase := range promptCases {
promptCase := promptCase
t.Run(promptCase.name, func(t *testing.T) {
t.Parallel()

for _, caller := range callers {
caller := caller
t.Run(caller.name, func(t *testing.T) {
t.Parallel()

body := caller.capture(t, promptCase.prompt, WithBedrock(), WithSkipAuth(true))
requireAnthropicSystemString(t, body, promptCase.wantSystem)
})
}
})
}

t.Run("direct anthropic keeps system blocks", func(t *testing.T) {
t.Parallel()

body := captureAnthropicGenerateRequestBody(t, singleSystemPrompt, WithAPIKey("test-api-key"))
requireAnthropicSystemBlocks(t, body, []string{"you are helpful"})
})
}

type anthropicCall struct {
method string
path string
Expand Down Expand Up @@ -595,6 +710,75 @@ func requireAnthropicEffort(t *testing.T, body map[string]any, expected Effort)
require.Equal(t, "adaptive", thinking["type"])
}

func captureAnthropicGenerateRequestBody(t *testing.T, prompt fantasy.Prompt, opts ...Option) map[string]any {
t.Helper()

server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
defer server.Close()

providerOptions := append([]Option{}, opts...)
providerOptions = append(providerOptions, WithBaseURL(server.URL))
provider, err := New(providerOptions...)
require.NoError(t, err)

model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
require.NoError(t, err)

_, err = model.Generate(context.Background(), fantasy.Call{Prompt: prompt})
require.NoError(t, err)

return awaitAnthropicCall(t, calls).body
}

func captureAnthropicStreamRequestBody(t *testing.T, prompt fantasy.Prompt, opts ...Option) map[string]any {
t.Helper()

server, calls := newAnthropicStreamingServer([]string{
"event: message_start\n",
"data: {\"type\":\"message_start\",\"message\":{}}\n\n",
"event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n\n",
})
defer server.Close()

providerOptions := append([]Option{}, opts...)
providerOptions = append(providerOptions, WithBaseURL(server.URL))
provider, err := New(providerOptions...)
require.NoError(t, err)

model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
require.NoError(t, err)

stream, err := model.Stream(context.Background(), fantasy.Call{Prompt: prompt})
require.NoError(t, err)

stream(func(fantasy.StreamPart) bool { return true })

return awaitAnthropicCall(t, calls).body
}

func requireAnthropicSystemString(t *testing.T, body map[string]any, expected string) {
t.Helper()

system, ok := body["system"].(string)
require.Truef(t, ok, "expected system to be a JSON string, got %T (%#v)", body["system"], body["system"])
require.Equal(t, expected, system)
}

func requireAnthropicSystemBlocks(t *testing.T, body map[string]any, expected []string) {
t.Helper()

system, ok := body["system"].([]any)
require.Truef(t, ok, "expected system to be a JSON array, got %T (%#v)", body["system"], body["system"])
require.Len(t, system, len(expected))
for i, want := range expected {
block, ok := system[i].(map[string]any)
require.Truef(t, ok, "expected system[%d] to be an object, got %T (%#v)", i, system[i], system[i])
require.Equal(t, want, block["text"])
require.Equal(t, "text", block["type"])
}
}

func testPrompt() fantasy.Prompt {
return fantasy.Prompt{
{
Expand Down Expand Up @@ -1574,7 +1758,8 @@ func TestComputerUseToolJSON(t *testing.T) {
}
_, err := computerUseToolJSON(pdt)
require.Error(t, err)
require.Contains(t, err.Error(), "tool_version arg is missing") })
require.Contains(t, err.Error(), "tool_version arg is missing")
})

t.Run("returns error for unsupported version", func(t *testing.T) {
t.Parallel()
Expand Down
19 changes: 19 additions & 0 deletions providers/anthropic/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/smithy-go/auth/bearer"
"github.com/charmbracelet/anthropic-sdk-go"
)

func bedrockBasicAuthConfig(apiKey string) aws.Config {
Expand All @@ -16,6 +17,24 @@ func bedrockBasicAuthConfig(apiKey string) aws.Config {
}
}

// flattenSystemForBedrock converts system text blocks into the string
// form required by Bedrock's Anthropic Messages schema. Direct
// Anthropic accepts an array of text blocks, so only Bedrock calls use
// this helper.
func flattenSystemForBedrock(blocks []anthropic.TextBlockParam) (string, bool) {
parts := make([]string, 0, len(blocks))
for _, block := range blocks {
if block.Text == "" {
continue
}
parts = append(parts, block.Text)
}
if len(parts) == 0 {
return "", false
}
return strings.Join(parts, "\n\n"), true
}

func bedrockPrefixModelWithRegion(modelID string) string {
region := os.Getenv("AWS_REGION")
if len(region) < 2 {
Expand Down