diff --git a/providers/anthropic/anthropic.go b/providers/anthropic/anthropic.go index 2f05b2f12..8654fdd18 100644 --- a/providers/anthropic/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -39,9 +39,9 @@ func betaRequestOptions(flags []string) []option.RequestOption { } // buildRequestOptions constructs the common request options shared -// by Generate and Stream: user-agent, raw tool injection, and any -// beta API flags. -func buildRequestOptions(call fantasy.Call, rawTools []json.RawMessage, betaFlags []string) []option.RequestOption { +// by Generate and Stream: user-agent, raw tool injection, any +// beta API flags, and Bedrock-specific overrides. +func buildRequestOptions(call fantasy.Call, rawTools []json.RawMessage, betaFlags []string, bedrockSystem string) []option.RequestOption { reqOpts := callUARequestOptions(call) if len(rawTools) > 0 { // Tools are injected as raw JSON rather than via params.Tools @@ -53,6 +53,9 @@ func buildRequestOptions(call fantasy.Call, rawTools []json.RawMessage, betaFlag if len(betaFlags) > 0 { reqOpts = append(reqOpts, betaRequestOptions(betaFlags)...) } + if bedrockSystem != "" { + reqOpts = append(reqOpts, option.WithJSONSet("system", bedrockSystem)) + } return reqOpts } @@ -272,6 +275,7 @@ func (a languageModel) prepareParams(call fantasy.Call) ( rawTools []json.RawMessage, warnings []fantasy.CallWarning, betaFlags []string, + bedrockSystem string, err error, ) { params = &anthropic.MessageNewParams{} @@ -279,7 +283,7 @@ func (a languageModel) prepareParams(call fantasy.Call) ( if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, nil, nil, nil, &fantasy.Error{Title: "invalid argument", Message: "anthropic provider options should be *anthropic.ProviderOptions"} + return nil, nil, nil, nil, "", &fantasy.Error{Title: "invalid argument", Message: "anthropic provider options should be *anthropic.ProviderOptions"} } } sendReasoning := true @@ -288,6 +292,12 @@ func (a languageModel) prepareParams(call fantasy.Call) ( } systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning) + if a.options.useBedrock { + if flattened, ok := flattenSystemForBedrock(systemBlocks); ok { + bedrockSystem = flattened + } + } + if call.FrequencyPenalty != nil { warnings = append(warnings, fantasy.CallWarning{ Type: fantasy.CallWarningTypeUnsupportedSetting, @@ -302,6 +312,9 @@ func (a languageModel) prepareParams(call fantasy.Call) ( } params.System = systemBlocks + if bedrockSystem != "" { + params.System = nil + } params.Messages = messages params.Model = anthropic.Model(a.modelID) params.MaxTokens = 4096 @@ -330,7 +343,7 @@ func (a languageModel) prepareParams(call fantasy.Call) ( params.Thinking.OfAdaptive = &adaptive case providerOptions.Thinking != nil: if providerOptions.Thinking.BudgetTokens == 0 { - return nil, nil, nil, nil, &fantasy.Error{Title: "no budget", Message: "thinking requires budget"} + return nil, nil, nil, nil, "", &fantasy.Error{Title: "no budget", Message: "thinking requires budget"} } params.Thinking = anthropic.ThinkingConfigParamOfEnabled(providerOptions.Thinking.BudgetTokens) if call.Temperature != nil { @@ -373,7 +386,7 @@ func (a languageModel) prepareParams(call fantasy.Call) ( warnings = append(warnings, toolWarnings...) } - return params, rawTools, warnings, betaFlags, nil + return params, rawTools, warnings, betaFlags, bedrockSystem, nil } func (a *provider) Name() string { @@ -1106,11 +1119,11 @@ func mapFinishReason(finishReason string) fantasy.FinishReason { // Generate implements fantasy.LanguageModel. func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { - params, rawTools, warnings, betaFlags, err := a.prepareParams(call) + params, rawTools, warnings, betaFlags, bedrockSystem, err := a.prepareParams(call) if err != nil { return nil, err } - reqOpts := buildRequestOptions(call, rawTools, betaFlags) + reqOpts := buildRequestOptions(call, rawTools, betaFlags, bedrockSystem) response, err := a.client.Messages.New(ctx, *params, reqOpts...) if err != nil { @@ -1235,12 +1248,12 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas // Stream implements fantasy.LanguageModel. func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - params, rawTools, warnings, betaFlags, err := a.prepareParams(call) + params, rawTools, warnings, betaFlags, bedrockSystem, err := a.prepareParams(call) if err != nil { return nil, err } - reqOpts := buildRequestOptions(call, rawTools, betaFlags) + reqOpts := buildRequestOptions(call, rawTools, betaFlags, bedrockSystem) stream := a.client.Messages.NewStreaming(ctx, *params, reqOpts...) acc := anthropic.Message{} diff --git a/providers/anthropic/anthropic_test.go b/providers/anthropic/anthropic_test.go index 4387a34fa..15cbf8d0d 100644 --- a/providers/anthropic/anthropic_test.go +++ b/providers/anthropic/anthropic_test.go @@ -504,6 +504,121 @@ func TestStream_SendsOutputConfigEffort(t *testing.T) { requireAnthropicEffort(t, call.body, EffortHigh) } +func TestBedrockSystemPromptWireShape(t *testing.T) { + t.Parallel() + + singleSystemPrompt := fantasy.Prompt{ + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "you are helpful"}, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "hello"}, + }, + }, + } + promptCases := []struct { + name string + prompt fantasy.Prompt + wantSystem string + }{ + { + name: "single system message", + prompt: singleSystemPrompt, + wantSystem: "you are helpful", + }, + { + name: "three consecutive system messages", + prompt: fantasy.Prompt{ + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "first instruction"}, + }, + }, + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "second instruction"}, + }, + }, + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "third instruction"}, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "hello"}, + }, + }, + }, + wantSystem: "first instruction\n\nsecond instruction\n\nthird instruction", + }, + { + name: "non contiguous system prompt keeps first block only", + prompt: fantasy.Prompt{ + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "first instruction"}, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "hello"}, + }, + }, + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "ignored instruction"}, + }, + }, + }, + wantSystem: "first instruction", + }, + } + callers := []struct { + name string + capture func(t *testing.T, prompt fantasy.Prompt, opts ...Option) map[string]any + }{ + {name: "generate", capture: captureAnthropicGenerateRequestBody}, + {name: "stream", capture: captureAnthropicStreamRequestBody}, + } + + for _, promptCase := range promptCases { + promptCase := promptCase + t.Run(promptCase.name, func(t *testing.T) { + t.Parallel() + + for _, caller := range callers { + caller := caller + t.Run(caller.name, func(t *testing.T) { + t.Parallel() + + body := caller.capture(t, promptCase.prompt, WithBedrock(), WithSkipAuth(true)) + requireAnthropicSystemString(t, body, promptCase.wantSystem) + }) + } + }) + } + + t.Run("direct anthropic keeps system blocks", func(t *testing.T) { + t.Parallel() + + body := captureAnthropicGenerateRequestBody(t, singleSystemPrompt, WithAPIKey("test-api-key")) + requireAnthropicSystemBlocks(t, body, []string{"you are helpful"}) + }) +} + type anthropicCall struct { method string path string @@ -595,6 +710,75 @@ func requireAnthropicEffort(t *testing.T, body map[string]any, expected Effort) require.Equal(t, "adaptive", thinking["type"]) } +func captureAnthropicGenerateRequestBody(t *testing.T, prompt fantasy.Prompt, opts ...Option) map[string]any { + t.Helper() + + server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse()) + defer server.Close() + + providerOptions := append([]Option{}, opts...) + providerOptions = append(providerOptions, WithBaseURL(server.URL)) + provider, err := New(providerOptions...) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514") + require.NoError(t, err) + + _, err = model.Generate(context.Background(), fantasy.Call{Prompt: prompt}) + require.NoError(t, err) + + return awaitAnthropicCall(t, calls).body +} + +func captureAnthropicStreamRequestBody(t *testing.T, prompt fantasy.Prompt, opts ...Option) map[string]any { + t.Helper() + + server, calls := newAnthropicStreamingServer([]string{ + "event: message_start\n", + "data: {\"type\":\"message_start\",\"message\":{}}\n\n", + "event: message_stop\n", + "data: {\"type\":\"message_stop\"}\n\n", + }) + defer server.Close() + + providerOptions := append([]Option{}, opts...) + providerOptions = append(providerOptions, WithBaseURL(server.URL)) + provider, err := New(providerOptions...) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514") + require.NoError(t, err) + + stream, err := model.Stream(context.Background(), fantasy.Call{Prompt: prompt}) + require.NoError(t, err) + + stream(func(fantasy.StreamPart) bool { return true }) + + return awaitAnthropicCall(t, calls).body +} + +func requireAnthropicSystemString(t *testing.T, body map[string]any, expected string) { + t.Helper() + + system, ok := body["system"].(string) + require.Truef(t, ok, "expected system to be a JSON string, got %T (%#v)", body["system"], body["system"]) + require.Equal(t, expected, system) +} + +func requireAnthropicSystemBlocks(t *testing.T, body map[string]any, expected []string) { + t.Helper() + + system, ok := body["system"].([]any) + require.Truef(t, ok, "expected system to be a JSON array, got %T (%#v)", body["system"], body["system"]) + require.Len(t, system, len(expected)) + for i, want := range expected { + block, ok := system[i].(map[string]any) + require.Truef(t, ok, "expected system[%d] to be an object, got %T (%#v)", i, system[i], system[i]) + require.Equal(t, want, block["text"]) + require.Equal(t, "text", block["type"]) + } +} + func testPrompt() fantasy.Prompt { return fantasy.Prompt{ { @@ -1574,7 +1758,8 @@ func TestComputerUseToolJSON(t *testing.T) { } _, err := computerUseToolJSON(pdt) require.Error(t, err) - require.Contains(t, err.Error(), "tool_version arg is missing") }) + require.Contains(t, err.Error(), "tool_version arg is missing") + }) t.Run("returns error for unsupported version", func(t *testing.T) { t.Parallel() diff --git a/providers/anthropic/bedrock.go b/providers/anthropic/bedrock.go index 8d9b94959..59a10a6dc 100644 --- a/providers/anthropic/bedrock.go +++ b/providers/anthropic/bedrock.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/smithy-go/auth/bearer" + "github.com/charmbracelet/anthropic-sdk-go" ) func bedrockBasicAuthConfig(apiKey string) aws.Config { @@ -16,6 +17,24 @@ func bedrockBasicAuthConfig(apiKey string) aws.Config { } } +// flattenSystemForBedrock converts system text blocks into the string +// form required by Bedrock's Anthropic Messages schema. Direct +// Anthropic accepts an array of text blocks, so only Bedrock calls use +// this helper. +func flattenSystemForBedrock(blocks []anthropic.TextBlockParam) (string, bool) { + parts := make([]string, 0, len(blocks)) + for _, block := range blocks { + if block.Text == "" { + continue + } + parts = append(parts, block.Text) + } + if len(parts) == 0 { + return "", false + } + return strings.Join(parts, "\n\n"), true +} + func bedrockPrefixModelWithRegion(modelID string) string { region := os.Getenv("AWS_REGION") if len(region) < 2 {