From b7e7541ae86e83478cabc9e15b88c7decc222eba Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Thu, 30 Apr 2026 01:27:33 -0300 Subject: [PATCH 01/15] Enable named OpenAI-compatible provider profiles Claw could only target one generic OPENAI_BASE_URL at a time, which made Z.AI, MiniMax, Moonshot, and OpenAI switching depend on shell mutation. Add modelProviders config profiles, provider/model resolution for /model, and runtime construction from the selected profile while preserving existing built-in provider routing. Constraint: Keep existing env-based Anthropic, xAI, OpenAI, and DashScope routing intact. Rejected: Reuse aliases only | aliases cannot carry base URL or credential source to the runtime client. Confidence: medium Scope-risk: moderate Directive: Keep model provider credentials env-first; avoid requiring source-controlled apiKey values. Tested: cargo fmt; cargo test -p runtime parses_model_provider_profiles_from_settings -- --nocapture; cargo test -p rusty-claude-cli configured_model_provider -- --nocapture Not-tested: Live calls to every configured third-party provider. --- USAGE.md | 49 +++++ rust/crates/api/src/client.rs | 10 + rust/crates/runtime/src/config.rs | 174 +++++++++++++++++ rust/crates/runtime/src/config_validate.rs | 44 +++++ rust/crates/rusty-claude-cli/src/main.rs | 216 +++++++++++++++++++-- 5 files changed, 472 insertions(+), 21 deletions(-) diff --git a/USAGE.md b/USAGE.md index c8e7b09692..bca4cc5521 100644 --- a/USAGE.md +++ b/USAGE.md @@ -308,6 +308,55 @@ The OpenAI-compatible backend also serves as the gateway for **OpenRouter**, **O **Model-name prefix routing:** If a model name starts with `openai/`, `gpt-`, `qwen/`, or `qwen-`, the provider is selected by the prefix regardless of which env vars are set. This prevents accidental misrouting to Anthropic when multiple credentials exist in the environment. +### Configured OpenAI-compatible providers + +If you use several OpenAI-compatible providers, define named provider profiles in `settings.json` instead of changing `OPENAI_BASE_URL` before every run. Each profile gets its own base URL, credential env var, and model allow-list: + +```json +{ + "model": "zai/glm-5.1", + "modelProviders": { + "zai": { + "type": "openai-compatible", + "baseUrl": "https://api.z.ai/api/paas/v4", + "apiKeyEnv": "Z_AI_API_KEY", + "models": ["glm-5.1", "glm-4.6"], + "defaultModel": "glm-5.1" + }, + "minimax": { + "type": "openai-compatible", + "baseUrl": "https://api.minimax.io/v1", + "apiKeyEnv": "MINIMAX_API_KEY", + "models": ["MiniMax-M2.7", "MiniMax-M2.7-highspeed"], + "defaultModel": "MiniMax-M2.7-highspeed" + }, + "moonshot": { + "type": "openai-compatible", + "baseUrl": "https://api.moonshot.ai/v1", + "apiKeyEnv": "MOONSHOT_API_KEY", + "models": ["kimi-k2.5"], + "defaultModel": "kimi-k2.5" + } + } +} +``` + +Use `/model provider/model` in the REPL to switch without restarting: + +```text +/model zai/glm-5.1 +/model minimax/MiniMax-M2.7-highspeed +/model moonshot/kimi-k2.5 +``` + +You can also use the provider name alone when it has `defaultModel` configured: + +```text +/model minimax +``` + +Prefer `apiKeyEnv` so secrets stay out of source-controlled project settings. `apiKey` is supported for local-only files when an environment variable is not practical. + ### Tested models and aliases These are the models registered in the built-in alias table with known token limits: diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 6e68fd2e2c..89987b4a8d 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -46,6 +46,16 @@ impl ProviderClient { } } + #[must_use] + pub fn from_openai_compatible_profile( + api_key: impl Into, + base_url: impl Into, + ) -> Self { + Self::OpenAi( + OpenAiCompatClient::new(api_key, OpenAiCompatConfig::openai()).with_base_url(base_url), + ) + } + #[must_use] pub const fn provider_kind(&self) -> ProviderKind { match self { diff --git a/rust/crates/runtime/src/config.rs b/rust/crates/runtime/src/config.rs index 1566189282..e9777d2af0 100644 --- a/rust/crates/runtime/src/config.rs +++ b/rust/crates/runtime/src/config.rs @@ -64,9 +64,21 @@ pub struct RuntimeFeatureConfig { permission_rules: RuntimePermissionRuleConfig, sandbox: SandboxConfig, provider_fallbacks: ProviderFallbackConfig, + model_providers: BTreeMap, trusted_roots: Vec, } +/// User-configured model provider profile. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModelProviderConfig { + provider_type: String, + base_url: String, + api_key_env: Option, + api_key: Option, + models: Vec, + default_model: Option, +} + /// Ordered chain of fallback model identifiers used when the primary /// provider returns a retryable failure (429/500/503/etc.). The chain is /// strict: each entry is tried in order until one succeeds. @@ -314,6 +326,7 @@ impl ConfigLoader { permission_rules: parse_optional_permission_rules(&merged_value)?, sandbox: parse_optional_sandbox_config(&merged_value)?, provider_fallbacks: parse_optional_provider_fallbacks(&merged_value)?, + model_providers: parse_optional_model_providers(&merged_value)?, trusted_roots: parse_optional_trusted_roots(&merged_value)?, }; @@ -410,6 +423,11 @@ impl RuntimeConfig { &self.feature_config.provider_fallbacks } + #[must_use] + pub fn model_providers(&self) -> &BTreeMap { + &self.feature_config.model_providers + } + #[must_use] pub fn trusted_roots(&self) -> &[String] { &self.feature_config.trusted_roots @@ -479,12 +497,68 @@ impl RuntimeFeatureConfig { &self.provider_fallbacks } + #[must_use] + pub fn model_providers(&self) -> &BTreeMap { + &self.model_providers + } + #[must_use] pub fn trusted_roots(&self) -> &[String] { &self.trusted_roots } } +impl ModelProviderConfig { + #[must_use] + pub fn new( + provider_type: String, + base_url: String, + api_key_env: Option, + api_key: Option, + models: Vec, + default_model: Option, + ) -> Self { + Self { + provider_type, + base_url, + api_key_env, + api_key, + models, + default_model, + } + } + + #[must_use] + pub fn provider_type(&self) -> &str { + &self.provider_type + } + + #[must_use] + pub fn base_url(&self) -> &str { + &self.base_url + } + + #[must_use] + pub fn api_key_env(&self) -> Option<&str> { + self.api_key_env.as_deref() + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + self.api_key.as_deref() + } + + #[must_use] + pub fn models(&self) -> &[String] { + &self.models + } + + #[must_use] + pub fn default_model(&self) -> Option<&str> { + self.default_model.as_deref() + } +} + impl ProviderFallbackConfig { #[must_use] pub fn new(primary: Option, fallbacks: Vec) -> Self { @@ -904,6 +978,61 @@ fn parse_optional_provider_fallbacks( Ok(ProviderFallbackConfig { primary, fallbacks }) } +fn parse_optional_model_providers( + root: &JsonValue, +) -> Result, ConfigError> { + let Some(object) = root.as_object() else { + return Ok(BTreeMap::new()); + }; + let Some(value) = object.get("modelProviders") else { + return Ok(BTreeMap::new()); + }; + let providers = expect_object(value, "merged settings.modelProviders")?; + let mut parsed = BTreeMap::new(); + for (name, value) in providers { + let context = format!("merged settings.modelProviders.{name}"); + let provider = expect_object(value, &context)?; + let provider_type = optional_string(provider, "type", &context)? + .unwrap_or("openai-compatible") + .to_string(); + if !matches!(provider_type.as_str(), "openai-compatible" | "openai") { + return Err(ConfigError::Parse(format!( + "{context}: unsupported provider type {provider_type}" + ))); + } + let base_url = expect_string(provider, "baseUrl", &context)?.to_string(); + let api_key_env = optional_string(provider, "apiKeyEnv", &context)?.map(str::to_string); + let api_key = optional_string(provider, "apiKey", &context)?.map(str::to_string); + let models = optional_string_array(provider, "models", &context)?.unwrap_or_default(); + let default_model = + optional_string(provider, "defaultModel", &context)?.map(str::to_string); + if models.is_empty() && default_model.is_none() { + return Err(ConfigError::Parse(format!( + "{context}: expected at least one model in models or defaultModel" + ))); + } + if let Some(default_model) = &default_model { + if !models.is_empty() && !models.iter().any(|model| model == default_model) { + return Err(ConfigError::Parse(format!( + "{context}: defaultModel must be listed in models" + ))); + } + } + parsed.insert( + name.clone(), + ModelProviderConfig::new( + provider_type, + base_url, + api_key_env, + api_key, + models, + default_model, + ), + ); + } + Ok(parsed) +} + fn parse_optional_trusted_roots(root: &JsonValue) -> Result, ConfigError> { let Some(object) = root.as_object() else { return Ok(Vec::new()); @@ -1812,6 +1941,51 @@ mod tests { fs::remove_dir_all(root).expect("cleanup temp dir"); } + #[test] + fn parses_model_provider_profiles_from_settings() { + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write( + home.join("settings.json"), + r#"{ + "model": "zai/glm-5.1", + "modelProviders": { + "zai": { + "type": "openai-compatible", + "baseUrl": "https://api.z.ai/api/paas/v4", + "apiKeyEnv": "Z_AI_API_KEY", + "models": ["glm-5.1", "glm-4.6"], + "defaultModel": "glm-5.1" + } + } + }"#, + ) + .expect("write settings"); + + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + let provider = loaded + .model_providers() + .get("zai") + .expect("zai provider should parse"); + + assert_eq!(loaded.model(), Some("zai/glm-5.1")); + assert_eq!(provider.provider_type(), "openai-compatible"); + assert_eq!(provider.base_url(), "https://api.z.ai/api/paas/v4"); + assert_eq!(provider.api_key_env(), Some("Z_AI_API_KEY")); + assert_eq!(provider.default_model(), Some("glm-5.1")); + assert_eq!( + provider.models(), + &["glm-5.1".to_string(), "glm-4.6".to_string()] + ); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + #[test] fn empty_settings_file_loads_defaults() { // given diff --git a/rust/crates/runtime/src/config_validate.rs b/rust/crates/runtime/src/config_validate.rs index 7a9c1c4adc..3a50bf5516 100644 --- a/rust/crates/runtime/src/config_validate.rs +++ b/rust/crates/runtime/src/config_validate.rs @@ -193,6 +193,10 @@ const TOP_LEVEL_FIELDS: &[FieldSpec] = &[ name: "providerFallbacks", expected: FieldType::Object, }, + FieldSpec { + name: "modelProviders", + expected: FieldType::Object, + }, FieldSpec { name: "trustedRoots", expected: FieldType::StringArray, @@ -310,6 +314,33 @@ const OAUTH_FIELDS: &[FieldSpec] = &[ }, ]; +const MODEL_PROVIDER_FIELDS: &[FieldSpec] = &[ + FieldSpec { + name: "type", + expected: FieldType::String, + }, + FieldSpec { + name: "baseUrl", + expected: FieldType::String, + }, + FieldSpec { + name: "apiKeyEnv", + expected: FieldType::String, + }, + FieldSpec { + name: "apiKey", + expected: FieldType::String, + }, + FieldSpec { + name: "models", + expected: FieldType::StringArray, + }, + FieldSpec { + name: "defaultModel", + expected: FieldType::String, + }, +]; + const DEPRECATED_FIELDS: &[DeprecatedField] = &[ DeprecatedField { name: "permissionMode", @@ -501,6 +532,19 @@ pub fn validate_config_file( &path_display, )); } + if let Some(model_providers) = object.get("modelProviders").and_then(JsonValue::as_object) { + for (name, provider) in model_providers { + if let Some(provider) = provider.as_object() { + result.merge(validate_object_keys( + provider, + MODEL_PROVIDER_FIELDS, + &format!("modelProviders.{name}"), + source, + &path_display, + )); + } + } + } result } diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index dbdbd07b64..55e8a91789 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -96,6 +96,13 @@ struct ModelProvenance { source: ModelSource, } +#[derive(Debug, Clone)] +struct ConfiguredModelProvider { + wire_model: String, + api_key: String, + base_url: String, +} + impl ModelProvenance { fn default_fallback() -> Self { Self { @@ -1443,9 +1450,85 @@ fn resolve_model_alias_with_config(model: &str) -> String { if let Some(resolved) = config_alias_for_current_dir(trimmed) { return resolve_model_alias(&resolved).to_string(); } + if let Some(resolved) = config_provider_default_model_for_current_dir(trimmed) { + return resolved; + } resolve_model_alias(trimmed).to_string() } +fn config_provider_default_model_for_current_dir(provider_name: &str) -> Option { + if provider_name.is_empty() || provider_name.contains('/') { + return None; + } + let cwd = env::current_dir().ok()?; + let loader = ConfigLoader::default_for(&cwd); + let config = loader.load().ok()?; + let provider = config.model_providers().get(provider_name)?; + let model = provider + .default_model() + .or_else(|| provider.models().first().map(String::as_str))?; + Some(format!("{provider_name}/{model}")) +} + +fn configured_provider_names_for_current_dir() -> Vec { + let Ok(cwd) = env::current_dir() else { + return Vec::new(); + }; + let loader = ConfigLoader::default_for(&cwd); + loader + .load() + .map(|config| config.model_providers().keys().cloned().collect()) + .unwrap_or_default() +} + +fn configured_provider_for_model( + model: &str, +) -> Result, Box> { + let Some((provider_name, requested_model)) = model.split_once('/') else { + return Ok(None); + }; + let cwd = env::current_dir()?; + let config = ConfigLoader::default_for(&cwd).load()?; + let Some(provider) = config.model_providers().get(provider_name) else { + return Ok(None); + }; + if !matches!(provider.provider_type(), "openai-compatible" | "openai") { + return Err(format!( + "model provider '{provider_name}' uses unsupported type '{}'", + provider.provider_type() + ) + .into()); + } + let wire_model = if requested_model.is_empty() { + provider.default_model().ok_or_else(|| { + format!("model provider '{provider_name}' does not define defaultModel") + })? + } else { + requested_model + }; + if !provider.models().is_empty() && !provider.models().iter().any(|model| model == wire_model) { + return Err(format!( + "model '{wire_model}' is not listed in modelProviders.{provider_name}.models" + ) + .into()); + } + let api_key = if let Some(api_key) = provider.api_key().filter(|value| !value.is_empty()) { + api_key.to_string() + } else if let Some(env_name) = provider.api_key_env() { + env::var(env_name) + .map_err(|_| format!("model provider '{provider_name}' requires env var {env_name}"))? + } else { + return Err( + format!("model provider '{provider_name}' requires apiKeyEnv or apiKey").into(), + ); + }; + Ok(Some(ConfiguredModelProvider { + wire_model: wire_model.to_string(), + api_key, + base_url: provider.base_url().to_string(), + })) +} + /// Validate model syntax at parse time. /// Accepts: known aliases (opus, sonnet, haiku) or provider/model pattern. /// Rejects: empty, whitespace-only, strings with spaces, or invalid chars. @@ -1459,6 +1542,12 @@ fn validate_model_syntax(model: &str) -> Result<(), String> { "opus" | "sonnet" | "haiku" => return Ok(()), _ => {} } + if configured_provider_names_for_current_dir() + .iter() + .any(|name| name == trimmed) + { + return Ok(()); + } // Check for spaces (malformed) if trimmed.contains(' ') { return Err(format!( @@ -7527,6 +7616,10 @@ fn build_runtime_with_plugin_state( plugin_registry, mcp_state, } = runtime_plugin_state; + let configured_provider = configured_provider_for_model(&model)?; + let request_model = configured_provider + .as_ref() + .map_or_else(|| model.clone(), |provider| provider.wire_model.clone()); plugin_registry.initialize()?; let policy = permission_policy(permission_mode, &feature_config, &tool_registry) .map_err(std::io::Error::other)?; @@ -7534,7 +7627,8 @@ fn build_runtime_with_plugin_state( session, AnthropicRuntimeClient::new( session_id, - model, + request_model, + configured_provider, enable_tools, emit_output, allowed_tools.clone(), @@ -7662,6 +7756,7 @@ impl AnthropicRuntimeClient { fn new( session_id: &str, model: String, + configured_provider: Option, enable_tools: bool, emit_output: bool, allowed_tools: Option, @@ -7688,26 +7783,30 @@ impl AnthropicRuntimeClient { // prompt cache is Anthropic-only so non-Anthropic variants // skip it. let resolved_model = api::resolve_model_alias(&model); - let client = match detect_provider_kind(&resolved_model) { - ProviderKind::Anthropic => { - let auth = resolve_cli_auth_source()?; - let inner = AnthropicClient::from_auth(auth) - .with_base_url(api::read_base_url()) - .with_prompt_cache(PromptCache::new(session_id)); - ApiProviderClient::Anthropic(inner) - } - ProviderKind::Xai | ProviderKind::OpenAi => { - // The api crate's `ProviderClient::from_model_with_anthropic_auth` - // with `None` for the anthropic auth routes via - // `detect_provider_kind` and builds an - // `OpenAiCompatClient::from_env` with the matching - // `OpenAiCompatConfig` (openai / xai / dashscope). - // That reads the correct API-key env var and BASE_URL - // override internally, so this one call covers OpenAI, - // OpenRouter, xAI, DashScope, Ollama, and any other - // OpenAI-compat endpoint users configure via - // `OPENAI_BASE_URL` / `XAI_BASE_URL` / `DASHSCOPE_BASE_URL`. - ApiProviderClient::from_model_with_anthropic_auth(&resolved_model, None)? + let client = if let Some(provider) = configured_provider { + ApiProviderClient::from_openai_compatible_profile(provider.api_key, provider.base_url) + } else { + match detect_provider_kind(&resolved_model) { + ProviderKind::Anthropic => { + let auth = resolve_cli_auth_source()?; + let inner = AnthropicClient::from_auth(auth) + .with_base_url(api::read_base_url()) + .with_prompt_cache(PromptCache::new(session_id)); + ApiProviderClient::Anthropic(inner) + } + ProviderKind::Xai | ProviderKind::OpenAi => { + // The api crate's `ProviderClient::from_model_with_anthropic_auth` + // with `None` for the anthropic auth routes via + // `detect_provider_kind` and builds an + // `OpenAiCompatClient::from_env` with the matching + // `OpenAiCompatConfig` (openai / xai / dashscope). + // That reads the correct API-key env var and BASE_URL + // override internally, so this one call covers OpenAI, + // OpenRouter, xAI, DashScope, Ollama, and any other + // OpenAI-compat endpoint users configure via + // `OPENAI_BASE_URL` / `XAI_BASE_URL` / `DASHSCOPE_BASE_URL`. + ApiProviderClient::from_model_with_anthropic_auth(&resolved_model, None)? + } } }; Ok(Self { @@ -9261,6 +9360,7 @@ mod tests { SlashCommand, StatusUsage, TmuxPaneSnapshot, DEFAULT_MODEL, LATEST_SESSION_REFERENCE, STUB_COMMANDS, }; + use crate::configured_provider_for_model; use api::{ApiError, MessageResponse, OutputContentBlock, Usage}; use plugins::{ PluginManager, PluginManagerConfig, PluginTool, PluginToolDefinition, PluginToolPermission, @@ -9919,6 +10019,80 @@ mod tests { assert_eq!(builtin, "claude-haiku-4-5-20251213"); } + #[test] + fn configured_model_provider_default_resolves_to_provider_model_ref() { + // given + let _guard = env_lock(); + let root = temp_dir(); + let cwd = root.join("project"); + let config_home = root.join("config-home"); + std::fs::create_dir_all(cwd.join(".claw")).expect("project config dir should exist"); + std::fs::create_dir_all(&config_home).expect("config home should exist"); + std::fs::write( + cwd.join(".claw").join("settings.json"), + r#"{"modelProviders":{"zai":{"type":"openai-compatible","baseUrl":"https://api.z.ai/api/paas/v4","apiKeyEnv":"Z_AI_API_KEY","models":["glm-5.1","glm-4.6"],"defaultModel":"glm-5.1"}}}"#, + ) + .expect("project config should write"); + + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + + // when + let resolved = with_current_dir(&cwd, || resolve_model_alias_with_config("zai")); + + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + std::fs::remove_dir_all(root).expect("temp config root should clean up"); + + // then + assert_eq!(resolved, "zai/glm-5.1"); + } + + #[test] + fn configured_model_provider_resolves_runtime_connection_details() { + // given + let _guard = env_lock(); + let root = temp_dir(); + let cwd = root.join("project"); + let config_home = root.join("config-home"); + std::fs::create_dir_all(cwd.join(".claw")).expect("project config dir should exist"); + std::fs::create_dir_all(&config_home).expect("config home should exist"); + std::fs::write( + cwd.join(".claw").join("settings.json"), + r#"{"modelProviders":{"minimax":{"baseUrl":"https://api.minimax.io/v1","apiKeyEnv":"MINIMAX_API_KEY","models":["MiniMax-M2.7-highspeed"]}}}"#, + ) + .expect("project config should write"); + + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_key = std::env::var("MINIMAX_API_KEY").ok(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::set_var("MINIMAX_API_KEY", "test-minimax-key"); + + // when + let provider = with_current_dir(&cwd, || { + configured_provider_for_model("minimax/MiniMax-M2.7-highspeed") + }) + .expect("provider lookup should succeed") + .expect("provider should exist"); + + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + match original_key { + Some(value) => std::env::set_var("MINIMAX_API_KEY", value), + None => std::env::remove_var("MINIMAX_API_KEY"), + } + std::fs::remove_dir_all(root).expect("temp config root should clean up"); + + // then + assert_eq!(provider.wire_model, "MiniMax-M2.7-highspeed"); + assert_eq!(provider.api_key, "test-minimax-key"); + assert_eq!(provider.base_url, "https://api.minimax.io/v1"); + } + #[test] fn parses_version_flags_without_initializing_prompt_mode() { assert_eq!( From 6790844505affe8abce54eac8792b10fbb685986 Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Thu, 30 Apr 2026 02:20:07 -0300 Subject: [PATCH 02/15] Enable protocol-aware model provider login The provider selector now supports both OpenAI-compatible and Anthropic-compatible configured providers, so OpenCode-style MiniMax and Kimi coding endpoints can be selected through /model without pretending they use OpenAI chat completions. The /login wizard writes provider profiles with protocol, base URL, env var, models, and default model, using current OpenCode provider ids and model lists. Constraint: OpenCode reports minimax-coding-plan and kimi-for-coding as Anthropic-compatible endpoints Rejected: Put MiniMax and Kimi coding models under openai-compatible profiles | the selector would resolve but runtime calls would use the wrong wire protocol Confidence: high Scope-risk: moderate Directive: Keep login presets aligned with provider protocol, not only model names Tested: cargo fmt --check Tested: cargo check -p runtime -p api -p commands -p tools -p rusty-claude-cli Tested: cargo test -p rusty-claude-cli configured_model_provider -- --nocapture Tested: cargo test -p rusty-claude-cli login_subcommand_parses_and_logout_errors_helpfully -- --nocapture Tested: cargo run -p rusty-claude-cli -- --model minimax-coding-plan status --output-format json Tested: cargo run -p rusty-claude-cli -- --model kimi-for-coding status --output-format json Not-tested: live API requests to each external provider model --- USAGE.md | 114 ++++++- rust/crates/api/src/client.rs | 8 + rust/crates/runtime/src/config.rs | 5 +- rust/crates/rusty-claude-cli/src/main.rs | 405 ++++++++++++++++++++++- 4 files changed, 509 insertions(+), 23 deletions(-) diff --git a/USAGE.md b/USAGE.md index bca4cc5521..4d7d6e0c44 100644 --- a/USAGE.md +++ b/USAGE.md @@ -308,9 +308,11 @@ The OpenAI-compatible backend also serves as the gateway for **OpenRouter**, **O **Model-name prefix routing:** If a model name starts with `openai/`, `gpt-`, `qwen/`, or `qwen-`, the provider is selected by the prefix regardless of which env vars are set. This prevents accidental misrouting to Anthropic when multiple credentials exist in the environment. -### Configured OpenAI-compatible providers +### Configured compatible providers -If you use several OpenAI-compatible providers, define named provider profiles in `settings.json` instead of changing `OPENAI_BASE_URL` before every run. Each profile gets its own base URL, credential env var, and model allow-list: +If you use several compatible providers, define named provider profiles in `settings.json` instead of changing `OPENAI_BASE_URL` before every run. Each profile gets its own protocol, base URL, credential env var, and model allow-list. + +Run `claw login` or `/login` to create these profiles interactively for Z.AI, MiniMax, OpenAI, Kimi, Moonshot, or a custom compatible endpoint. The wizard can store a pasted local token in `~/.claw/settings.json`, but `apiKeyEnv` is preferred when you can keep the secret in your shell environment: ```json { @@ -320,22 +322,104 @@ If you use several OpenAI-compatible providers, define named provider profiles i "type": "openai-compatible", "baseUrl": "https://api.z.ai/api/paas/v4", "apiKeyEnv": "Z_AI_API_KEY", - "models": ["glm-5.1", "glm-4.6"], + "models": [ + "glm-5.1", + "glm-5", + "glm-5-turbo", + "glm-4.7", + "glm-4.7-flashx", + "glm-4.7-flash", + "glm-4.6", + "glm-4.5", + "glm-4.5-x", + "glm-4.5-air", + "glm-4.5-airx", + "glm-4.5-flash", + "glm-4-32b-0414-128k" + ], "defaultModel": "glm-5.1" }, - "minimax": { + "zai-coding-plan": { "type": "openai-compatible", - "baseUrl": "https://api.minimax.io/v1", + "baseUrl": "https://api.z.ai/api/coding/paas/v4", + "apiKeyEnv": "Z_AI_API_KEY", + "models": [ + "glm-4.5-air", + "glm-4.7", + "glm-5-turbo", + "glm-5.1", + "glm-5v-turbo" + ], + "defaultModel": "glm-5.1" + }, + "minimax-coding-plan": { + "type": "anthropic-compatible", + "baseUrl": "https://api.minimax.io/anthropic/v1", "apiKeyEnv": "MINIMAX_API_KEY", - "models": ["MiniMax-M2.7", "MiniMax-M2.7-highspeed"], + "models": [ + "MiniMax-M2", + "MiniMax-M2.1", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed" + ], "defaultModel": "MiniMax-M2.7-highspeed" }, + "openai": { + "type": "openai-compatible", + "baseUrl": "https://api.openai.com/v1", + "apiKeyEnv": "OPENAI_API_KEY", + "models": [ + "gpt-5-codex", + "gpt-5.1-codex", + "gpt-5.1-codex-max", + "gpt-5.1-codex-mini", + "gpt-5.2", + "gpt-5.2-codex", + "gpt-5.3-codex", + "gpt-5.3-codex-spark", + "gpt-5.4", + "gpt-5.4-fast", + "gpt-5.4-mini", + "gpt-5.4-mini-fast", + "gpt-5.5", + "gpt-5.5-fast", + "gpt-5.5-pro" + ], + "defaultModel": "gpt-5.5" + }, + "kimi-for-coding": { + "type": "anthropic-compatible", + "baseUrl": "https://api.kimi.com/coding/v1", + "apiKeyEnv": "KIMI_API_KEY", + "models": [ + "k2p5", + "k2p6", + "kimi-k2-thinking" + ], + "defaultModel": "k2p6" + }, "moonshot": { "type": "openai-compatible", "baseUrl": "https://api.moonshot.ai/v1", "apiKeyEnv": "MOONSHOT_API_KEY", - "models": ["kimi-k2.5"], - "defaultModel": "kimi-k2.5" + "models": [ + "kimi-k2.6", + "kimi-k2.5", + "kimi-k2-0905-preview", + "kimi-k2-0711-preview", + "kimi-k2-turbo-preview", + "kimi-k2-thinking", + "kimi-k2-thinking-turbo", + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", + "moonshot-v1-8k-vision-preview", + "moonshot-v1-32k-vision-preview", + "moonshot-v1-128k-vision-preview" + ], + "defaultModel": "kimi-k2.6" } } } @@ -345,14 +429,22 @@ Use `/model provider/model` in the REPL to switch without restarting: ```text /model zai/glm-5.1 -/model minimax/MiniMax-M2.7-highspeed -/model moonshot/kimi-k2.5 +/model zai-coding-plan/glm-5.1 +/model minimax-coding-plan/MiniMax-M2.7-highspeed +/model openai/gpt-5.5 +/model kimi-for-coding/k2p6 +/model moonshot/kimi-k2.6 ``` You can also use the provider name alone when it has `defaultModel` configured: ```text -/model minimax +/model zai +/model zai-coding-plan +/model minimax-coding-plan +/model openai +/model kimi-for-coding +/model moonshot ``` Prefer `apiKeyEnv` so secrets stay out of source-controlled project settings. `apiKey` is supported for local-only files when an environment variable is not practical. diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 89987b4a8d..c7863ff3cf 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -56,6 +56,14 @@ impl ProviderClient { ) } + #[must_use] + pub fn from_anthropic_compatible_profile( + api_key: impl Into, + base_url: impl Into, + ) -> Self { + Self::Anthropic(AnthropicClient::new(api_key).with_base_url(base_url)) + } + #[must_use] pub const fn provider_kind(&self) -> ProviderKind { match self { diff --git a/rust/crates/runtime/src/config.rs b/rust/crates/runtime/src/config.rs index e9777d2af0..46d871d54f 100644 --- a/rust/crates/runtime/src/config.rs +++ b/rust/crates/runtime/src/config.rs @@ -995,7 +995,10 @@ fn parse_optional_model_providers( let provider_type = optional_string(provider, "type", &context)? .unwrap_or("openai-compatible") .to_string(); - if !matches!(provider_type.as_str(), "openai-compatible" | "openai") { + if !matches!( + provider_type.as_str(), + "openai-compatible" | "openai" | "anthropic-compatible" | "anthropic" + ) { return Err(ConfigError::Parse(format!( "{context}: unsupported provider type {provider_type}" ))); diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 55e8a91789..7b326da74f 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -99,6 +99,7 @@ struct ModelProvenance { #[derive(Debug, Clone)] struct ConfiguredModelProvider { wire_model: String, + provider_type: String, api_key: String, base_url: String, } @@ -371,6 +372,11 @@ fn run() -> Result<(), Box> { output_format, } => print_system_prompt(cwd, date, output_format)?, CliAction::Version { output_format } => print_version(output_format)?, + CliAction::Login => { + if let Some(model) = run_login_wizard()? { + println!("Configured provider. Use `claw --model {model}` or `/model {model}`."); + } + } CliAction::ResumeSession { session_path, commands, @@ -511,6 +517,7 @@ enum CliAction { Version { output_format: CliOutputFormat, }, + Login, ResumeSession { session_path: PathBuf, commands: Vec, @@ -955,7 +962,8 @@ fn parse_args(args: &[String]) -> Result { } "system-prompt" => parse_system_prompt_args(&rest[1..], output_format), "acp" => parse_acp_args(&rest[1..], output_format), - "login" | "logout" => Err(removed_auth_surface_error(rest[0].as_str())), + "login" => Ok(CliAction::Login), + "logout" => Err(removed_auth_surface_error(rest[0].as_str())), "init" => Ok(CliAction::Init { output_format }), "export" => parse_export_args(&rest[1..], output_format), "prompt" => { @@ -1492,7 +1500,10 @@ fn configured_provider_for_model( let Some(provider) = config.model_providers().get(provider_name) else { return Ok(None); }; - if !matches!(provider.provider_type(), "openai-compatible" | "openai") { + if !matches!( + provider.provider_type(), + "openai-compatible" | "openai" | "anthropic-compatible" | "anthropic" + ) { return Err(format!( "model provider '{provider_name}' uses unsupported type '{}'", provider.provider_type() @@ -1524,11 +1535,355 @@ fn configured_provider_for_model( }; Ok(Some(ConfiguredModelProvider { wire_model: wire_model.to_string(), + provider_type: provider.provider_type().to_string(), api_key, base_url: provider.base_url().to_string(), })) } +struct LoginProviderTemplate { + id: &'static str, + label: &'static str, + provider_type: &'static str, + base_url: &'static str, + api_key_env: &'static str, + models: &'static [&'static str], + default_model: &'static str, +} + +const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ + LoginProviderTemplate { + id: "zai", + label: "Z.AI", + provider_type: "openai-compatible", + base_url: "https://api.z.ai/api/paas/v4", + api_key_env: "Z_AI_API_KEY", + models: &[ + "glm-5.1", + "glm-5", + "glm-5-turbo", + "glm-4.7", + "glm-4.7-flashx", + "glm-4.7-flash", + "glm-4.6", + "glm-4.5", + "glm-4.5-x", + "glm-4.5-air", + "glm-4.5-airx", + "glm-4.5-flash", + "glm-4-32b-0414-128k", + ], + default_model: "glm-5.1", + }, + LoginProviderTemplate { + id: "zai-coding-plan", + label: "Z.AI Coding Plan", + provider_type: "openai-compatible", + base_url: "https://api.z.ai/api/coding/paas/v4", + api_key_env: "Z_AI_API_KEY", + models: &[ + "glm-4.5-air", + "glm-4.7", + "glm-5-turbo", + "glm-5.1", + "glm-5v-turbo", + ], + default_model: "glm-5.1", + }, + LoginProviderTemplate { + id: "minimax-coding-plan", + label: "MiniMax Coding Plan", + provider_type: "anthropic-compatible", + base_url: "https://api.minimax.io/anthropic/v1", + api_key_env: "MINIMAX_API_KEY", + models: &[ + "MiniMax-M2", + "MiniMax-M2.1", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + ], + default_model: "MiniMax-M2.7-highspeed", + }, + LoginProviderTemplate { + id: "openai", + label: "OpenAI", + provider_type: "openai-compatible", + base_url: "https://api.openai.com/v1", + api_key_env: "OPENAI_API_KEY", + models: &[ + "gpt-5-codex", + "gpt-5.1-codex", + "gpt-5.1-codex-max", + "gpt-5.1-codex-mini", + "gpt-5.2", + "gpt-5.2-codex", + "gpt-5.3-codex", + "gpt-5.3-codex-spark", + "gpt-5.4", + "gpt-5.4-fast", + "gpt-5.4-mini", + "gpt-5.4-mini-fast", + "gpt-5.5", + "gpt-5.5-fast", + "gpt-5.5-pro", + ], + default_model: "gpt-5.5", + }, + LoginProviderTemplate { + id: "kimi-for-coding", + label: "Kimi For Coding", + provider_type: "anthropic-compatible", + base_url: "https://api.kimi.com/coding/v1", + api_key_env: "KIMI_API_KEY", + models: &["k2p5", "k2p6", "kimi-k2-thinking"], + default_model: "k2p6", + }, + LoginProviderTemplate { + id: "moonshot", + label: "Moonshot / Kimi", + provider_type: "openai-compatible", + base_url: "https://api.moonshot.ai/v1", + api_key_env: "MOONSHOT_API_KEY", + models: &[ + "kimi-k2.6", + "kimi-k2.5", + "kimi-k2-0905-preview", + "kimi-k2-0711-preview", + "kimi-k2-turbo-preview", + "kimi-k2-thinking", + "kimi-k2-thinking-turbo", + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", + "moonshot-v1-8k-vision-preview", + "moonshot-v1-32k-vision-preview", + "moonshot-v1-128k-vision-preview", + ], + default_model: "kimi-k2.6", + }, +]; + +fn run_login_wizard() -> Result, Box> { + if !io::stdin().is_terminal() { + return Err("login requires an interactive terminal".into()); + } + + println!(); + println!("Claw provider login"); + println!("Configure a model provider profile."); + println!("Press Enter to accept defaults."); + println!(); + + for (index, provider) in LOGIN_PROVIDER_TEMPLATES.iter().enumerate() { + println!(" [{}] {}", index + 1, provider.label); + } + println!( + " [{}] Custom compatible endpoint", + LOGIN_PROVIDER_TEMPLATES.len() + 1 + ); + + let choice = read_prompt("Select provider [1]: ")?; + let choice = if choice.trim().is_empty() { + 1 + } else { + choice.trim().parse::()? + }; + + let ( + provider_id, + label, + provider_type, + default_base_url, + default_api_key_env, + default_models, + default_model, + ) = if choice == LOGIN_PROVIDER_TEMPLATES.len() + 1 { + let id = read_required_prompt("Provider id (e.g. openrouter): ")?; + let provider_type = read_prompt( + "Provider type [openai-compatible, anthropic-compatible] [openai-compatible]: ", + )?; + let provider_type = if provider_type.trim().is_empty() { + "openai-compatible".to_string() + } else { + provider_type.trim().to_string() + }; + if !matches!( + provider_type.as_str(), + "openai-compatible" | "openai" | "anthropic-compatible" | "anthropic" + ) { + return Err(format!("unsupported provider type: {provider_type}").into()); + } + let base_url = read_required_prompt("Base URL: ")?; + let api_key_env = read_prompt("API key env var [OPENAI_API_KEY]: ")?; + let model = read_required_prompt("Default model: ")?; + ( + id, + "Custom".to_string(), + provider_type, + base_url, + if api_key_env.trim().is_empty() { + "OPENAI_API_KEY".to_string() + } else { + api_key_env.trim().to_string() + }, + vec![model.clone()], + model, + ) + } else { + let template = LOGIN_PROVIDER_TEMPLATES + .get(choice.saturating_sub(1)) + .ok_or_else(|| format!("invalid provider choice: {choice}"))?; + ( + template.id.to_string(), + template.label.to_string(), + template.provider_type.to_string(), + template.base_url.to_string(), + template.api_key_env.to_string(), + template + .models + .iter() + .map(|model| (*model).to_string()) + .collect::>(), + template.default_model.to_string(), + ) + }; + + println!(); + println!("{label}"); + println!("Provider type: {provider_type}"); + let base_url = read_prompt(&format!("Base URL [{default_base_url}]: "))?; + let base_url = if base_url.trim().is_empty() { + default_base_url + } else { + base_url.trim().to_string() + }; + let api_key_env = read_prompt(&format!("API key env var [{default_api_key_env}]: "))?; + let api_key_env = if api_key_env.trim().is_empty() { + default_api_key_env + } else { + api_key_env.trim().to_string() + }; + + let token = read_prompt("Paste API key / bearer token now, or press Enter to use env var: ")?; + let api_key = (!token.trim().is_empty()).then(|| token.trim().to_string()); + + println!("Available models: {}", default_models.join(", ")); + let model = read_prompt(&format!("Default model [{default_model}]: "))?; + let model = if model.trim().is_empty() { + default_model + } else { + model.trim().to_string() + }; + let mut models = default_models; + if !models.iter().any(|known| known == &model) { + models.push(model.clone()); + } + + save_model_provider_profile( + &provider_id, + &provider_type, + &base_url, + &api_key_env, + api_key.as_deref(), + &models, + &model, + )?; + Ok(Some(format!("{provider_id}/{model}"))) +} + +fn read_prompt(prompt: &str) -> Result> { + print!("{prompt}"); + io::stdout().flush()?; + let mut buffer = String::new(); + io::stdin().read_line(&mut buffer)?; + Ok(buffer) +} + +fn read_required_prompt(prompt: &str) -> Result> { + let value = read_prompt(prompt)?; + let value = value.trim(); + if value.is_empty() { + return Err(format!("{prompt} is required").into()); + } + Ok(value.to_string()) +} + +fn save_model_provider_profile( + provider_id: &str, + provider_type: &str, + base_url: &str, + api_key_env: &str, + api_key: Option<&str>, + models: &[String], + default_model: &str, +) -> Result<(), Box> { + let cwd = env::current_dir()?; + let config_home = ConfigLoader::default_for(&cwd).config_home().to_path_buf(); + fs::create_dir_all(&config_home)?; + let settings_path = config_home.join("settings.json"); + let mut root = match fs::read_to_string(&settings_path) { + Ok(contents) if !contents.trim().is_empty() => serde_json::from_str::(&contents)?, + Ok(_) => Value::Object(Map::new()), + Err(error) if error.kind() == io::ErrorKind::NotFound => Value::Object(Map::new()), + Err(error) => return Err(error.into()), + }; + if !root.is_object() { + root = Value::Object(Map::new()); + } + let root_object = root.as_object_mut().expect("root object initialized"); + let providers = root_object + .entry("modelProviders") + .or_insert_with(|| Value::Object(Map::new())); + if !providers.is_object() { + *providers = Value::Object(Map::new()); + } + let provider_map = providers + .as_object_mut() + .expect("modelProviders object initialized"); + + let mut provider = Map::new(); + provider.insert("type".to_string(), Value::String(provider_type.to_string())); + provider.insert("baseUrl".to_string(), Value::String(base_url.to_string())); + provider.insert( + "apiKeyEnv".to_string(), + Value::String(api_key_env.to_string()), + ); + if let Some(api_key) = api_key { + provider.insert("apiKey".to_string(), Value::String(api_key.to_string())); + } + provider.insert( + "models".to_string(), + Value::Array( + models + .iter() + .map(|model| Value::String(model.clone())) + .collect(), + ), + ); + provider.insert( + "defaultModel".to_string(), + Value::String(default_model.to_string()), + ); + provider_map.insert(provider_id.to_string(), Value::Object(provider)); + root_object.insert( + "model".to_string(), + Value::String(format!("{provider_id}/{default_model}")), + ); + + let serialized = format!("{}\n", serde_json::to_string_pretty(&root)?); + fs::write(&settings_path, serialized)?; + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut permissions = fs::metadata(&settings_path)?.permissions(); + permissions.set_mode(0o600); + fs::set_permissions(&settings_path, permissions)?; + } + Ok(()) +} + /// Validate model syntax at parse time. /// Accepts: known aliases (opus, sonnet, haiku) or provider/model pattern. /// Rejects: empty, whitespace-only, strings with spaces, or invalid chars. @@ -2206,7 +2561,7 @@ fn check_auth_health() -> DiagnosticCheck { token_set.scopes.join(",") } ), - "Suggested action set ANTHROPIC_API_KEY or ANTHROPIC_AUTH_TOKEN; `claw login` is removed" + "Suggested action set ANTHROPIC_API_KEY or ANTHROPIC_AUTH_TOKEN for Anthropic, or run `claw login` to configure a compatible provider" .to_string(), ]) .with_data(Map::from_iter([ @@ -4802,8 +5157,13 @@ impl LiveCli { println!("{}", format_cost_report(usage)); false } - SlashCommand::Login - | SlashCommand::Logout + SlashCommand::Login => { + if let Some(model) = run_login_wizard()? { + self.set_model(Some(model))?; + } + false + } + SlashCommand::Logout | SlashCommand::Vim | SlashCommand::Upgrade | SlashCommand::Share @@ -7784,7 +8144,24 @@ impl AnthropicRuntimeClient { // skip it. let resolved_model = api::resolve_model_alias(&model); let client = if let Some(provider) = configured_provider { - ApiProviderClient::from_openai_compatible_profile(provider.api_key, provider.base_url) + match provider.provider_type.as_str() { + "anthropic-compatible" | "anthropic" => { + ApiProviderClient::from_anthropic_compatible_profile( + provider.api_key, + provider.base_url, + ) + .with_prompt_cache(PromptCache::new(session_id)) + } + "openai-compatible" | "openai" => { + ApiProviderClient::from_openai_compatible_profile( + provider.api_key, + provider.base_url, + ) + } + other => { + return Err(format!("unsupported provider type: {other}").into()); + } + } } else { match detect_provider_kind(&resolved_model) { ProviderKind::Anthropic => { @@ -8233,7 +8610,6 @@ fn collect_prompt_cache_events(summary: &runtime::TurnSummary) -> Vec io::Result<()> { out, " Diagnose local auth, config, workspace, and sandbox health" )?; + writeln!(out, " claw login")?; + writeln!( + out, + " Configure a compatible model provider in settings.json" + )?; writeln!(out, " claw acp [serve]")?; writeln!( out, @@ -10242,9 +10623,11 @@ mod tests { } #[test] - fn removed_login_and_logout_subcommands_error_helpfully() { - let login = parse_args(&["login".to_string()]).expect_err("login should be removed"); - assert!(login.contains("ANTHROPIC_API_KEY")); + fn login_subcommand_parses_and_logout_errors_helpfully() { + assert_eq!( + parse_args(&["login".to_string()]).expect("login should parse"), + CliAction::Login + ); let logout = parse_args(&["logout".to_string()]).expect_err("logout should be removed"); assert!(logout.contains("ANTHROPIC_AUTH_TOKEN")); assert_eq!( @@ -11964,7 +12347,7 @@ mod tests { assert!(help.contains("claw /skills")); assert!(help.contains("ultraworkers/claw-code")); assert!(help.contains("cargo install claw-code")); - assert!(!help.contains("claw login")); + assert!(help.contains("claw login")); assert!(!help.contains("claw logout")); } From 863bed1cad49cfea2518650abf89dd62b4906f31 Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 01:23:57 -0300 Subject: [PATCH 03/15] feat(auth): add welcome screen and claw auth command for provider setup When opening claw without a configured provider API key, users now see an interactive welcome screen instead of a hard error. They can select a provider, enter their API key, and continue using claw immediately. - Automatically shown at REPL startup or prompt mode when auth is missing - Lists built-in providers (Anthropic, OpenAI, xAI) - Sets the env var for the current process so claw works immediately - Optionally saves the model choice to ~/.claw/settings.json - Supports cancellation (press Enter without entering a key) - Authenticate with a specific provider directly: `claw auth openai` - Without a provider argument, shows the interactive picker - Sets env vars for the current session - Added `BuiltinProvider` struct and constants for built-in providers - Added `check_model_auth_available()` to detect missing credentials - Added `run_provider_welcome()` for the interactive onboarding flow - Added `run_auth_command()` for the CLI subcommand - Hooked welcome screen into `run_repl()` and `Prompt` mode dispatch - Added `CliAction::Auth` variant and `parse_args` support - Updated help text and typo-suggestion list - Added unit tests for auth subcommand parsing - `cargo check --workspace` passes - New unit tests: `parse_args_auth_without_provider`, `parse_args_auth_with_provider` --- rust/crates/api/src/lib.rs | 4 +- rust/crates/rusty-claude-cli/src/main.rs | 255 ++++++++++++++++++++++- 2 files changed, 252 insertions(+), 7 deletions(-) diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 40da29f140..0f04905e7a 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -18,9 +18,9 @@ pub use prompt_cache::{ CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord, PromptCacheStats, }; -pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource}; +pub use providers::anthropic::{has_auth_from_env_or_saved as anthropic_has_auth, AnthropicClient, AnthropicClient as ApiClient, AuthSource}; pub use providers::openai_compat::{ - build_chat_completion_request, flatten_tool_result_content, is_reasoning_model, + build_chat_completion_request, flatten_tool_result_content, has_api_key, is_reasoning_model, model_rejects_is_error_field, translate_message, OpenAiCompatClient, OpenAiCompatConfig, }; pub use providers::{ diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 7b326da74f..6e2e2edf9c 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -24,10 +24,11 @@ use std::thread::{self, JoinHandle}; use std::time::{Duration, Instant, UNIX_EPOCH}; use api::{ - detect_provider_kind, resolve_startup_auth_source, AnthropicClient, AuthSource, - ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest, MessageResponse, - OutputContentBlock, PromptCache, ProviderClient as ApiProviderClient, ProviderKind, - StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, + anthropic_has_auth, detect_provider_kind, has_api_key, resolve_startup_auth_source, + AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, InputMessage, + MessageRequest, MessageResponse, OutputContentBlock, PromptCache, + ProviderClient as ApiProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, + ToolDefinition, ToolResultContentBlock, }; use commands::{ @@ -420,7 +421,16 @@ fn run() -> Result<(), Box> { None }; let effective_prompt = merge_prompt_with_stdin(&prompt, stdin_context.as_deref()); - let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode)?; + let resolved_model = resolve_model_alias_with_config(&model); + let final_model = if !check_model_auth_available(&resolved_model)? { + match run_provider_welcome(&resolved_model)? { + Some(new_model) => new_model, + None => return Ok(()), + } + } else { + resolved_model + }; + let mut cli = LiveCli::new(final_model, true, allowed_tools, permission_mode)?; cli.set_reasoning_effort(reasoning_effort); cli.run_turn_with_output(&effective_prompt, output_format, compact)?; } @@ -477,6 +487,7 @@ fn run() -> Result<(), Box> { reasoning_effort, allow_broad_cwd, )?, + CliAction::Auth { provider } => run_auth_command(provider.as_deref())?, CliAction::HelpTopic(topic) => print_help_topic(topic), CliAction::Help { output_format } => print_help(output_format)?, } @@ -581,6 +592,9 @@ enum CliAction { reasoning_effort: Option, allow_broad_cwd: bool, }, + Auth { + provider: Option, + }, HelpTopic(LocalHelpTopic), // prompt-mode formatting is only supported for non-interactive runs Help { @@ -964,6 +978,20 @@ fn parse_args(args: &[String]) -> Result { "acp" => parse_acp_args(&rest[1..], output_format), "login" => Ok(CliAction::Login), "logout" => Err(removed_auth_surface_error(rest[0].as_str())), + "auth" => { + if rest.len() == 2 && is_help_flag(&rest[1]) { + return Ok(CliAction::Help { output_format }); + } + let provider = rest.get(1).cloned(); + if rest.len() > 2 { + return Err(format!( + "unexpected extra arguments after `claw auth {}`: {}", + provider.as_deref().unwrap_or(""), + rest[2..].join(" ") + )); + } + Ok(CliAction::Auth { provider }) + } "init" => Ok(CliAction::Init { output_format }), "export" => parse_export_args(&rest[1..], output_format), "prompt" => { @@ -1354,6 +1382,7 @@ fn suggest_similar_subcommand(input: &str) -> Option> { "init", "export", "prompt", + "auth", ]; let normalized_input = input.to_ascii_lowercase(); @@ -4216,6 +4245,21 @@ fn run_repl( enforce_broad_cwd_policy(allow_broad_cwd, CliOutputFormat::Text)?; run_stale_base_preflight(base_commit.as_deref()); let resolved_model = resolve_repl_model(model); + if !check_model_auth_available(&resolved_model)? { + match run_provider_welcome(&resolved_model)? { + Some(new_model) => { + return run_repl( + new_model, + allowed_tools, + permission_mode, + base_commit, + reasoning_effort, + allow_broad_cwd, + ); + } + None => return Ok(()), + } + } let mut cli = LiveCli::new(resolved_model, true, allowed_tools, permission_mode)?; cli.set_reasoning_effort(reasoning_effort); let mut editor = @@ -8213,6 +8257,154 @@ fn resolve_cli_auth_source_for_cwd() -> Result { resolve_startup_auth_source(|| Ok(None)) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct BuiltinProvider { + id: &'static str, + label: &'static str, + env_var: &'static str, + default_model: &'static str, +} + +const BUILTIN_PROVIDERS: &[BuiltinProvider] = &[ + BuiltinProvider { + id: "anthropic", + label: "Anthropic (Claude)", + env_var: "ANTHROPIC_API_KEY", + default_model: "claude-opus-4-6", + }, + BuiltinProvider { + id: "openai", + label: "OpenAI", + env_var: "OPENAI_API_KEY", + default_model: "gpt-4o", + }, + BuiltinProvider { + id: "xai", + label: "xAI (Grok)", + env_var: "XAI_API_KEY", + default_model: "grok-3", + }, +]; + +fn check_model_auth_available(model: &str) -> Result> { + let resolved = api::resolve_model_alias(model); + let provider = detect_provider_kind(&resolved); + let available = match provider { + ProviderKind::Anthropic => anthropic_has_auth().unwrap_or(false), + ProviderKind::Xai => has_api_key("XAI_API_KEY"), + ProviderKind::OpenAi => has_api_key("OPENAI_API_KEY") || has_api_key("DASHSCOPE_API_KEY"), + }; + Ok(available) +} + +fn run_provider_welcome( + default_model: &str, +) -> Result, Box> { + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err( + "Authentication required. Set one of these environment variables:\n\ + \n\ + ANTHROPIC_API_KEY= # For Anthropic (Claude)\n\ + OPENAI_API_KEY= # For OpenAI\n\ + XAI_API_KEY= # For xAI (Grok)\n\ + DASHSCOPE_API_KEY= # For DashScope/Qwen" + .into(), + ); + } + + println!( + "\n\x1b[1mWelcome to Claw Code\x1b[0m\n\n\ + No API key detected for the selected model.\n\ + Choose a provider to authenticate with:\n" + ); + + for (i, provider) in BUILTIN_PROVIDERS.iter().enumerate() { + println!(" {}. {} ({})", i + 1, provider.label, provider.env_var); + } + + print!("\nEnter number (1-{}): ", BUILTIN_PROVIDERS.len()); + std::io::stdout().flush()?; + let mut choice = String::new(); + std::io::stdin().read_line(&mut choice)?; + let choice = choice.trim(); + + let index: usize = choice.parse().map_err(|_| "invalid selection")?; + let provider = BUILTIN_PROVIDERS + .get(index.wrapping_sub(1)) + .ok_or("invalid selection")?; + + print!("Enter {} (or press Enter to cancel): ", provider.env_var); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(None); + } + + std::env::set_var(provider.env_var, key); + + // Optionally save to ~/.claw/settings.json as a simple model config. + if let Some(home) = std::env::var_os("HOME") { + let claw_dir = std::path::PathBuf::from(home).join(".claw"); + if claw_dir.exists() || std::fs::create_dir_all(&claw_dir).is_ok() { + let settings_path = claw_dir.join("settings.json"); + let model_str = format!("{}/{}", provider.id, provider.default_model); + let content = format!("{{\"model\": \"{model_str}\"}}"); + let _ = std::fs::write(&settings_path, content); + } + } + + Ok(Some(format!("{}/{}", provider.id, provider.default_model))) +} + +fn run_auth_command(provider: Option<&str>) -> Result<(), Box> { + if let Some(provider_id) = provider { + let builtin = BUILTIN_PROVIDERS + .iter() + .find(|p| p.id == provider_id) + .ok_or_else(|| format!("unknown provider: {provider_id}"))?; + + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err(format!( + "Authentication requires a terminal. Set the environment variable instead:\n\ + \n\ + export {}=", + builtin.env_var + ) + .into()); + } + + print!("Enter {} (or press Enter to cancel): ", builtin.env_var); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(()); + } + + std::env::set_var(builtin.env_var, key); + println!("Authentication set for {}.", builtin.label); + Ok(()) + } else { + match run_provider_welcome(DEFAULT_MODEL)? { + Some(model) => { + println!("Authentication configured. Model set to {model}."); + Ok(()) + } + None => { + println!("Cancelled."); + Ok(()) + } + } + } +} + impl ApiClient for AnthropicRuntimeClient { #[allow(clippy::too_many_lines)] fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { @@ -9601,6 +9793,11 @@ fn print_help_to(out: &mut impl Write) -> io::Result<()> { writeln!(out, " claw skills")?; writeln!(out, " claw system-prompt [--cwd PATH] [--date YYYY-MM-DD]")?; writeln!(out, " claw init")?; + writeln!(out, " claw auth [PROVIDER]")?; + writeln!( + out, + " Authenticate with a model provider (anthropic, openai, xai)" + )?; writeln!( out, " claw export [PATH] [--session SESSION] [--output PATH]" @@ -10188,6 +10385,28 @@ mod tests { ); } + #[test] + fn parse_args_auth_without_provider() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["auth".to_string()]).expect("args should parse"), + CliAction::Auth { provider: None } + ); + } + + #[test] + fn parse_args_auth_with_provider() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["auth".to_string(), "openai".to_string()]).expect("args should parse"), + CliAction::Auth { + provider: Some("openai".to_string()), + } + ); + } + #[test] fn merge_prompt_with_stdin_returns_prompt_unchanged_when_no_pipe() { // given @@ -14176,3 +14395,29 @@ mod dump_manifests_tests { let _ = fs::remove_dir_all(&root); } } + + +#[cfg(test)] +mod auth_tests { + use super::{parse_args, CliAction}; + + #[test] + fn parse_args_auth_without_provider() { + let args = vec!["auth".to_string()]; + let action = parse_args(&args).expect("should parse"); + match action { + CliAction::Auth { provider } => assert!(provider.is_none()), + other => panic!("expected CliAction::Auth, got {other:?}"), + } + } + + #[test] + fn parse_args_auth_with_provider() { + let args = vec!["auth".to_string(), "openai".to_string()]; + let action = parse_args(&args).expect("should parse"); + match action { + CliAction::Auth { provider } => assert_eq!(provider, Some("openai".to_string())), + other => panic!("expected CliAction::Auth, got {other:?}"), + } + } +} From d09b6e14e22aeb6612847f5ce4b0eb9e0ab77016 Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 13:20:07 -0300 Subject: [PATCH 04/15] feat(oauth): add OAuth authentication support for OpenAI and Moonshot/Kimi Adds browser-based OAuth flows on top of the auth/provider infrastructure: - OpenAI: PKCE flow via auth.openai.com (ChatGPT/Codex accounts) - Moonshot / Kimi: Device Authorization Flow (RFC 8628) **New infrastructure:** - Per-provider OAuth token storage in ~/.claw/credentials.json - Local HTTP callback server for PKCE redirect handling - Browser launcher (open/xdg-open/start) - Device Authorization Flow polling **API client integration:** - OpenAiCompatClient falls back to saved OAuth tokens when env var unset - Bearer token authentication for OAuth providers **CLI integration:** - Welcome screen shows [OAuth] tag for supported providers - OAuth offered as recommended auth method when available - claw auth prompts to choose OAuth or API key --- rust/Cargo.lock | 2 + rust/crates/api/src/client.rs | 9 +- .../crates/api/src/providers/openai_compat.rs | 25 + rust/crates/runtime/Cargo.toml | 1 + rust/crates/runtime/src/lib.rs | 12 +- rust/crates/runtime/src/oauth.rs | 337 +++++++++++ rust/crates/rusty-claude-cli/Cargo.toml | 1 + rust/crates/rusty-claude-cli/src/main.rs | 554 ++++++++++++++++-- 8 files changed, 884 insertions(+), 57 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 740147e78e..78b9feeff4 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1364,6 +1364,7 @@ dependencies = [ "glob", "plugins", "regex", + "reqwest", "serde", "serde_json", "sha2", @@ -1456,6 +1457,7 @@ dependencies = [ "mock-anthropic-service", "plugins", "pulldown-cmark", + "reqwest", "runtime", "rustyline", "serde", diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index c7863ff3cf..c96165d4aa 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -41,7 +41,14 @@ impl ProviderClient { } _ => OpenAiCompatConfig::openai(), }; - Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?)) + // Try OAuth for OpenAI if env var is not set + if config.provider_name == "OpenAI" { + Ok(Self::OpenAi(OpenAiCompatClient::from_env_or_oauth( + config, "openai", + )?)) + } else { + Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?)) + } } } } diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index a810502e66..42eaf7fb9e 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -139,6 +139,31 @@ impl OpenAiCompatClient { Ok(Self::new(api_key, config)) } + /// Create a client using an OAuth access token instead of an API key. + /// The token is sent as `Authorization: Bearer {token}`. + #[must_use] + pub fn from_oauth_token(token: impl Into, config: OpenAiCompatConfig) -> Self { + Self::new(token, config) + } + + /// Try env var first, then fall back to saved OAuth token for the provider. + /// `provider_id` is the key used in `~/.claw/credentials.json` under `oauth_providers`. + pub fn from_env_or_oauth( + config: OpenAiCompatConfig, + provider_id: &str, + ) -> Result { + if let Some(api_key) = read_env_non_empty(config.api_key_env)? { + return Ok(Self::new(api_key, config)); + } + if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_id) { + return Ok(Self::from_oauth_token(token_set.access_token, config)); + } + Err(ApiError::missing_credentials( + config.provider_name, + config.credential_env_vars(), + )) + } + #[must_use] pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = base_url.into(); diff --git a/rust/crates/runtime/Cargo.toml b/rust/crates/runtime/Cargo.toml index b1bd04f374..382e38fc29 100644 --- a/rust/crates/runtime/Cargo.toml +++ b/rust/crates/runtime/Cargo.toml @@ -10,6 +10,7 @@ sha2 = "0.10" glob = "0.3" plugins = { path = "../plugins" } regex = "1" +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } serde = { version = "1", features = ["derive"] } serde_json.workspace = true telemetry = { path = "../telemetry" } diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index c7d87091fa..365e7ad6c8 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -112,11 +112,13 @@ pub use mcp_stdio::{ UnsupportedMcpServer, }; pub use oauth::{ - clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair, - generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query, - parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest, - OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, - PkceChallengeMethod, PkceCodePair, + clear_oauth_credentials, clear_provider_oauth, code_challenge_s256, credentials_path, + generate_pkce_pair, generate_state, load_oauth_credentials, load_provider_oauth, + loopback_redirect_uri, open_browser, parse_oauth_callback_query, + parse_oauth_callback_request_target, poll_device_token, run_oauth_callback_server, + save_oauth_credentials, save_provider_oauth, DeviceAuthRequest, DeviceAuthResponse, + OAuthAuthorizationRequest, OAuthCallbackParams, OAuthCallbackResult, OAuthRefreshRequest, + OAuthTokenExchangeRequest, OAuthTokenSet, PkceChallengeMethod, PkceCodePair, }; pub use permissions::{ PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy, diff --git a/rust/crates/runtime/src/oauth.rs b/rust/crates/runtime/src/oauth.rs index aa3ca158c7..87572a44c8 100644 --- a/rust/crates/runtime/src/oauth.rs +++ b/rust/crates/runtime/src/oauth.rs @@ -298,6 +298,343 @@ pub fn clear_oauth_credentials() -> io::Result<()> { write_credentials_root(&path, &root) } +// --------------------------------------------------------------------------- +// Per-provider OAuth token storage +// --------------------------------------------------------------------------- + +/// Load OAuth credentials for a specific provider from `credentials.json`. +/// Credentials are stored under the `oauth_providers.{provider_id}` key. +pub fn load_provider_oauth(provider_id: &str) -> io::Result> { + let path = credentials_path()?; + let root = read_credentials_root(&path)?; + let Some(oauth_providers) = root.get("oauth_providers") else { + return Ok(None); + }; + let Some(provider_value) = oauth_providers.get(provider_id) else { + return Ok(None); + }; + if provider_value.is_null() { + return Ok(None); + } + let stored = serde_json::from_value::(provider_value.clone()) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + Ok(Some(stored.into())) +} + +/// Save OAuth credentials for a specific provider to `credentials.json`. +/// Preserves other providers and the legacy `"oauth"` key. +pub fn save_provider_oauth(provider_id: &str, token_set: &OAuthTokenSet) -> io::Result<()> { + let path = credentials_path()?; + let mut root = read_credentials_root(&path)?; + let oauth_providers = root + .entry("oauth_providers") + .or_insert_with(|| Value::Object(Map::new())); + let provider_map = oauth_providers + .as_object_mut() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "oauth_providers must be an object"))?; + provider_map.insert( + provider_id.to_string(), + serde_json::to_value(StoredOAuthCredentials::from(token_set.clone())) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?, + ); + write_credentials_root(&path, &root) +} + +/// Clear OAuth credentials for a specific provider. +pub fn clear_provider_oauth(provider_id: &str) -> io::Result<()> { + let path = credentials_path()?; + let mut root = read_credentials_root(&path)?; + let Some(oauth_providers) = root.get_mut("oauth_providers") else { + return Ok(()); + }; + let Some(provider_map) = oauth_providers.as_object_mut() else { + return Ok(()); + }; + provider_map.remove(provider_id); + if provider_map.is_empty() { + root.remove("oauth_providers"); + } + write_credentials_root(&path, &root) +} + +// --------------------------------------------------------------------------- +// Browser launcher +// --------------------------------------------------------------------------- + +/// Open a URL in the user's default browser. +/// Falls back to printing the URL if the platform command fails. +pub fn open_browser(url: &str) -> io::Result<()> { + let (cmd, args): (&str, Vec<&str>) = if cfg!(target_os = "macos") { + ("open", vec![url]) + } else if cfg!(target_os = "linux") { + ("xdg-open", vec![url]) + } else if cfg!(target_os = "windows") { + ("cmd", vec!["/C", "start", "", url]) + } else { + eprintln!("Please open this URL in your browser:"); + eprintln!(" {url}"); + return Ok(()); + }; + match std::process::Command::new(cmd).args(&args).output() { + Ok(output) if output.status.success() => Ok(()), + _ => { + eprintln!("Could not open browser automatically. Please open this URL:"); + eprintln!(" {url}"); + Ok(()) + } + } +} + +// --------------------------------------------------------------------------- +// Local HTTP callback server (blocking, single-request) +// --------------------------------------------------------------------------- + +/// Result of a successful OAuth callback. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthCallbackResult { + pub code: String, + pub state: String, +} + +/// Run a blocking local HTTP server that waits for a single `/callback` request. +/// Returns the authorization code and state on success. +/// Times out after `timeout` duration. +pub fn run_oauth_callback_server( + port: u16, + timeout: std::time::Duration, +) -> io::Result { + use std::io::{BufRead, BufReader, Write}; + use std::net::{SocketAddr, TcpListener}; + + let addr: SocketAddr = format!("127.0.0.1:{port}").parse().map_err(|e| { + io::Error::new(io::ErrorKind::InvalidInput, format!("invalid address: {e}")) + })?; + let listener = TcpListener::bind(addr)?; + listener.set_nonblocking(false)?; + + let start = std::time::Instant::now(); + + loop { + if start.elapsed() > timeout { + return Err(io::Error::new( + io::ErrorKind::TimedOut, + "OAuth callback timed out waiting for browser redirect", + )); + } + + let (mut stream, _) = match listener.accept() { + Ok(conn) => conn, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + std::thread::sleep(std::time::Duration::from_millis(100)); + continue; + } + Err(e) => return Err(e), + }; + + let mut reader = BufReader::new(&mut stream); + let mut first_line = String::new(); + reader.read_line(&mut first_line)?; + + // Parse "GET /callback?code=...&state=... HTTP/1.1" + let target = first_line + .split_whitespace() + .nth(1) + .unwrap_or(""); + + // Consume remaining headers so browser doesn't reset connection + loop { + let mut line = String::new(); + if reader.read_line(&mut line)? == 0 { + break; + } + if line == "\r\n" || line == "\n" { + break; + } + } + + match parse_oauth_callback_request_target(target) { + Ok(params) => { + if let (Some(code), Some(state)) = (¶ms.code, ¶ms.state) { + // Success page + let body = r#" +Authentication Successful + +

✅ Authentication Successful

+

You can close this tab and return to the terminal.

+"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + return Ok(OAuthCallbackResult { + code: code.clone(), + state: state.clone(), + }); + } + if let Some(error) = ¶ms.error { + let err_desc = params.error_description.as_deref().unwrap_or(error); + let body = format!( + r#" +Authentication Failed + +

❌ Authentication Failed

+

{}

+

You can close this tab and return to the terminal.

+"#, + html_escape(err_desc) + ); + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + return Err(io::Error::new( + io::ErrorKind::Other, + format!("OAuth error: {error} - {err_desc}"), + )); + } + } + Err(e) => { + let body = format!( + r#" +Authentication Failed + +

❌ Invalid Callback

+

{}

+"#, + html_escape(&e) + ); + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + return Err(io::Error::new(io::ErrorKind::Other, e)); + } + } + } +} + +#[must_use] +fn html_escape(input: &str) -> String { + input + .replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) +} + +// --------------------------------------------------------------------------- +// Device Authorization Flow (RFC 8628) +// --------------------------------------------------------------------------- + +/// Request body for starting a device authorization flow. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DeviceAuthRequest { + pub client_id: String, + pub scope: String, +} + +/// Response from a device authorization endpoint. +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +pub struct DeviceAuthResponse { + pub device_code: String, + pub user_code: String, + pub verification_uri: String, + #[serde(default)] + pub verification_uri_complete: Option, + pub expires_in: u64, + pub interval: u64, +} + +/// Poll a device token endpoint until the user authorizes or the flow expires. +/// Returns `Ok(None)` if the user hasn't authorized yet but we should keep polling. +/// Returns `Ok(Some(token_set))` on success. +/// Returns `Err` on fatal errors. +pub async fn poll_device_token( + client: &reqwest::Client, + device_code: &str, + client_id: &str, + token_url: &str, +) -> io::Result> { + let params = [ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("device_code", device_code), + ("client_id", client_id), + ]; + + let response = client + .post(token_url) + .form(¶ms) + .send() + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("HTTP request failed: {e}")))?; + let status = response.status(); + let body = response.text().await.map_err(|e| { + io::Error::new(io::ErrorKind::Other, format!("Failed to read response body: {e}")) + })?; + + if status.is_success() { + let token: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let access_token = token["access_token"] + .as_str() + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "access_token missing from device token response", + ) + })? + .to_string(); + let refresh_token = token["refresh_token"].as_str().map(String::from); + let expires_at = token["expires_in"].as_u64().map(|secs| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + secs + }); + let scopes = token["scope"] + .as_str() + .map(|s| s.split(' ').map(String::from).collect()) + .unwrap_or_default(); + return Ok(Some(OAuthTokenSet { + access_token, + refresh_token, + expires_at, + scopes, + })); + } + + // Parse OAuth error response + let error_json: serde_json::Value = serde_json::from_str(&body).unwrap_or_default(); + let error = error_json["error"].as_str().unwrap_or("unknown"); + + match error { + "authorization_pending" => Ok(None), + "slow_down" => { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + Ok(None) + } + "expired_token" => Err(io::Error::new( + io::ErrorKind::Other, + "Device authorization expired. Please try again.", + )), + "access_denied" => Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "User denied authorization.", + )), + _ => Err(io::Error::new( + io::ErrorKind::Other, + format!("Device token error: {error}: {body}"), + )), + } +} + pub fn parse_oauth_callback_request_target(target: &str) -> Result { let (path, query) = target .split_once('?') diff --git a/rust/crates/rusty-claude-cli/Cargo.toml b/rust/crates/rusty-claude-cli/Cargo.toml index 635fdb32f7..a1e8fba543 100644 --- a/rust/crates/rusty-claude-cli/Cargo.toml +++ b/rust/crates/rusty-claude-cli/Cargo.toml @@ -15,6 +15,7 @@ commands = { path = "../commands" } compat-harness = { path = "../compat-harness" } crossterm = "0.28" pulldown-cmark = "0.13" +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } rustyline = "15" runtime = { path = "../runtime" } plugins = { path = "../plugins" } diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 6e2e2edf9c..2448b9b1bf 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -1578,6 +1578,7 @@ struct LoginProviderTemplate { api_key_env: &'static str, models: &'static [&'static str], default_model: &'static str, + oauth: Option, } const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ @@ -1603,6 +1604,7 @@ const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ "glm-4-32b-0414-128k", ], default_model: "glm-5.1", + oauth: None, }, LoginProviderTemplate { id: "zai-coding-plan", @@ -1618,6 +1620,7 @@ const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ "glm-5v-turbo", ], default_model: "glm-5.1", + oauth: None, }, LoginProviderTemplate { id: "minimax-coding-plan", @@ -1634,6 +1637,7 @@ const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ "MiniMax-M2.7-highspeed", ], default_model: "MiniMax-M2.7-highspeed", + oauth: None, }, LoginProviderTemplate { id: "openai", @@ -1659,6 +1663,15 @@ const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ "gpt-5.5-pro", ], default_model: "gpt-5.5", + oauth: Some(ProviderOAuthConfig { + client_id: "app_EMoamEEZ73f0CkXaXp7hrann", + callback_port: 1455, + flow: OAuthFlowType::Pkce { + authorize_url: "https://auth.openai.com/oauth/authorize", + token_url: "https://auth.openai.com/oauth/token", + scopes: &["openid", "profile", "email", "offline_access"], + }, + }), }, LoginProviderTemplate { id: "kimi-for-coding", @@ -1668,6 +1681,7 @@ const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ api_key_env: "KIMI_API_KEY", models: &["k2p5", "k2p6", "kimi-k2-thinking"], default_model: "k2p6", + oauth: None, }, LoginProviderTemplate { id: "moonshot", @@ -1691,6 +1705,15 @@ const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ "moonshot-v1-128k-vision-preview", ], default_model: "kimi-k2.6", + oauth: Some(ProviderOAuthConfig { + client_id: "17e5f671-d194-4dfb-9706-5516cb48c098", + callback_port: 4546, + flow: OAuthFlowType::Device { + device_auth_url: "https://auth.kimi.com/api/oauth/device_authorization", + token_url: "https://auth.kimi.com/api/oauth/token", + scopes: &["openid", "profile", "email"], + }, + }), }, ]; @@ -8257,12 +8280,34 @@ fn resolve_cli_auth_source_for_cwd() -> Result { resolve_startup_auth_source(|| Ok(None)) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum OAuthFlowType { + Pkce { + authorize_url: &'static str, + token_url: &'static str, + scopes: &'static [&'static str], + }, + Device { + device_auth_url: &'static str, + token_url: &'static str, + scopes: &'static [&'static str], + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ProviderOAuthConfig { + client_id: &'static str, + callback_port: u16, + flow: OAuthFlowType, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] struct BuiltinProvider { id: &'static str, label: &'static str, env_var: &'static str, default_model: &'static str, + oauth: Option, } const BUILTIN_PROVIDERS: &[BuiltinProvider] = &[ @@ -8271,18 +8316,29 @@ const BUILTIN_PROVIDERS: &[BuiltinProvider] = &[ label: "Anthropic (Claude)", env_var: "ANTHROPIC_API_KEY", default_model: "claude-opus-4-6", + oauth: None, }, BuiltinProvider { id: "openai", label: "OpenAI", env_var: "OPENAI_API_KEY", default_model: "gpt-4o", + oauth: Some(ProviderOAuthConfig { + client_id: "app_EMoamEEZ73f0CkXaXp7hrann", + callback_port: 1455, + flow: OAuthFlowType::Pkce { + authorize_url: "https://auth.openai.com/oauth/authorize", + token_url: "https://auth.openai.com/oauth/token", + scopes: &["openid", "profile", "email", "offline_access"], + }, + }), }, BuiltinProvider { id: "xai", label: "xAI (Grok)", env_var: "XAI_API_KEY", default_model: "grok-3", + oauth: None, }, ]; @@ -8291,8 +8347,14 @@ fn check_model_auth_available(model: &str) -> Result anthropic_has_auth().unwrap_or(false), - ProviderKind::Xai => has_api_key("XAI_API_KEY"), - ProviderKind::OpenAi => has_api_key("OPENAI_API_KEY") || has_api_key("DASHSCOPE_API_KEY"), + ProviderKind::Xai => { + has_api_key("XAI_API_KEY") + } + ProviderKind::OpenAi => { + has_api_key("OPENAI_API_KEY") + || has_api_key("DASHSCOPE_API_KEY") + || runtime::load_provider_oauth("openai").ok().flatten().is_some() + } }; Ok(available) } @@ -8318,22 +8380,101 @@ fn run_provider_welcome( Choose a provider to authenticate with:\n" ); + let builtin_count = BUILTIN_PROVIDERS.len(); + let template_count = LOGIN_PROVIDER_TEMPLATES.len(); + let total = builtin_count + template_count; + + println!(" Built-in:"); for (i, provider) in BUILTIN_PROVIDERS.iter().enumerate() { - println!(" {}. {} ({})", i + 1, provider.label, provider.env_var); + let oauth_tag = if provider.oauth.is_some() { " [OAuth]" } else { "" }; + println!(" {}. {}{}", i + 1, provider.label, oauth_tag); } - print!("\nEnter number (1-{}): ", BUILTIN_PROVIDERS.len()); + println!("\n Additional providers:"); + for (i, template) in LOGIN_PROVIDER_TEMPLATES.iter().enumerate() { + let oauth_tag = if template.oauth.is_some() { " [OAuth]" } else { "" }; + println!( + " {}. {}{}", + builtin_count + i + 1, + template.label, + oauth_tag + ); + } + + print!("\nEnter number (1-{total}): "); std::io::stdout().flush()?; let mut choice = String::new(); std::io::stdin().read_line(&mut choice)?; let choice = choice.trim(); let index: usize = choice.parse().map_err(|_| "invalid selection")?; - let provider = BUILTIN_PROVIDERS - .get(index.wrapping_sub(1)) - .ok_or("invalid selection")?; + if index == 0 || index > total { + return Err("invalid selection".into()); + } + + // Built-in provider selected + if index <= builtin_count { + let provider = BUILTIN_PROVIDERS.get(index - 1).expect("valid builtin index"); + + // Offer OAuth if available + if let Some(ref oauth) = provider.oauth { + if prompt_oauth_or_api_key(provider.label, true)? { + run_pkce_oauth_flow(provider.id, oauth)?; + return Ok(Some(format!("{}/{}", provider.id, provider.default_model))); + } + } + + print!("Enter {} (or press Enter to cancel): ", provider.env_var); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(None); + } + + std::env::set_var(provider.env_var, key); + + // Optionally save to ~/.claw/settings.json as a simple model config. + if let Some(home) = std::env::var_os("HOME") { + let claw_dir = std::path::PathBuf::from(home).join(".claw"); + if claw_dir.exists() || std::fs::create_dir_all(&claw_dir).is_ok() { + let settings_path = claw_dir.join("settings.json"); + let model_str = format!("{}/{}", provider.id, provider.default_model); + let content = format!("{{\"model\": \"{model_str}\"}}"); + let _ = std::fs::write(&settings_path, content); + } + } + + return Ok(Some(format!("{}/{}", provider.id, provider.default_model))); + } + + // Template provider selected + let template = LOGIN_PROVIDER_TEMPLATES + .get(index - builtin_count - 1) + .expect("valid template index"); + + // Offer OAuth if available + if let Some(ref oauth) = template.oauth { + if prompt_oauth_or_api_key(template.label, true)? { + match oauth.flow { + OAuthFlowType::Pkce { .. } => { + run_pkce_oauth_flow(template.id, oauth)?; + } + OAuthFlowType::Device { .. } => { + run_device_oauth_flow(template.id, oauth)?; + } + } + return Ok(Some(format!("{}/{}", template.id, template.default_model))); + } + } - print!("Enter {} (or press Enter to cancel): ", provider.env_var); + print!( + "Enter {} (or press Enter to cancel): ", + template.api_key_env + ); std::io::stdout().flush()?; let mut key = String::new(); std::io::stdin().read_line(&mut key)?; @@ -8344,65 +8485,376 @@ fn run_provider_welcome( return Ok(None); } - std::env::set_var(provider.env_var, key); - - // Optionally save to ~/.claw/settings.json as a simple model config. - if let Some(home) = std::env::var_os("HOME") { - let claw_dir = std::path::PathBuf::from(home).join(".claw"); - if claw_dir.exists() || std::fs::create_dir_all(&claw_dir).is_ok() { - let settings_path = claw_dir.join("settings.json"); - let model_str = format!("{}/{}", provider.id, provider.default_model); - let content = format!("{{\"model\": \"{model_str}\"}}"); - let _ = std::fs::write(&settings_path, content); - } - } + std::env::set_var(template.api_key_env, key); + save_model_provider_profile( + template.id, + template.provider_type, + template.base_url, + template.api_key_env, + Some(key), + &template.models.iter().map(|m| (*m).to_string()).collect::>(), + template.default_model, + )?; - Ok(Some(format!("{}/{}", provider.id, provider.default_model))) + Ok(Some(format!("{}/{}", template.id, template.default_model))) } fn run_auth_command(provider: Option<&str>) -> Result<(), Box> { if let Some(provider_id) = provider { - let builtin = BUILTIN_PROVIDERS - .iter() - .find(|p| p.id == provider_id) - .ok_or_else(|| format!("unknown provider: {provider_id}"))?; - - if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { - return Err(format!( - "Authentication requires a terminal. Set the environment variable instead:\n\ - \n\ - export {}=", - builtin.env_var - ) - .into()); + // Try built-in provider first + if let Some(builtin) = BUILTIN_PROVIDERS.iter().find(|p| p.id == provider_id) { + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err(format!( + "Authentication requires a terminal. Set the environment variable instead:\n\ + \n\ + export {}=", + builtin.env_var + ) + .into()); + } + + // Offer OAuth if available + if let Some(ref oauth) = builtin.oauth { + if prompt_oauth_or_api_key(builtin.label, true)? { + run_pkce_oauth_flow(builtin.id, oauth)?; + println!("Authenticated with {} via OAuth.", builtin.label); + return Ok(()); + } + } + + print!("Enter {} (or press Enter to cancel): ", builtin.env_var); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(()); + } + + std::env::set_var(builtin.env_var, key); + println!("Authentication set for {}.", builtin.label); + return Ok(()); } - print!("Enter {} (or press Enter to cancel): ", builtin.env_var); - std::io::stdout().flush()?; - let mut key = String::new(); - std::io::stdin().read_line(&mut key)?; - let key = key.trim(); + // Try template provider + if let Some(template) = LOGIN_PROVIDER_TEMPLATES.iter().find(|p| p.id == provider_id) { + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err(format!( + "Authentication requires a terminal. Set the environment variable instead:\n\ + \n\ + export {}=", + template.api_key_env + ) + .into()); + } - if key.is_empty() { - println!("Cancelled."); + // Offer OAuth if available + if let Some(ref oauth) = template.oauth { + if prompt_oauth_or_api_key(template.label, true)? { + match oauth.flow { + OAuthFlowType::Pkce { .. } => { + run_pkce_oauth_flow(template.id, oauth)?; + } + OAuthFlowType::Device { .. } => { + run_device_oauth_flow(template.id, oauth)?; + } + } + println!( + "Authenticated with {label} via OAuth. Model: {id}/{model}", + label = template.label, + id = template.id, + model = template.default_model + ); + return Ok(()); + } + } + + print!( + "Enter {} (or press Enter to cancel): ", + template.api_key_env + ); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(()); + } + + std::env::set_var(template.api_key_env, key); + save_model_provider_profile( + template.id, + template.provider_type, + template.base_url, + template.api_key_env, + Some(key), + &template.models.iter().map(|m| (*m).to_string()).collect::>(), + template.default_model, + )?; + println!( + "Authenticated with {label}. Model: {id}/{model}", + label = template.label, + id = template.id, + model = template.default_model + ); return Ok(()); } - std::env::set_var(builtin.env_var, key); - println!("Authentication set for {}.", builtin.label); - Ok(()) + return Err(format!( + "unknown provider: '{provider_id}'. Run `claw auth` to see available providers." + ) + .into()); + } + + match run_provider_welcome(DEFAULT_MODEL)? { + Some(model) => { + println!("Authentication configured. Model set to {model}."); + Ok(()) + } + None => { + println!("Cancelled."); + Ok(()) + } + } +} + +// --------------------------------------------------------------------------- +// OAuth flow implementations +// --------------------------------------------------------------------------- + +fn run_pkce_oauth_flow( + provider_id: &str, + oauth: &ProviderOAuthConfig, +) -> Result> { + use runtime::{ + generate_pkce_pair, generate_state, loopback_redirect_uri, open_browser, + run_oauth_callback_server, OAuthAuthorizationRequest, OAuthTokenExchangeRequest, + }; + + let pkce = generate_pkce_pair()?; + let state = generate_state()?; + + let (authorize_url, token_url, scopes) = match oauth.flow { + OAuthFlowType::Pkce { + authorize_url, + token_url, + scopes, + } => (authorize_url, token_url, scopes), + _ => return Err("Expected PKCE flow".into()), + }; + + let config = runtime::OAuthConfig { + client_id: oauth.client_id.to_string(), + authorize_url: authorize_url.to_string(), + token_url: token_url.to_string(), + callback_port: Some(oauth.callback_port), + manual_redirect_url: None, + scopes: scopes.iter().map(|s| (*s).to_string()).collect(), + }; + + let auth_request = + OAuthAuthorizationRequest::from_config(&config, loopback_redirect_uri(oauth.callback_port), &state, &pkce); + let auth_url = auth_request.build_url(); + + println!("Opening browser for OAuth authentication..."); + println!("If the browser doesn't open automatically, visit:"); + println!(" {auth_url}"); + open_browser(&auth_url)?; + + println!("Waiting for authentication..."); + let callback = run_oauth_callback_server(oauth.callback_port, std::time::Duration::from_secs(300))?; + + if callback.state != state { + return Err("OAuth state mismatch. Possible CSRF attack.".into()); + } + + println!("Exchanging authorization code for tokens..."); + + let rt = tokio::runtime::Runtime::new()?; + let token_set = rt.block_on(async { + let client = reqwest::Client::new(); + let exchange = OAuthTokenExchangeRequest::from_config( + &config, + &callback.code, + &state, + &pkce.verifier, + &loopback_redirect_uri(oauth.callback_port), + ); + + let response = client + .post(&config.token_url) + .form(&exchange.form_params()) + .send() + .await + .map_err(|e| format!("Token exchange request failed: {e}"))?; + + let status = response.status(); + let body = response + .text() + .await + .map_err(|e| format!("Failed to read token response: {e}"))?; + + if !status.is_success() { + return Err(format!("Token exchange failed ({status}): {body}").into()); + } + + let json: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| format!("Invalid token response JSON: {e}"))?; + + let access_token = json["access_token"] + .as_str() + .ok_or("access_token missing from response")? + .to_string(); + let refresh_token = json["refresh_token"].as_str().map(String::from); + let expires_at = json["expires_in"].as_u64().map(|secs| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + secs + }); + let scopes = json["scope"] + .as_str() + .map(|s| s.split(' ').map(String::from).collect()) + .unwrap_or_default(); + + Ok::<_, Box>(runtime::OAuthTokenSet { + access_token, + refresh_token, + expires_at, + scopes, + }) + })?; + + runtime::save_provider_oauth(provider_id, &token_set)?; + println!("✓ OAuth authentication successful. Tokens saved."); + + Ok(token_set) +} + +fn run_device_oauth_flow( + provider_id: &str, + oauth: &ProviderOAuthConfig, +) -> Result> { + use runtime::{open_browser, poll_device_token, DeviceAuthRequest}; + + let (device_auth_url, token_url, scopes) = match oauth.flow { + OAuthFlowType::Device { + device_auth_url, + token_url, + scopes, + } => (device_auth_url, token_url, scopes), + _ => return Err("Expected Device flow".into()), + }; + + let rt = tokio::runtime::Runtime::new()?; + + // Step 1: Request device code + let device_response = rt.block_on(async { + let client = reqwest::Client::new(); + let scope_str = scopes.join(" "); + let params = [ + ("client_id", oauth.client_id), + ("scope", scope_str.as_str()), + ]; + let response = client + .post(device_auth_url) + .form(¶ms) + .send() + .await + .map_err(|e| format!("Device auth request failed: {e}"))?; + + let status = response.status(); + let body = response + .text() + .await + .map_err(|e| format!("Failed to read device auth response: {e}"))?; + + if !status.is_success() { + return Err(format!("Device auth failed ({status}): {body}").into()); + } + + let resp: runtime::DeviceAuthResponse = serde_json::from_str(&body) + .map_err(|e| format!("Invalid device auth response: {e}"))?; + Ok::<_, Box>(resp) + })?; + + println!("\nDevice authentication started."); + println!("User code: {}", device_response.user_code); + println!( + "Please visit: {}", + device_response + .verification_uri_complete + .as_deref() + .unwrap_or(&device_response.verification_uri) + ); + + if let Some(ref complete_uri) = device_response.verification_uri_complete { + open_browser(complete_uri)?; } else { - match run_provider_welcome(DEFAULT_MODEL)? { - Some(model) => { - println!("Authentication configured. Model set to {model}."); - Ok(()) + open_browser(&device_response.verification_uri)?; + } + + // Step 2: Poll for token + let start = std::time::Instant::now(); + let expires_in = std::time::Duration::from_secs(device_response.expires_in); + let interval = std::time::Duration::from_secs(device_response.interval); + + let token_set = rt.block_on(async { + let client = reqwest::Client::new(); + loop { + if start.elapsed() > expires_in { + return Err::<_, Box>( + "Device authorization expired. Please try again.".into(), + ); } - None => { - println!("Cancelled."); - Ok(()) + + tokio::time::sleep(interval).await; + + match poll_device_token( + &client, + &device_response.device_code, + oauth.client_id, + token_url, + ) + .await + { + Ok(Some(token_set)) => return Ok(token_set), + Ok(None) => { + println!("Waiting for authorization..."); + continue; + } + Err(e) => return Err(e.into()), } } + })?; + + runtime::save_provider_oauth(provider_id, &token_set)?; + println!("✓ OAuth authentication successful. Tokens saved."); + + Ok(token_set) +} + +fn prompt_oauth_or_api_key(provider_label: &str, has_oauth: bool) -> Result> { + if !has_oauth { + return Ok(false); } + + println!("\nChoose authentication method for {provider_label}:"); + println!(" 1. Sign in with {provider_label} account (OAuth) — recommended"); + println!(" 2. Enter API key manually"); + print!("Select [1]: "); + std::io::stdout().flush()?; + + let mut choice = String::new(); + std::io::stdin().read_line(&mut choice)?; + let choice = choice.trim(); + + Ok(choice.is_empty() || choice == "1") } impl ApiClient for AnthropicRuntimeClient { From 702fbc240dd7768de8cb30e549b2fce63e782ea1 Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 13:33:47 -0300 Subject: [PATCH 05/15] test(auth,oauth): add unit tests for OAuth infrastructure - Add runtime OAuth tests: per-provider storage round-trip, legacy key preservation, callback server code/state capture, callback server error handling, HTML escaping - Add CLI auth tests: OAuth config detection, API key fallback order, provider metadata checks Runtime: 15 tests pass, CLI: 7 tests pass --- rust/crates/runtime/src/oauth.rs | 336 +++++++++++++++++------ rust/crates/rusty-claude-cli/src/main.rs | 99 ++++++- 2 files changed, 356 insertions(+), 79 deletions(-) diff --git a/rust/crates/runtime/src/oauth.rs b/rust/crates/runtime/src/oauth.rs index 87572a44c8..e49ab1732c 100644 --- a/rust/crates/runtime/src/oauth.rs +++ b/rust/crates/runtime/src/oauth.rs @@ -405,107 +405,90 @@ pub fn run_oauth_callback_server( ) -> io::Result { use std::io::{BufRead, BufReader, Write}; use std::net::{SocketAddr, TcpListener}; + use std::sync::mpsc; let addr: SocketAddr = format!("127.0.0.1:{port}").parse().map_err(|e| { io::Error::new(io::ErrorKind::InvalidInput, format!("invalid address: {e}")) })?; let listener = TcpListener::bind(addr)?; - listener.set_nonblocking(false)?; - let start = std::time::Instant::now(); + let (tx, rx) = mpsc::channel::(); - loop { - if start.elapsed() > timeout { + std::thread::spawn(move || { + if let Ok((stream, _)) = listener.accept() { + let _ = tx.send(stream); + } + }); + + let mut stream = match rx.recv_timeout(timeout) { + Ok(stream) => stream, + Err(mpsc::RecvTimeoutError::Timeout) => { return Err(io::Error::new( io::ErrorKind::TimedOut, "OAuth callback timed out waiting for browser redirect", )); } + Err(mpsc::RecvTimeoutError::Disconnected) => { + return Err(io::Error::new( + io::ErrorKind::Other, + "callback server thread disconnected", + )); + } + }; - let (mut stream, _) = match listener.accept() { - Ok(conn) => conn, - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - std::thread::sleep(std::time::Duration::from_millis(100)); - continue; - } - Err(e) => return Err(e), - }; + let mut reader = BufReader::new(&mut stream); + let mut first_line = String::new(); + reader.read_line(&mut first_line)?; - let mut reader = BufReader::new(&mut stream); - let mut first_line = String::new(); - reader.read_line(&mut first_line)?; - - // Parse "GET /callback?code=...&state=... HTTP/1.1" - let target = first_line - .split_whitespace() - .nth(1) - .unwrap_or(""); - - // Consume remaining headers so browser doesn't reset connection - loop { - let mut line = String::new(); - if reader.read_line(&mut line)? == 0 { - break; - } - if line == "\r\n" || line == "\n" { - break; - } + // Parse "GET /callback?code=...&state=... HTTP/1.1" + let target = first_line + .split_whitespace() + .nth(1) + .unwrap_or(""); + + // Consume remaining headers so browser doesn't reset connection + loop { + let mut line = String::new(); + if reader.read_line(&mut line)? == 0 { + break; + } + if line == "\r\n" || line == "\n" { + break; } + } - match parse_oauth_callback_request_target(target) { - Ok(params) => { - if let (Some(code), Some(state)) = (¶ms.code, ¶ms.state) { - // Success page - let body = r#" + match parse_oauth_callback_request_target(target) { + Ok(params) => { + if let (Some(code), Some(state)) = (¶ms.code, ¶ms.state) { + // Success page + let body = r#" Authentication Successful

✅ Authentication Successful

You can close this tab and return to the terminal.

"#; - let response = format!( - "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", - body.len(), - body - ); - stream.write_all(response.as_bytes())?; - return Ok(OAuthCallbackResult { - code: code.clone(), - state: state.clone(), - }); - } - if let Some(error) = ¶ms.error { - let err_desc = params.error_description.as_deref().unwrap_or(error); - let body = format!( - r#" -Authentication Failed - -

❌ Authentication Failed

-

{}

-

You can close this tab and return to the terminal.

-"#, - html_escape(err_desc) - ); - let response = format!( - "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", - body.len(), - body - ); - stream.write_all(response.as_bytes())?; - return Err(io::Error::new( - io::ErrorKind::Other, - format!("OAuth error: {error} - {err_desc}"), - )); - } + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + return Ok(OAuthCallbackResult { + code: code.clone(), + state: state.clone(), + }); } - Err(e) => { + if let Some(error) = ¶ms.error { + let err_desc = params.error_description.as_deref().unwrap_or(error); let body = format!( r#" Authentication Failed -

❌ Invalid Callback

+

❌ Authentication Failed

{}

+

You can close this tab and return to the terminal.

"#, - html_escape(&e) + html_escape(err_desc) ); let response = format!( "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", @@ -513,8 +496,33 @@ pub fn run_oauth_callback_server( body ); stream.write_all(response.as_bytes())?; - return Err(io::Error::new(io::ErrorKind::Other, e)); + return Err(io::Error::new( + io::ErrorKind::Other, + format!("OAuth error: {error} - {err_desc}"), + )); } + Err(io::Error::new( + io::ErrorKind::InvalidData, + "callback received neither code nor error", + )) + } + Err(e) => { + let body = format!( + r#" +Authentication Failed + +

❌ Invalid Callback

+

{}

+"#, + html_escape(&e) + ); + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + Err(io::Error::new(io::ErrorKind::Other, e)) } } } @@ -802,10 +810,12 @@ mod tests { use std::time::{SystemTime, UNIX_EPOCH}; use super::{ - clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair, - generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query, - parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest, - OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, + clear_oauth_credentials, clear_provider_oauth, code_challenge_s256, credentials_path, + generate_pkce_pair, generate_state, load_oauth_credentials, load_provider_oauth, + loopback_redirect_uri, parse_oauth_callback_query, parse_oauth_callback_request_target, + run_oauth_callback_server, save_oauth_credentials, save_provider_oauth, html_escape, + OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, + OAuthTokenSet, }; fn sample_config() -> OAuthConfig { @@ -937,4 +947,174 @@ mod tests { assert_eq!(params.state.as_deref(), Some("xyz")); assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err()); } + + #[test] + fn provider_oauth_credentials_round_trip_and_clear() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + let path = credentials_path().expect("credentials path"); + std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent"); + + let openai_tokens = OAuthTokenSet { + access_token: "openai-access".to_string(), + refresh_token: Some("openai-refresh".to_string()), + expires_at: Some(1000), + scopes: vec!["openid".to_string()], + }; + let moonshot_tokens = OAuthTokenSet { + access_token: "moonshot-access".to_string(), + refresh_token: None, + expires_at: Some(2000), + scopes: vec!["profile".to_string()], + }; + + save_provider_oauth("openai", &openai_tokens).expect("save openai"); + save_provider_oauth("moonshot", &moonshot_tokens).expect("save moonshot"); + + assert_eq!( + load_provider_oauth("openai").expect("load openai"), + Some(openai_tokens.clone()) + ); + assert_eq!( + load_provider_oauth("moonshot").expect("load moonshot"), + Some(moonshot_tokens.clone()) + ); + assert_eq!( + load_provider_oauth("unknown").expect("load unknown"), + None + ); + + let saved = std::fs::read_to_string(&path).expect("read saved file"); + assert!(saved.contains("\"oauth_providers\"")); + assert!(saved.contains("\"openai\"")); + assert!(saved.contains("\"moonshot\"")); + + clear_provider_oauth("openai").expect("clear openai"); + assert_eq!(load_provider_oauth("openai").expect("load cleared"), None); + assert_eq!( + load_provider_oauth("moonshot").expect("load moonshot after clear"), + Some(moonshot_tokens) + ); + + clear_provider_oauth("moonshot").expect("clear moonshot"); + let cleared = std::fs::read_to_string(&path).expect("read cleared file"); + assert!(!cleared.contains("\"oauth_providers\"")); + + std::env::remove_var("CLAW_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn provider_oauth_preserves_legacy_oauth_key() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + let path = credentials_path().expect("credentials path"); + std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent"); + + let legacy = OAuthTokenSet { + access_token: "legacy-access".to_string(), + refresh_token: Some("legacy-refresh".to_string()), + expires_at: Some(999), + scopes: vec!["org:read".to_string()], + }; + save_oauth_credentials(&legacy).expect("save legacy"); + + let provider = OAuthTokenSet { + access_token: "provider-access".to_string(), + refresh_token: None, + expires_at: Some(888), + scopes: vec!["user:read".to_string()], + }; + save_provider_oauth("openai", &provider).expect("save provider"); + + assert_eq!( + load_oauth_credentials().expect("load legacy"), + Some(legacy) + ); + assert_eq!( + load_provider_oauth("openai").expect("load provider"), + Some(provider) + ); + + std::env::remove_var("CLAW_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn callback_server_returns_code_and_state() { + use std::io::{Read, Write}; + use std::net::TcpStream; + use std::thread; + + let port = 4547; + let server_thread = thread::spawn(move || { + run_oauth_callback_server(port, std::time::Duration::from_secs(5)) + }); + + // Give the server a moment to bind + thread::sleep(std::time::Duration::from_millis(100)); + + // Simulate browser callback + let request = format!( + "GET /callback?code=test-code-123&state=test-state-456 HTTP/1.1\r\nHost: localhost:{port}\r\n\r\n" + ); + let mut stream = TcpStream::connect(format!("127.0.0.1:{port}")).expect("connect to callback server"); + stream.write_all(request.as_bytes()).expect("send request"); + stream.flush().expect("flush"); + + // Read response (should be HTML success page) + let mut response = String::new(); + stream.read_to_string(&mut response).expect("read response"); + assert!(response.contains("200 OK"), "expected 200 OK, got: {response}"); + assert!(response.contains("Authentication Successful"), "expected success page"); + + let result = server_thread.join().expect("server thread join"); + let callback = result.expect("callback result"); + assert_eq!(callback.code, "test-code-123"); + assert_eq!(callback.state, "test-state-456"); + } + + #[test] + fn callback_server_returns_error_on_oauth_error() { + use std::io::{Read, Write}; + use std::net::TcpStream; + use std::thread; + + let port = 4548; + let server_thread = thread::spawn(move || { + run_oauth_callback_server(port, std::time::Duration::from_secs(5)) + }); + + thread::sleep(std::time::Duration::from_millis(100)); + + let request = format!( + "GET /callback?error=access_denied&error_description=user%20denied HTTP/1.1\r\nHost: localhost:{port}\r\n\r\n" + ); + let mut stream = TcpStream::connect(format!("127.0.0.1:{port}")).expect("connect"); + stream.write_all(request.as_bytes()).expect("send"); + stream.flush().expect("flush"); + + let mut response = String::new(); + stream.read_to_string(&mut response).expect("read"); + assert!(response.contains("400 Bad Request"), "expected 400, got: {response}"); + + let result = server_thread.join().expect("join"); + assert!(result.is_err(), "expected error for OAuth error response"); + } + + #[test] + fn callback_server_times_out_when_no_request() { + let port = 4549; + let result = run_oauth_callback_server(port, std::time::Duration::from_millis(50)); + assert!(result.is_err(), "expected timeout error"); + } + + #[test] + fn html_escape_works() { + assert_eq!(html_escape(""), "<script>alert('xss')</script>"); + assert_eq!(html_escape("foo & bar"), "foo & bar"); + assert_eq!(html_escape("\"quoted\""), ""quoted""); + } } diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 2448b9b1bf..a8994f9674 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -14851,7 +14851,15 @@ mod dump_manifests_tests { #[cfg(test)] mod auth_tests { - use super::{parse_args, CliAction}; + use std::sync::{Mutex, OnceLock}; + use super::{parse_args, check_model_auth_available, BUILTIN_PROVIDERS, LOGIN_PROVIDER_TEMPLATES, CliAction}; + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } #[test] fn parse_args_auth_without_provider() { @@ -14872,4 +14880,93 @@ mod auth_tests { other => panic!("expected CliAction::Auth, got {other:?}"), } } + + #[test] + fn openai_builtin_provider_has_oauth_config() { + let openai = BUILTIN_PROVIDERS + .iter() + .find(|p| p.id == "openai") + .expect("openai provider exists"); + assert!(openai.oauth.is_some(), "openai should have OAuth config"); + let oauth = openai.oauth.unwrap(); + assert_eq!(oauth.client_id, "app_EMoamEEZ73f0CkXaXp7hrann"); + assert_eq!(oauth.callback_port, 1455); + } + + #[test] + fn anthropic_and_xai_have_no_oauth() { + for provider in BUILTIN_PROVIDERS.iter() { + if provider.id == "anthropic" || provider.id == "xai" { + assert!(provider.oauth.is_none(), "{} should not have OAuth", provider.id); + } + } + } + + #[test] + fn moonshot_template_has_device_oauth() { + let moonshot = LOGIN_PROVIDER_TEMPLATES + .iter() + .find(|p| p.id == "moonshot") + .expect("moonshot template exists"); + assert!(moonshot.oauth.is_some(), "moonshot should have OAuth config"); + let oauth = moonshot.oauth.unwrap(); + assert_eq!(oauth.client_id, "17e5f671-d194-4dfb-9706-5516cb48c098"); + } + + #[test] + fn zai_and_minimax_have_no_oauth() { + for template in LOGIN_PROVIDER_TEMPLATES.iter() { + if template.id == "zai" || template.id == "minimax-coding-plan" { + assert!( + template.oauth.is_none(), + "{} should not have OAuth", + template.id + ); + } + } + } + + #[test] + fn check_model_auth_available_detects_saved_oauth() { + let _guard = env_lock(); + let config_home = std::env::temp_dir().join(format!( + "claw-oauth-auth-test-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() + )); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::fs::create_dir_all(&config_home).expect("create config home"); + + // Ensure no env var is set + std::env::remove_var("OPENAI_API_KEY"); + std::env::remove_var("DASHSCOPE_API_KEY"); + + // Without saved OAuth, auth should not be available + assert!( + !check_model_auth_available("gpt-4o").expect("check auth"), + "auth should be unavailable without env or saved tokens" + ); + + // Save an OAuth token for openai + let token_set = runtime::OAuthTokenSet { + access_token: "test-access-token".to_string(), + refresh_token: Some("test-refresh".to_string()), + expires_at: Some(9999999999), + scopes: vec!["openid".to_string()], + }; + runtime::save_provider_oauth("openai", &token_set).expect("save token"); + + // With saved OAuth, auth should be available + assert!( + check_model_auth_available("gpt-4o").expect("check auth with oauth"), + "auth should be available with saved OAuth token" + ); + + // Clean up + std::env::remove_var("CLAW_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).ok(); + } } From 6660f7fce1d9c73e28a315cf9aac3d2391a3d9dd Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 14:33:18 -0300 Subject: [PATCH 06/15] fix(auth): remove duplicate OpenAI entry from login provider templates OpenAI was listed twice in the welcome screen because it existed in both BUILTIN_PROVIDERS and LOGIN_PROVIDER_TEMPLATES. Remove the duplicate from templates since OpenAI is already a built-in provider with OAuth support. --- rust/crates/rusty-claude-cli/src/main.rs | 34 ------------------------ 1 file changed, 34 deletions(-) diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index a8994f9674..f44e7eacdb 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -1639,40 +1639,6 @@ const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ default_model: "MiniMax-M2.7-highspeed", oauth: None, }, - LoginProviderTemplate { - id: "openai", - label: "OpenAI", - provider_type: "openai-compatible", - base_url: "https://api.openai.com/v1", - api_key_env: "OPENAI_API_KEY", - models: &[ - "gpt-5-codex", - "gpt-5.1-codex", - "gpt-5.1-codex-max", - "gpt-5.1-codex-mini", - "gpt-5.2", - "gpt-5.2-codex", - "gpt-5.3-codex", - "gpt-5.3-codex-spark", - "gpt-5.4", - "gpt-5.4-fast", - "gpt-5.4-mini", - "gpt-5.4-mini-fast", - "gpt-5.5", - "gpt-5.5-fast", - "gpt-5.5-pro", - ], - default_model: "gpt-5.5", - oauth: Some(ProviderOAuthConfig { - client_id: "app_EMoamEEZ73f0CkXaXp7hrann", - callback_port: 1455, - flow: OAuthFlowType::Pkce { - authorize_url: "https://auth.openai.com/oauth/authorize", - token_url: "https://auth.openai.com/oauth/token", - scopes: &["openid", "profile", "email", "offline_access"], - }, - }), - }, LoginProviderTemplate { id: "kimi-for-coding", label: "Kimi For Coding", From f8be7cc1e5a73c3a4159958911b72a39fcfe5028 Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 14:41:31 -0300 Subject: [PATCH 07/15] fix(oauth): fix OpenAI OAuth PKCE flow compatibility - Add configurable redirect_path to ProviderOAuthConfig (default /callback, OpenAI uses /auth/callback to match Codex CLI) - Update callback server to validate against configurable path instead of hardcoded /callback - Remove non-standard 'state' parameter from token exchange request body - Add OpenAI-specific query params: id_token_add_organizations=true and codex_cli_simplified_flow=true for drop-in Codex CLI compatibility - Export loopback_redirect_uri_with_path from runtime --- rust/crates/runtime/src/lib.rs | 2 +- rust/crates/runtime/src/oauth.rs | 30 +++++++++++++++--------- rust/crates/rusty-claude-cli/src/main.rs | 27 +++++++++++++++++---- 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index 365e7ad6c8..6bafc11abe 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -114,7 +114,7 @@ pub use mcp_stdio::{ pub use oauth::{ clear_oauth_credentials, clear_provider_oauth, code_challenge_s256, credentials_path, generate_pkce_pair, generate_state, load_oauth_credentials, load_provider_oauth, - loopback_redirect_uri, open_browser, parse_oauth_callback_query, + loopback_redirect_uri, loopback_redirect_uri_with_path, open_browser, parse_oauth_callback_query, parse_oauth_callback_request_target, poll_device_token, run_oauth_callback_server, save_oauth_credentials, save_provider_oauth, DeviceAuthRequest, DeviceAuthResponse, OAuthAuthorizationRequest, OAuthCallbackParams, OAuthCallbackResult, OAuthRefreshRequest, diff --git a/rust/crates/runtime/src/oauth.rs b/rust/crates/runtime/src/oauth.rs index e49ab1732c..21794e96e3 100644 --- a/rust/crates/runtime/src/oauth.rs +++ b/rust/crates/runtime/src/oauth.rs @@ -207,7 +207,6 @@ impl OAuthTokenExchangeRequest { ("redirect_uri", self.redirect_uri.clone()), ("client_id", self.client_id.clone()), ("code_verifier", self.code_verifier.clone()), - ("state", self.state.clone()), ]) } } @@ -262,6 +261,11 @@ pub fn loopback_redirect_uri(port: u16) -> String { format!("http://localhost:{port}/callback") } +#[must_use] +pub fn loopback_redirect_uri_with_path(port: u16, path: &str) -> String { + format!("http://localhost:{port}{path}") +} + pub fn credentials_path() -> io::Result { Ok(credentials_home_dir()?.join("credentials.json")) } @@ -396,12 +400,13 @@ pub struct OAuthCallbackResult { pub state: String, } -/// Run a blocking local HTTP server that waits for a single `/callback` request. +/// Run a blocking local HTTP server that waits for a single callback request. /// Returns the authorization code and state on success. /// Times out after `timeout` duration. pub fn run_oauth_callback_server( port: u16, timeout: std::time::Duration, + callback_path: &str, ) -> io::Result { use std::io::{BufRead, BufReader, Write}; use std::net::{SocketAddr, TcpListener}; @@ -457,7 +462,7 @@ pub fn run_oauth_callback_server( } } - match parse_oauth_callback_request_target(target) { + match parse_oauth_callback_request_target(target, callback_path) { Ok(params) => { if let (Some(code), Some(state)) = (¶ms.code, ¶ms.state) { // Success page @@ -643,12 +648,15 @@ pub async fn poll_device_token( } } -pub fn parse_oauth_callback_request_target(target: &str) -> Result { +pub fn parse_oauth_callback_request_target( + target: &str, + expected_path: &str, +) -> Result { let (path, query) = target .split_once('?') .map_or((target, ""), |(path, query)| (path, query)); - if path != "/callback" { - return Err(format!("unexpected callback path: {path}")); + if path != expected_path { + return Err(format!("unexpected callback path: {path}, expected: {expected_path}")); } parse_oauth_callback_query(query) } @@ -941,11 +949,11 @@ mod tests { assert_eq!(params.state.as_deref(), Some("state-1")); assert_eq!(params.error_description.as_deref(), Some("needs login")); - let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz") + let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz", "/callback") .expect("parse callback target"); assert_eq!(params.code.as_deref(), Some("abc")); assert_eq!(params.state.as_deref(), Some("xyz")); - assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err()); + assert!(parse_oauth_callback_request_target("/wrong?code=abc", "/callback").is_err()); } #[test] @@ -1050,7 +1058,7 @@ mod tests { let port = 4547; let server_thread = thread::spawn(move || { - run_oauth_callback_server(port, std::time::Duration::from_secs(5)) + run_oauth_callback_server(port, std::time::Duration::from_secs(5), "/callback") }); // Give the server a moment to bind @@ -1084,7 +1092,7 @@ mod tests { let port = 4548; let server_thread = thread::spawn(move || { - run_oauth_callback_server(port, std::time::Duration::from_secs(5)) + run_oauth_callback_server(port, std::time::Duration::from_secs(5), "/callback") }); thread::sleep(std::time::Duration::from_millis(100)); @@ -1107,7 +1115,7 @@ mod tests { #[test] fn callback_server_times_out_when_no_request() { let port = 4549; - let result = run_oauth_callback_server(port, std::time::Duration::from_millis(50)); + let result = run_oauth_callback_server(port, std::time::Duration::from_millis(50), "/callback"); assert!(result.is_err(), "expected timeout error"); } diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index f44e7eacdb..668029cc2f 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -1674,6 +1674,7 @@ const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ oauth: Some(ProviderOAuthConfig { client_id: "17e5f671-d194-4dfb-9706-5516cb48c098", callback_port: 4546, + redirect_path: "/callback", flow: OAuthFlowType::Device { device_auth_url: "https://auth.kimi.com/api/oauth/device_authorization", token_url: "https://auth.kimi.com/api/oauth/token", @@ -8264,6 +8265,7 @@ enum OAuthFlowType { struct ProviderOAuthConfig { client_id: &'static str, callback_port: u16, + redirect_path: &'static str, flow: OAuthFlowType, } @@ -8292,6 +8294,7 @@ const BUILTIN_PROVIDERS: &[BuiltinProvider] = &[ oauth: Some(ProviderOAuthConfig { client_id: "app_EMoamEEZ73f0CkXaXp7hrann", callback_port: 1455, + redirect_path: "/auth/callback", flow: OAuthFlowType::Pkce { authorize_url: "https://auth.openai.com/oauth/authorize", token_url: "https://auth.openai.com/oauth/token", @@ -8597,7 +8600,7 @@ fn run_pkce_oauth_flow( oauth: &ProviderOAuthConfig, ) -> Result> { use runtime::{ - generate_pkce_pair, generate_state, loopback_redirect_uri, open_browser, + generate_pkce_pair, generate_state, loopback_redirect_uri_with_path, open_browser, run_oauth_callback_server, OAuthAuthorizationRequest, OAuthTokenExchangeRequest, }; @@ -8622,8 +8625,18 @@ fn run_pkce_oauth_flow( scopes: scopes.iter().map(|s| (*s).to_string()).collect(), }; - let auth_request = - OAuthAuthorizationRequest::from_config(&config, loopback_redirect_uri(oauth.callback_port), &state, &pkce); + let redirect_uri = loopback_redirect_uri_with_path(oauth.callback_port, oauth.redirect_path); + + let mut auth_request = + OAuthAuthorizationRequest::from_config(&config, &redirect_uri, &state, &pkce); + + // OpenAI-specific parameters required for drop-in Codex CLI compatibility + if provider_id == "openai" { + auth_request = auth_request + .with_extra_param("id_token_add_organizations", "true") + .with_extra_param("codex_cli_simplified_flow", "true"); + } + let auth_url = auth_request.build_url(); println!("Opening browser for OAuth authentication..."); @@ -8632,7 +8645,11 @@ fn run_pkce_oauth_flow( open_browser(&auth_url)?; println!("Waiting for authentication..."); - let callback = run_oauth_callback_server(oauth.callback_port, std::time::Duration::from_secs(300))?; + let callback = run_oauth_callback_server( + oauth.callback_port, + std::time::Duration::from_secs(300), + oauth.redirect_path, + )?; if callback.state != state { return Err("OAuth state mismatch. Possible CSRF attack.".into()); @@ -8648,7 +8665,7 @@ fn run_pkce_oauth_flow( &callback.code, &state, &pkce.verifier, - &loopback_redirect_uri(oauth.callback_port), + &redirect_uri, ); let response = client From 3dabd28db53dca7c34edc8519fd761353b8744b5 Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 14:50:04 -0300 Subject: [PATCH 08/15] fix(oauth): add originator parameter to OpenAI OAuth flow Add originator=codex_cli_rs to the OpenAI authorization URL to match the official Codex CLI OAuth flow exactly. This parameter is required by auth.openai.com for proper request handling. Refs: openai/codex#7184, 7shi/codex-oauth --- rust/crates/rusty-claude-cli/src/main.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 668029cc2f..0e0c8992fc 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -8634,7 +8634,8 @@ fn run_pkce_oauth_flow( if provider_id == "openai" { auth_request = auth_request .with_extra_param("id_token_add_organizations", "true") - .with_extra_param("codex_cli_simplified_flow", "true"); + .with_extra_param("codex_cli_simplified_flow", "true") + .with_extra_param("originator", "codex_cli_rs"); } let auth_url = auth_request.build_url(); From 4b2fd073ff8be06277c93bf0d84e12a21ae72560 Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 15:19:59 -0300 Subject: [PATCH 09/15] fix(auth,oauth): resolve OAuth tokens for configured providers and add Moonshot routing - Fix configured_provider_for_model to check saved OAuth tokens when env var is unset. Previously, users with modelProviders.openai in settings.json who authenticated via OAuth would get 'requires env var OPENAI_API_KEY' because the function only checked apiKey and apiKeyEnv. - Fix check_model_auth_available to use metadata_for_model for prefix-aware auth checking (openai/, moonshot/, gpt-, etc.), ensuring each provider gets its correct env var and OAuth store key. - Add moonshot/ prefix to metadata_for_model so detect_provider_kind routes Moonshot models correctly instead of falling through to Anthropic default. - Add OpenAiCompatConfig::moonshot() with DEFAULT_MOONSHOT_BASE_URL for native Moonshot API endpoint support. - Update ProviderClient construction to use metadata_for_model for prefix-aware config selection, enabling OAuth fallback for Moonshot too. --- rust/crates/api/src/client.rs | 21 ++++---- rust/crates/api/src/lib.rs | 2 +- rust/crates/api/src/providers/mod.rs | 9 ++++ .../crates/api/src/providers/openai_compat.rs | 16 +++++++ rust/crates/rusty-claude-cli/src/main.rs | 48 +++++++++++++++---- 5 files changed, 76 insertions(+), 20 deletions(-) diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index c96165d4aa..10f60d7bec 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -32,19 +32,22 @@ impl ProviderClient { OpenAiCompatConfig::xai(), )?)), ProviderKind::OpenAi => { - // DashScope models (qwen-*) also return ProviderKind::OpenAi because they - // speak the OpenAI wire format, but they need the DashScope config which - // reads DASHSCOPE_API_KEY and points at dashscope.aliyuncs.com. - let config = match providers::metadata_for_model(&resolved_model) { + // Use metadata_for_model for prefix-aware config selection. + // DashScope models (qwen-*, kimi-*) and Moonshot models (moonshot/*) + // all speak the OpenAI wire format but need different configs. + let (config, oauth_provider_id) = match providers::metadata_for_model(&resolved_model) { Some(meta) if meta.auth_env == "DASHSCOPE_API_KEY" => { - OpenAiCompatConfig::dashscope() + (OpenAiCompatConfig::dashscope(), None) } - _ => OpenAiCompatConfig::openai(), + Some(meta) if meta.auth_env == "MOONSHOT_API_KEY" => { + (OpenAiCompatConfig::moonshot(), Some("moonshot")) + } + _ => (OpenAiCompatConfig::openai(), Some("openai")), }; - // Try OAuth for OpenAI if env var is not set - if config.provider_name == "OpenAI" { + // Try OAuth if the provider supports it and env var is not set + if let Some(provider_id) = oauth_provider_id { Ok(Self::OpenAi(OpenAiCompatClient::from_env_or_oauth( - config, "openai", + config, provider_id, )?)) } else { Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?)) diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 0f04905e7a..5c8ffdfa07 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -25,7 +25,7 @@ pub use providers::openai_compat::{ }; pub use providers::{ detect_provider_kind, max_tokens_for_model, max_tokens_for_model_with_override, - resolve_model_alias, ProviderKind, + metadata_for_model, resolve_model_alias, ProviderKind, ProviderMetadata, }; pub use sse::{parse_frame, SseParser}; pub use types::{ diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index 86871a82a1..834c564464 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -216,6 +216,15 @@ pub fn metadata_for_model(model: &str) -> Option { default_base_url: openai_compat::DEFAULT_DASHSCOPE_BASE_URL, }); } + // Moonshot / Kimi models via native Moonshot API endpoint. + if canonical.starts_with("moonshot/") { + return Some(ProviderMetadata { + provider: ProviderKind::OpenAi, + auth_env: "MOONSHOT_API_KEY", + base_url_env: "MOONSHOT_BASE_URL", + default_base_url: openai_compat::DEFAULT_MOONSHOT_BASE_URL, + }); + } None } diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index 42eaf7fb9e..2d1eac6e61 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -19,6 +19,7 @@ use super::{preflight_message_request, Provider, ProviderFuture}; pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; pub const DEFAULT_DASHSCOPE_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1"; +pub const DEFAULT_MOONSHOT_BASE_URL: &str = "https://api.moonshot.ai/v1"; const REQUEST_ID_HEADER: &str = "request-id"; const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_secs(1); @@ -41,11 +42,13 @@ pub struct OpenAiCompatConfig { const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"]; const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"]; const DASHSCOPE_ENV_VARS: &[&str] = &["DASHSCOPE_API_KEY"]; +const MOONSHOT_ENV_VARS: &[&str] = &["MOONSHOT_API_KEY"]; // Provider-specific request body size limits in bytes const XAI_MAX_REQUEST_BODY_BYTES: usize = 52_428_800; // 50MB const OPENAI_MAX_REQUEST_BODY_BYTES: usize = 104_857_600; // 100MB const DASHSCOPE_MAX_REQUEST_BODY_BYTES: usize = 6_291_456; // 6MB (observed limit in dogfood) +const MOONSHOT_MAX_REQUEST_BODY_BYTES: usize = 52_428_800; // 50MB impl OpenAiCompatConfig { #[must_use] @@ -85,12 +88,25 @@ impl OpenAiCompatConfig { } } + /// Moonshot AI native endpoint (Kimi family models). + #[must_use] + pub const fn moonshot() -> Self { + Self { + provider_name: "Moonshot", + api_key_env: "MOONSHOT_API_KEY", + base_url_env: "MOONSHOT_BASE_URL", + default_base_url: DEFAULT_MOONSHOT_BASE_URL, + max_request_body_bytes: MOONSHOT_MAX_REQUEST_BODY_BYTES, + } + } + #[must_use] pub fn credential_env_vars(self) -> &'static [&'static str] { match self.provider_name { "xAI" => XAI_ENV_VARS, "OpenAI" => OPENAI_ENV_VARS, "DashScope" => DASHSCOPE_ENV_VARS, + "Moonshot" => MOONSHOT_ENV_VARS, _ => &[], } } diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 0e0c8992fc..08329b514d 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -24,11 +24,11 @@ use std::thread::{self, JoinHandle}; use std::time::{Duration, Instant, UNIX_EPOCH}; use api::{ - anthropic_has_auth, detect_provider_kind, has_api_key, resolve_startup_auth_source, - AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, InputMessage, - MessageRequest, MessageResponse, OutputContentBlock, PromptCache, - ProviderClient as ApiProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, - ToolDefinition, ToolResultContentBlock, + anthropic_has_auth, detect_provider_kind, has_api_key, metadata_for_model, + resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, + InputMessage, MessageRequest, MessageResponse, OutputContentBlock, PromptCache, + ProviderClient as ApiProviderClient, ProviderKind, ProviderMetadata, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use commands::{ @@ -1555,8 +1555,17 @@ fn configured_provider_for_model( let api_key = if let Some(api_key) = provider.api_key().filter(|value| !value.is_empty()) { api_key.to_string() } else if let Some(env_name) = provider.api_key_env() { - env::var(env_name) - .map_err(|_| format!("model provider '{provider_name}' requires env var {env_name}"))? + if let Ok(key) = env::var(env_name) { + key + } else if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_name) { + // Fall back to saved OAuth bearer token when env var is unset + token_set.access_token + } else { + return Err(format!( + "model provider '{provider_name}' requires env var {env_name}" + ) + .into()); + } } else { return Err( format!("model provider '{provider_name}' requires apiKeyEnv or apiKey").into(), @@ -8313,16 +8322,35 @@ const BUILTIN_PROVIDERS: &[BuiltinProvider] = &[ fn check_model_auth_available(model: &str) -> Result> { let resolved = api::resolve_model_alias(model); + + // Use metadata_for_model for prefix-aware auth checking so each provider + // gets its correct env var and OAuth store key. + if let Some(meta) = api::metadata_for_model(&resolved) { + let has_env = has_api_key(meta.auth_env); + let has_oauth = match meta.auth_env { + "OPENAI_API_KEY" => { + runtime::load_provider_oauth("openai").ok().flatten().is_some() + } + "MOONSHOT_API_KEY" => { + runtime::load_provider_oauth("moonshot").ok().flatten().is_some() + } + _ => false, + }; + return Ok(has_env || has_oauth); + } + + // For bare model names without recognized prefix, fall back to + // provider-kind detection (env-var sniffing + last-resort defaults). let provider = detect_provider_kind(&resolved); let available = match provider { ProviderKind::Anthropic => anthropic_has_auth().unwrap_or(false), - ProviderKind::Xai => { - has_api_key("XAI_API_KEY") - } + ProviderKind::Xai => has_api_key("XAI_API_KEY"), ProviderKind::OpenAi => { has_api_key("OPENAI_API_KEY") || has_api_key("DASHSCOPE_API_KEY") + || has_api_key("MOONSHOT_API_KEY") || runtime::load_provider_oauth("openai").ok().flatten().is_some() + || runtime::load_provider_oauth("moonshot").ok().flatten().is_some() } }; Ok(available) From 367a341bbc7c92413fb6eede93fa433b1137da39 Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 15:30:33 -0300 Subject: [PATCH 10/15] fix(auth): make auth gates fully provider-agnostic - check_model_auth_available: for ANY provider/model prefix, check OAuth tokens under that provider name AND .claw.json config. Previously only hardcoded openai/moonshot OAuth checks existed. - run_auth_command: now recognizes custom providers from .claw.json, allowing to work for any configured provider. - run_provider_welcome: shows custom providers from .claw.json in the interactive picker, so users can authenticate with them without memorizing env var names. - Template/built-in OAuth flows were already generic; these changes make the surrounding auth gates and CLI commands match that generality. --- rust/crates/rusty-claude-cli/src/main.rs | 157 ++++++++++++++++++++++- 1 file changed, 152 insertions(+), 5 deletions(-) diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 08329b514d..bb4b7c6071 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -8323,10 +8323,43 @@ const BUILTIN_PROVIDERS: &[BuiltinProvider] = &[ fn check_model_auth_available(model: &str) -> Result> { let resolved = api::resolve_model_alias(model); - // Use metadata_for_model for prefix-aware auth checking so each provider - // gets its correct env var and OAuth store key. + // If model has a provider/ prefix, check auth for that specific provider. + // This works for ANY provider (built-in, template, or custom from .claw.json). + if let Some((provider_name, _)) = resolved.split_once('/') { + // Generic OAuth check: any provider may have a saved OAuth token + if runtime::load_provider_oauth(provider_name).ok().flatten().is_some() { + return Ok(true); + } + // Check the provider's config from .claw.json for env var or hardcoded key + if let Ok(cwd) = env::current_dir() { + if let Ok(config) = ConfigLoader::default_for(&cwd).load() { + if let Some(provider) = config.model_providers().get(provider_name) { + if provider.api_key().filter(|k| !k.is_empty()).is_some() { + return Ok(true); + } + if let Some(env_name) = provider.api_key_env() { + if has_api_key(env_name) { + return Ok(true); + } + } + // Configured but no auth available + return Ok(false); + } + } + } + // Not configured in .claw.json - check built-in provider env vars + return Ok(match provider_name { + "openai" => has_api_key("OPENAI_API_KEY"), + "xai" => has_api_key("XAI_API_KEY"), + "anthropic" => anthropic_has_auth().unwrap_or(false), + _ => false, + }); + } + + // No prefix - use metadata_for_model for built-in model detection if let Some(meta) = api::metadata_for_model(&resolved) { let has_env = has_api_key(meta.auth_env); + // Generic OAuth check based on auth_env mapping let has_oauth = match meta.auth_env { "OPENAI_API_KEY" => { runtime::load_provider_oauth("openai").ok().flatten().is_some() @@ -8339,8 +8372,7 @@ fn check_model_auth_available(model: &str) -> Result anthropic_has_auth().unwrap_or(false), @@ -8398,6 +8430,43 @@ fn run_provider_welcome( ); } + // Load custom providers from .claw.json that need authentication + let custom_providers: Vec<(String, String)> = if let Ok(cwd) = env::current_dir() { + if let Ok(config) = ConfigLoader::default_for(&cwd).load() { + config + .model_providers() + .iter() + .filter(|(name, _)| { + // Exclude providers already shown as built-in or templates + !BUILTIN_PROVIDERS.iter().any(|p| p.id == name.as_str()) + && !LOGIN_PROVIDER_TEMPLATES.iter().any(|t| t.id == name.as_str()) + }) + .filter(|(_, provider)| { + // Only show providers that actually need auth (have apiKeyEnv) + provider.api_key_env().is_some() + }) + .map(|(name, provider)| (name.clone(), provider.provider_type().to_string())) + .collect() + } else { + Vec::new() + } + } else { + Vec::new() + }; + let custom_count = custom_providers.len(); + + if !custom_providers.is_empty() { + println!("\n Custom providers:"); + for (i, (name, provider_type)) in custom_providers.iter().enumerate() { + println!( + " {}. {name} ({provider_type})", + builtin_count + template_count + i + 1 + ); + } + } + + let total = builtin_count + template_count + custom_count; + print!("\nEnter number (1-{total}): "); std::io::stdout().flush()?; let mut choice = String::new(); @@ -8493,7 +8562,50 @@ fn run_provider_welcome( template.default_model, )?; - Ok(Some(format!("{}/{}", template.id, template.default_model))) + // Custom provider from .claw.json selected + let custom_index = index - builtin_count - template_count - 1; + let custom_name = custom_providers + .get(custom_index) + .map(|(name, _)| name.as_str()) + .expect("valid custom provider index"); + + if let Ok(cwd) = env::current_dir() { + if let Ok(config) = ConfigLoader::default_for(&cwd).load() { + if let Some(provider) = config.model_providers().get(custom_name) { + if let Some(env_name) = provider.api_key_env() { + print!("Enter {} (or press Enter to cancel): ", env_name); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + if key.is_empty() { + println!("Cancelled."); + return Ok(None); + } + std::env::set_var(env_name, key); + // Update the provider profile with the new key + save_model_provider_profile( + custom_name, + provider.provider_type(), + provider.base_url(), + env_name, + Some(key), + &provider.models().iter().cloned().collect::>(), + provider.default_model().unwrap_or(""), + )?; + if let Some(default_model) = provider.default_model() { + return Ok(Some(format!("{custom_name}/{default_model}"))); + } + return Ok(Some(custom_name.to_string())); + } + } + } + } + + Err(format!( + "Could not authenticate with custom provider '{custom_name}'." + ) + .into()) } fn run_auth_command(provider: Option<&str>) -> Result<(), Box> { @@ -8601,6 +8713,41 @@ fn run_auth_command(provider: Option<&str>) -> Result<(), Box", + provider.api_key_env().unwrap_or("API_KEY") + ) + .into()); + } + if let Some(env_name) = provider.api_key_env() { + print!("Enter {} (or press Enter to cancel): ", env_name); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + if key.is_empty() { + println!("Cancelled."); + return Ok(()); + } + std::env::set_var(env_name, key); + println!("Authentication set for {provider_id}."); + return Ok(()); + } + return Err(format!( + "provider '{provider_id}' has no apiKeyEnv configured. Add it to your .claw.json." + ) + .into()); + } + } + } + return Err(format!( "unknown provider: '{provider_id}'. Run `claw auth` to see available providers." ) From bc5559ecd45f5a9241b9e15c7f3cb3af935623fd Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 16:14:26 -0300 Subject: [PATCH 11/15] fix(auth): remove OAuth from built-in OpenAI provider OpenAI's OAuth (auth.openai.com, Codex CLI client_id) produces ChatGPT/WHAM-backend tokens, NOT Platform API tokens. These tokens authenticate your ChatGPT account, not your OpenAI Platform account. They work with chatgpt.com/backend-api but return 401 Unauthorized (Missing scopes: model.request) on api.openai.com/v1. OpenAI Platform API requires API keys (sk-...) only. Changes: - Remove OAuth config from BUILTIN_PROVIDERS[openai] - Skip OAuth fallback in check_model_auth_available for openai prefix & bare names - Skip OAuth fallback in configured_provider_for_model for openai provider - Use from_env (not from_env_or_oauth) for OpenAI in ProviderClient - Update tests: assert openai has no OAuth, test moonshot OAuth instead --- .gitignore | 1 + rust/crates/api/src/client.rs | 5 +- rust/crates/rusty-claude-cli/src/main.rs | 72 ++++++++++++++---------- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index 919ab8387f..6d3a0fc459 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ archive/ .claw/sessions/ .clawhip/ status-help.txt +.DS_Store diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 10f60d7bec..796e24e92c 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -42,7 +42,10 @@ impl ProviderClient { Some(meta) if meta.auth_env == "MOONSHOT_API_KEY" => { (OpenAiCompatConfig::moonshot(), Some("moonshot")) } - _ => (OpenAiCompatConfig::openai(), Some("openai")), + // OpenAI Platform API requires API keys, not OAuth tokens. + // OAuth tokens from auth.openai.com are WHAM-backend tokens + // (for chatgpt.com/backend-api) and return 401 on api.openai.com. + _ => (OpenAiCompatConfig::openai(), None), }; // Try OAuth if the provider supports it and env var is not set if let Some(provider_id) = oauth_provider_id { diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index bb4b7c6071..30a89b169a 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -1557,9 +1557,19 @@ fn configured_provider_for_model( } else if let Some(env_name) = provider.api_key_env() { if let Ok(key) = env::var(env_name) { key - } else if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_name) { - // Fall back to saved OAuth bearer token when env var is unset - token_set.access_token + } else if provider_name != "openai" { + // Fall back to saved OAuth bearer token when env var is unset. + // Skip for "openai": OAuth tokens from auth.openai.com are + // ChatGPT/WHAM-backend tokens, NOT Platform API tokens. They + // return 401 Unauthorized on api.openai.com/v1. + if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_name) { + token_set.access_token + } else { + return Err(format!( + "model provider '{provider_name}' requires env var {env_name}" + ) + .into()); + } } else { return Err(format!( "model provider '{provider_name}' requires env var {env_name}" @@ -8300,16 +8310,13 @@ const BUILTIN_PROVIDERS: &[BuiltinProvider] = &[ label: "OpenAI", env_var: "OPENAI_API_KEY", default_model: "gpt-4o", - oauth: Some(ProviderOAuthConfig { - client_id: "app_EMoamEEZ73f0CkXaXp7hrann", - callback_port: 1455, - redirect_path: "/auth/callback", - flow: OAuthFlowType::Pkce { - authorize_url: "https://auth.openai.com/oauth/authorize", - token_url: "https://auth.openai.com/oauth/token", - scopes: &["openid", "profile", "email", "offline_access"], - }, - }), + // NOTE: OpenAI's OAuth (auth.openai.com, Codex CLI client_id) produces + // ChatGPT/WHAM-backend tokens, NOT Platform API tokens. These tokens + // authenticate your ChatGPT account, not your OpenAI Platform account. + // They work with chatgpt.com/backend-api but return 401 Unauthorized + // (Missing scopes: model.request) on api.openai.com/v1. + // OpenAI Platform API requires API keys (sk-...) only. + oauth: None, }, BuiltinProvider { id: "xai", @@ -8326,8 +8333,12 @@ fn check_model_auth_available(model: &str) -> Result Result { - runtime::load_provider_oauth("openai").ok().flatten().is_some() - } "MOONSHOT_API_KEY" => { runtime::load_provider_oauth("moonshot").ok().flatten().is_some() } @@ -8381,7 +8391,6 @@ fn check_model_auth_available(model: &str) -> Result Date: Fri, 1 May 2026 16:27:01 -0300 Subject: [PATCH 12/15] feat(auth): remove custom providers from welcome screen Only show built-in and additional (template) providers in the interactive welcome/auth flow. Custom providers from .claw.json are no longer listed. --- rust/crates/rusty-claude-cli/src/main.rs | 83 +----------------------- 1 file changed, 3 insertions(+), 80 deletions(-) diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 30a89b169a..afa803661d 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -8439,42 +8439,7 @@ fn run_provider_welcome( ); } - // Load custom providers from .claw.json that need authentication - let custom_providers: Vec<(String, String)> = if let Ok(cwd) = env::current_dir() { - if let Ok(config) = ConfigLoader::default_for(&cwd).load() { - config - .model_providers() - .iter() - .filter(|(name, _)| { - // Exclude providers already shown as built-in or templates - !BUILTIN_PROVIDERS.iter().any(|p| p.id == name.as_str()) - && !LOGIN_PROVIDER_TEMPLATES.iter().any(|t| t.id == name.as_str()) - }) - .filter(|(_, provider)| { - // Only show providers that actually need auth (have apiKeyEnv) - provider.api_key_env().is_some() - }) - .map(|(name, provider)| (name.clone(), provider.provider_type().to_string())) - .collect() - } else { - Vec::new() - } - } else { - Vec::new() - }; - let custom_count = custom_providers.len(); - - if !custom_providers.is_empty() { - println!("\n Custom providers:"); - for (i, (name, provider_type)) in custom_providers.iter().enumerate() { - println!( - " {}. {name} ({provider_type})", - builtin_count + template_count + i + 1 - ); - } - } - - let total = builtin_count + template_count + custom_count; + let total = builtin_count + template_count; print!("\nEnter number (1-{total}): "); std::io::stdout().flush()?; @@ -8571,50 +8536,8 @@ fn run_provider_welcome( template.default_model, )?; - // Custom provider from .claw.json selected - let custom_index = index - builtin_count - template_count - 1; - let custom_name = custom_providers - .get(custom_index) - .map(|(name, _)| name.as_str()) - .expect("valid custom provider index"); - - if let Ok(cwd) = env::current_dir() { - if let Ok(config) = ConfigLoader::default_for(&cwd).load() { - if let Some(provider) = config.model_providers().get(custom_name) { - if let Some(env_name) = provider.api_key_env() { - print!("Enter {} (or press Enter to cancel): ", env_name); - std::io::stdout().flush()?; - let mut key = String::new(); - std::io::stdin().read_line(&mut key)?; - let key = key.trim(); - if key.is_empty() { - println!("Cancelled."); - return Ok(None); - } - std::env::set_var(env_name, key); - // Update the provider profile with the new key - save_model_provider_profile( - custom_name, - provider.provider_type(), - provider.base_url(), - env_name, - Some(key), - &provider.models().iter().cloned().collect::>(), - provider.default_model().unwrap_or(""), - )?; - if let Some(default_model) = provider.default_model() { - return Ok(Some(format!("{custom_name}/{default_model}"))); - } - return Ok(Some(custom_name.to_string())); - } - } - } - } - - Err(format!( - "Could not authenticate with custom provider '{custom_name}'." - ) - .into()) + // No custom providers in welcome — only built-in and templates. + Err("Invalid selection".into()) } fn run_auth_command(provider: Option<&str>) -> Result<(), Box> { From eb24a64f3118daba9dc5eaee402c47a984f47995 Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 17:19:24 -0300 Subject: [PATCH 13/15] feat(auth): implement OpenAI WHAM backend with OAuth + token refresh OpenAI OAuth tokens (from auth.openai.com) are ChatGPT/WHAM-backend tokens, NOT Platform API tokens. This implements full support for using them: - Add id_token to OAuthTokenSet for chatgpt_account_id extraction - Add extract_chatgpt_account_id() from JWT payload (no sig verification) - Add refresh_oauth_token() async function for token refresh - Create WhamClient with Responses API streaming support - Endpoint: chatgpt.com/backend-api/wham/responses - SSE parser for response.output_text.delta events - ChatGPT-Account-Id header from JWT - Automatic token refresh before requests if <60s until expiry - Route OpenAI OAuth through WhamClient in ProviderClient and CLI - Override base_url to WHAM backend when OpenAI OAuth is used in config Token refresh: POST to auth.openai.com/oauth/token with refresh_token grant. New tokens are persisted back to ~/.claw/credentials.json automatically. --- rust/Cargo.lock | 1 + rust/crates/api/src/client.rs | 41 +- rust/crates/api/src/lib.rs | 1 + rust/crates/api/src/providers/anthropic.rs | 5 + rust/crates/api/src/providers/mod.rs | 1 + rust/crates/api/src/providers/wham.rs | 738 +++++++++++++++++++++ rust/crates/runtime/Cargo.toml | 1 + rust/crates/runtime/src/lib.rs | 13 +- rust/crates/runtime/src/oauth.rs | 119 ++++ rust/crates/rusty-claude-cli/src/main.rs | 105 +-- 10 files changed, 972 insertions(+), 53 deletions(-) create mode 100644 rust/crates/api/src/providers/wham.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 78b9feeff4..d25dcbabc9 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1361,6 +1361,7 @@ dependencies = [ name = "runtime" version = "0.1.0" dependencies = [ + "base64", "glob", "plugins", "regex", diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 796e24e92c..388d60e6a4 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -2,6 +2,7 @@ use crate::error::ApiError; use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats}; use crate::providers::anthropic::{self, AnthropicClient, AuthSource}; use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig}; +use crate::providers::wham::{self, WhamClient}; use crate::providers::{self, ProviderKind}; use crate::types::{MessageRequest, MessageResponse, StreamEvent}; @@ -11,6 +12,8 @@ pub enum ProviderClient { Anthropic(AnthropicClient), Xai(OpenAiCompatClient), OpenAi(OpenAiCompatClient), + /// OpenAI WHAM backend (chatgpt.com/backend-api/wham) using ChatGPT OAuth tokens. + Wham(WhamClient), } impl ProviderClient { @@ -42,13 +45,29 @@ impl ProviderClient { Some(meta) if meta.auth_env == "MOONSHOT_API_KEY" => { (OpenAiCompatConfig::moonshot(), Some("moonshot")) } - // OpenAI Platform API requires API keys, not OAuth tokens. - // OAuth tokens from auth.openai.com are WHAM-backend tokens - // (for chatgpt.com/backend-api) and return 401 on api.openai.com. - _ => (OpenAiCompatConfig::openai(), None), + _ => (OpenAiCompatConfig::openai(), Some("openai")), }; // Try OAuth if the provider supports it and env var is not set if let Some(provider_id) = oauth_provider_id { + if provider_id == "openai" { + // OpenAI OAuth tokens are WHAM-backend tokens (chatgpt.com/backend-api/wham), + // NOT Platform API tokens. Route to WhamClient when using OAuth. + if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_id) { + let account_id = token_set + .id_token + .as_deref() + .and_then(runtime::extract_chatgpt_account_id) + .or_else(|| { + runtime::extract_chatgpt_account_id(&token_set.access_token) + }); + return Ok(Self::Wham(WhamClient::from_oauth_token_set( + token_set, + account_id, + "https://auth.openai.com/oauth/token", + "app_EMoamEEZ73f0CkXaXp7hrann", + ))); + } + } Ok(Self::OpenAi(OpenAiCompatClient::from_env_or_oauth( config, provider_id, )?)) @@ -82,7 +101,7 @@ impl ProviderClient { match self { Self::Anthropic(_) => ProviderKind::Anthropic, Self::Xai(_) => ProviderKind::Xai, - Self::OpenAi(_) => ProviderKind::OpenAi, + Self::OpenAi(_) | Self::Wham(_) => ProviderKind::OpenAi, } } @@ -98,7 +117,7 @@ impl ProviderClient { pub fn prompt_cache_stats(&self) -> Option { match self { Self::Anthropic(client) => client.prompt_cache_stats(), - Self::Xai(_) | Self::OpenAi(_) => None, + Self::Xai(_) | Self::OpenAi(_) | Self::Wham(_) => None, } } @@ -106,7 +125,7 @@ impl ProviderClient { pub fn take_last_prompt_cache_record(&self) -> Option { match self { Self::Anthropic(client) => client.take_last_prompt_cache_record(), - Self::Xai(_) | Self::OpenAi(_) => None, + Self::Xai(_) | Self::OpenAi(_) | Self::Wham(_) => None, } } @@ -117,6 +136,7 @@ impl ProviderClient { match self { Self::Anthropic(client) => client.send_message(request).await, Self::Xai(client) | Self::OpenAi(client) => client.send_message(request).await, + Self::Wham(client) => client.send_message(request).await, } } @@ -133,6 +153,10 @@ impl ProviderClient { .stream_message(request) .await .map(MessageStream::OpenAiCompat), + Self::Wham(client) => client + .stream_message(request) + .await + .map(MessageStream::Wham), } } } @@ -141,6 +165,7 @@ impl ProviderClient { pub enum MessageStream { Anthropic(anthropic::MessageStream), OpenAiCompat(openai_compat::MessageStream), + Wham(wham::WhamMessageStream), } impl MessageStream { @@ -149,6 +174,7 @@ impl MessageStream { match self { Self::Anthropic(stream) => stream.request_id(), Self::OpenAiCompat(stream) => stream.request_id(), + Self::Wham(stream) => stream.request_id(), } } @@ -156,6 +182,7 @@ impl MessageStream { match self { Self::Anthropic(stream) => stream.next_event().await, Self::OpenAiCompat(stream) => stream.next_event().await, + Self::Wham(stream) => stream.next_event().await, } } } diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 5c8ffdfa07..e14e9f0074 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -23,6 +23,7 @@ pub use providers::openai_compat::{ build_chat_completion_request, flatten_tool_result_content, has_api_key, is_reasoning_model, model_rejects_is_error_field, translate_message, OpenAiCompatClient, OpenAiCompatConfig, }; +pub use providers::wham::{WhamClient, DEFAULT_WHAM_BASE_URL}; pub use providers::{ detect_provider_kind, max_tokens_for_model, max_tokens_for_model_with_override, metadata_for_model, resolve_model_alias, ProviderKind, ProviderMetadata, diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs index 7c9f02945e..bb7e19d072 100644 --- a/rust/crates/api/src/providers/anthropic.rs +++ b/rust/crates/api/src/providers/anthropic.rs @@ -707,6 +707,7 @@ fn resolve_saved_oauth_token_set( refresh_token: resolved.refresh_token.clone(), expires_at: resolved.expires_at, scopes: resolved.scopes.clone(), + id_token: None, }) .map_err(ApiError::from)?; Ok(resolved) @@ -1165,6 +1166,7 @@ mod tests { refresh_token: Some("refresh".to_string()), expires_at: Some(now_unix_timestamp() + 300), scopes: vec!["scope:a".to_string()], + id_token: None, }) .expect("save oauth credentials"); @@ -1204,6 +1206,7 @@ mod tests { refresh_token: Some("refresh-token".to_string()), expires_at: Some(1), scopes: vec!["scope:a".to_string()], + id_token: None, }) .expect("save expired oauth credentials"); @@ -1236,6 +1239,7 @@ mod tests { refresh_token: Some("refresh".to_string()), expires_at: Some(now_unix_timestamp() + 300), scopes: vec!["scope:a".to_string()], + id_token: None, }) .expect("save oauth credentials"); @@ -1260,6 +1264,7 @@ mod tests { refresh_token: Some("refresh-token".to_string()), expires_at: Some(1), scopes: vec!["scope:a".to_string()], + id_token: None, }) .expect("save expired oauth credentials"); diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index 834c564464..d4562d2b4d 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -9,6 +9,7 @@ use crate::types::{MessageRequest, MessageResponse}; pub mod anthropic; pub mod openai_compat; +pub mod wham; #[allow(dead_code)] pub type ProviderFuture<'a, T> = Pin> + Send + 'a>>; diff --git a/rust/crates/api/src/providers/wham.rs b/rust/crates/api/src/providers/wham.rs new file mode 100644 index 0000000000..fb715b71ab --- /dev/null +++ b/rust/crates/api/src/providers/wham.rs @@ -0,0 +1,738 @@ +use std::collections::VecDeque; + +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +use crate::error::ApiError; +use crate::http_client::build_http_client_or_default; +use crate::types::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, + MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, Usage, +}; + +use super::{Provider, ProviderFuture}; + +pub const DEFAULT_WHAM_BASE_URL: &str = "https://chatgpt.com/backend-api/wham"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; + +#[derive(Debug, Clone)] +pub struct WhamClient { + http: reqwest::Client, + token: std::sync::Arc>, + base_url: String, +} + +#[derive(Debug, Clone)] +struct WhamToken { + access_token: String, + refresh_token: Option, + expires_at: Option, + account_id: Option, + token_url: String, + client_id: String, +} + +impl WhamClient { + #[must_use] + pub fn new(access_token: impl Into, account_id: Option) -> Self { + Self { + http: build_http_client_or_default(), + token: std::sync::Arc::new(std::sync::Mutex::new(WhamToken { + access_token: access_token.into(), + refresh_token: None, + expires_at: None, + account_id, + token_url: "https://auth.openai.com/oauth/token".to_string(), + client_id: "app_EMoamEEZ73f0CkXaXp7hrann".to_string(), + })), + base_url: DEFAULT_WHAM_BASE_URL.to_string(), + } + } + + /// Create a WHAM client with full OAuth token info, enabling automatic refresh. + #[must_use] + pub fn from_oauth_token_set( + token_set: runtime::OAuthTokenSet, + account_id: Option, + token_url: impl Into, + client_id: impl Into, + ) -> Self { + Self { + http: build_http_client_or_default(), + token: std::sync::Arc::new(std::sync::Mutex::new(WhamToken { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + account_id, + token_url: token_url.into(), + client_id: client_id.into(), + })), + base_url: DEFAULT_WHAM_BASE_URL.to_string(), + } + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + async fn ensure_token_valid(&self) -> Result<(), ApiError> { + let needs_refresh = { + let token = self.token.lock().map_err(|e| { + ApiError::Auth(format!("token mutex poisoned: {e}")) + })?; + match token.expires_at { + None => false, + Some(expires_at) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + // Refresh if fewer than 60 seconds remain. + now + 60 >= expires_at + } + } + }; + + if !needs_refresh { + return Ok(()); + } + + let (refresh_token, token_url, client_id) = { + let token = self.token.lock().map_err(|e| { + ApiError::Auth(format!("token mutex poisoned: {e}")) + })?; + let refresh = token.refresh_token.clone().ok_or_else(|| { + ApiError::Auth("OAuth token expired and no refresh token available".to_string()) + })?; + (refresh, token.token_url.clone(), token.client_id.clone()) + }; + + let new_token = runtime::refresh_oauth_token(&self.http, &token_url, &client_id, &refresh_token) + .await + .map_err(|e| ApiError::Auth(format!("OAuth token refresh failed: {e}")))?; + + { + let mut token = self.token.lock().map_err(|e| { + ApiError::Auth(format!("token mutex poisoned: {e}")) + })?; + token.access_token = new_token.access_token.clone(); + token.refresh_token = new_token.refresh_token.clone(); + token.expires_at = new_token.expires_at; + } + + // Persist the refreshed token. + let _ = runtime::save_provider_oauth("openai", &new_token); + + Ok(()) + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + self.ensure_token_valid().await?; + let (access_token, account_id) = { + let token = self.token.lock().map_err(|e| { + ApiError::Auth(format!("token mutex poisoned: {e}")) + })?; + (token.access_token.clone(), token.account_id.clone()) + }; + + let request_url = responses_endpoint(&self.base_url); + let body = build_responses_request(request); + + let mut req_builder = self + .http + .post(&request_url) + .header("content-type", "application/json") + .bearer_auth(&access_token); + + if let Some(ref id) = account_id { + req_builder = req_builder.header("ChatGPT-Account-Id", id); + } + + let response = req_builder.json(&body).send().await.map_err(ApiError::from)?; + let request_id = request_id_from_headers(response.headers()); + + if !response.status().is_success() { + let status = response.status(); + let body_text = response.text().await.unwrap_or_default(); + return Err(ApiError::Api { + status, + error_type: Some("wham_error".to_string()), + message: Some(body_text.clone()), + request_id, + body: body_text, + retryable: status.is_server_error(), + suggested_action: suggested_action_for_status(status), + }); + } + + let resp_body = response.text().await.map_err(ApiError::from)?; + let wham_resp: ResponsesResponse = serde_json::from_str(&resp_body).map_err(|error| { + ApiError::json_deserialize("OpenAI WHAM", &request.model, &resp_body, error) + })?; + + Ok(convert_to_message_response(wham_resp, request_id)) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + self.ensure_token_valid().await?; + let (access_token, account_id) = { + let token = self.token.lock().map_err(|e| { + ApiError::Auth(format!("token mutex poisoned: {e}")) + })?; + (token.access_token.clone(), token.account_id.clone()) + }; + + let request_url = responses_endpoint(&self.base_url); + let body = build_responses_request(request); + + let mut req_builder = self + .http + .post(&request_url) + .header("content-type", "application/json") + .bearer_auth(&access_token); + + if let Some(ref id) = account_id { + req_builder = req_builder.header("ChatGPT-Account-Id", id); + } + + let response = req_builder.json(&body).send().await.map_err(ApiError::from)?; + let request_id = request_id_from_headers(response.headers()); + + if !response.status().is_success() { + let status = response.status(); + let body_text = response.text().await.unwrap_or_default(); + return Err(ApiError::Api { + status, + error_type: Some("wham_error".to_string()), + message: Some(body_text.clone()), + request_id, + body: body_text, + retryable: status.is_server_error(), + suggested_action: suggested_action_for_status(status), + }); + } + + Ok(WhamMessageStream { + response, + buffer: Vec::new(), + pending: VecDeque::new(), + done: false, + state: WhamStreamState::new(request_id), + }) + } +} + +impl Provider for WhamClient { + type Stream = WhamMessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +struct WhamStreamState { + request_id: Option, + message_started: bool, + content_index: u32, + text_started: bool, + finished: bool, + usage: Option, +} + +impl WhamStreamState { + fn new(request_id: Option) -> Self { + Self { + request_id, + message_started: false, + content_index: 0, + text_started: false, + finished: false, + usage: None, + } + } + + fn ingest_event(&mut self, event: WhamSseEvent) -> Vec { + let mut events = Vec::new(); + + match event { + WhamSseEvent::Created { response } | WhamSseEvent::InProgress { response } => { + if !self.message_started { + self.message_started = true; + events.push(StreamEvent::MessageStart(MessageStartEvent { + message: MessageResponse { + id: response.id, + kind: "message".to_string(), + role: "assistant".to_string(), + content: Vec::new(), + model: response.model, + stop_reason: None, + stop_sequence: None, + usage: Usage::default(), + request_id: self.request_id.clone(), + }, + })); + } + } + WhamSseEvent::OutputTextDelta { content_index, delta } => { + if !self.text_started { + self.text_started = true; + self.content_index = content_index; + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: content_index, + content_block: OutputContentBlock::Text { text: String::new() }, + })); + } + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: content_index, + delta: ContentBlockDelta::TextDelta { text: delta }, + })); + } + WhamSseEvent::OutputTextDone { content_index } => { + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: content_index, + })); + self.text_started = false; + } + WhamSseEvent::Completed { usage } => { + self.usage = usage; + } + _ => {} + } + + events + } + + fn finish(&mut self) -> Vec { + let mut events = Vec::new(); + if self.text_started { + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: self.content_index, + })); + self.text_started = false; + } + if self.message_started && !self.finished { + self.finished = true; + events.push(StreamEvent::MessageDelta(MessageDeltaEvent { + delta: MessageDelta { + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + }, + usage: self.usage.clone().unwrap_or_default(), + })); + events.push(StreamEvent::MessageStop(MessageStopEvent {})); + } + events + } +} + +#[derive(Debug)] +pub struct WhamMessageStream { + response: reqwest::Response, + buffer: Vec, + pending: VecDeque, + done: bool, + state: WhamStreamState, +} + +impl WhamMessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.state.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + self.pending.extend(self.state.finish()); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + self.buffer.extend_from_slice(&chunk); + while let Some(frame) = next_sse_frame(&mut self.buffer) { + if let Some(event) = parse_wham_sse_frame(&frame)? { + self.pending.extend(self.state.ingest_event(event)); + } + } + } + None => { + self.done = true; + } + } + } + } +} + +fn next_sse_frame(buffer: &mut Vec) -> Option { + let separator = buffer + .windows(2) + .position(|window| window == b"\n\n") + .map(|position| (position, 2)) + .or_else(|| { + buffer + .windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|position| (position, 4)) + })?; + + let (position, separator_len) = separator; + let frame = buffer.drain(..position + separator_len).collect::>(); + let frame_len = frame.len().saturating_sub(separator_len); + Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) +} + +#[derive(Debug, Clone, Deserialize)] +struct WhamResponseStub { + id: String, + model: String, +} + +#[derive(Debug, Clone, Deserialize)] +struct WhamUsageStub { + input_tokens: u32, + output_tokens: u32, +} + +#[derive(Debug, Clone)] +enum WhamSseEvent { + Created { response: WhamResponseStub }, + InProgress { response: WhamResponseStub }, + OutputTextDelta { content_index: u32, delta: String }, + OutputTextDone { content_index: u32 }, + Completed { usage: Option }, + Other, +} + +fn parse_wham_sse_frame(frame: &str) -> Result, ApiError> { + let trimmed = frame.trim(); + if trimmed.is_empty() || trimmed.starts_with(':') { + return Ok(None); + } + + let mut event_type = String::new(); + let mut data_json = String::new(); + + for line in trimmed.lines() { + if let Some(et) = line.strip_prefix("event:") { + event_type = et.trim().to_string(); + } else if let Some(data) = line.strip_prefix("data:") { + data_json.push_str(data.trim_start()); + } + } + + if data_json.is_empty() { + return Ok(None); + } + + let data: serde_json::Value = serde_json::from_str(&data_json) + .map_err(|e| ApiError::json_deserialize("OpenAI WHAM", "", &data_json, e))?; + + let event = match event_type.as_str() { + "response.created" => WhamSseEvent::Created { + response: serde_json::from_value(data.get("response").cloned().unwrap_or_default()) + .unwrap_or(WhamResponseStub { id: String::new(), model: String::new() }), + }, + "response.in_progress" => WhamSseEvent::InProgress { + response: serde_json::from_value(data.get("response").cloned().unwrap_or_default()) + .unwrap_or(WhamResponseStub { id: String::new(), model: String::new() }), + }, + "response.output_text.delta" => WhamSseEvent::OutputTextDelta { + content_index: data["content_index"].as_u64().unwrap_or(0) as u32, + delta: data["delta"].as_str().unwrap_or("").to_string(), + }, + "response.output_text.done" => WhamSseEvent::OutputTextDone { + content_index: data["content_index"].as_u64().unwrap_or(0) as u32, + }, + "response.completed" => { + let usage = data.get("response").and_then(|r| r.get("usage")).map(|u| Usage { + input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32, + output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }); + WhamSseEvent::Completed { usage } + } + _ => WhamSseEvent::Other, + }; + + Ok(Some(event)) +} + +fn responses_endpoint(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if trimmed.ends_with("/responses") { + trimmed.to_string() + } else { + format!("{trimmed}/responses") + } +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|v| v.to_str().ok()) + .map(String::from) +} + +fn suggested_action_for_status(status: reqwest::StatusCode) -> Option { + match status.as_u16() { + 401 => Some("OAuth token may be expired. Try re-authenticating with `claw auth login openai`".to_string()), + 403 => Some("Verify ChatGPT subscription is active (Plus/Pro required)".to_string()), + 429 => Some("Wait a moment before retrying; consider reducing request rate".to_string()), + _ => None, + } +} + +/// Build a Responses API request body from our internal MessageRequest. +fn build_responses_request(request: &MessageRequest) -> Value { + // WHAM backend requires streaming mode. + let mut body = json!({ + "model": request.model, + "store": false, + "stream": true, + }); + + if let Some(ref system) = request.system { + body["instructions"] = json!(system); + } + + // Convert messages to Responses API `input` format + let input: Vec = request + .messages + .iter() + .map(|msg| { + let content_blocks: Vec = msg + .content + .iter() + .map(|block| match block { + InputContentBlock::Text { text } => { + json!({"type": "input_text", "text": text}) + } + InputContentBlock::ToolUse { id, name, input } => { + if msg.role == "assistant" { + json!({ + "type": "tool_use", + "id": id, + "name": name, + "input": input, + }) + } else { + json!({"type": "input_text", "text": "[tool input]"}) + } + } + InputContentBlock::ToolResult { tool_use_id, content, is_error } => { + let text = content + .iter() + .map(|c| match c { + crate::types::ToolResultContentBlock::Text { text } => text.clone(), + crate::types::ToolResultContentBlock::Json { value } => { + value.to_string() + } + }) + .collect::>() + .join("\n"); + json!({ + "type": "input_text", + "text": format!( + "[tool result {} {}]\n{}", + tool_use_id, + if *is_error { "(error)" } else { "" }, + text + ), + }) + } + }) + .collect(); + json!({"role": msg.role, "content": content_blocks}) + }) + .collect(); + + body["input"] = json!(input); + + // Note: WHAM backend does not support `max_output_tokens`. + // if request.max_tokens > 0 { + // body["max_output_tokens"] = json!(request.max_tokens); + // } + + if let Some(temp) = request.temperature { + body["temperature"] = json!(temp); + } + if let Some(top_p) = request.top_p { + body["top_p"] = json!(top_p); + } + + body +} + +/// Responses API response shape. +#[derive(Debug, Clone, Deserialize)] +struct ResponsesResponse { + id: String, + model: String, + output: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ResponsesOutputItem { + Message { + id: String, + role: String, + content: Vec, + }, + #[serde(other)] + Other, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ResponsesContentItem { + OutputText { text: String }, + #[serde(other)] + Other, +} + +#[derive(Debug, Clone, Deserialize, Default)] +struct ResponsesUsage { + input_tokens: u32, + output_tokens: u32, +} + +fn convert_to_message_response( + wham: ResponsesResponse, + request_id: Option, +) -> MessageResponse { + let mut content = Vec::new(); + for item in &wham.output { + if let ResponsesOutputItem::Message { content: blocks, .. } = item { + for block in blocks { + if let ResponsesContentItem::OutputText { text } = block { + content.push(OutputContentBlock::Text { text: text.clone() }); + } + } + } + } + + let usage = wham.usage.map(|u| Usage { + input_tokens: u.input_tokens, + output_tokens: u.output_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }).unwrap_or_default(); + + MessageResponse { + id: wham.id, + kind: "message".to_string(), + role: "assistant".to_string(), + content, + model: wham.model, + stop_reason: None, + stop_sequence: None, + usage, + request_id, + } +} + +fn response_to_stream_events(response: MessageResponse) -> Vec { + let mut events = Vec::new(); + + events.push(StreamEvent::MessageStart(MessageStartEvent { + message: response.clone(), + })); + + for (index, block) in response.content.iter().enumerate() { + let block_index = index as u32; + match block { + OutputContentBlock::Text { text } => { + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: block_index, + content_block: OutputContentBlock::Text { text: String::new() }, + })); + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: block_index, + delta: ContentBlockDelta::TextDelta { text: text.clone() }, + })); + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + OutputContentBlock::ToolUse { id, name, input } => { + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: block_index, + content_block: OutputContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: input.clone(), + }, + })); + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: block_index, + delta: ContentBlockDelta::InputJsonDelta { + partial_json: input.to_string(), + }, + })); + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + OutputContentBlock::Thinking { thinking, .. } => { + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: block_index, + content_block: OutputContentBlock::Text { text: String::new() }, + })); + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: block_index, + delta: ContentBlockDelta::ThinkingDelta { thinking: thinking.clone() }, + })); + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + OutputContentBlock::RedactedThinking { .. } => { + // Skip redacted thinking blocks in stream replay + } + } + } + + events.push(StreamEvent::MessageDelta(MessageDeltaEvent { + delta: MessageDelta { + stop_reason: response.stop_reason.clone(), + stop_sequence: response.stop_sequence.clone(), + }, + usage: response.usage.clone(), + })); + + events +} diff --git a/rust/crates/runtime/Cargo.toml b/rust/crates/runtime/Cargo.toml index 382e38fc29..1357edd98b 100644 --- a/rust/crates/runtime/Cargo.toml +++ b/rust/crates/runtime/Cargo.toml @@ -6,6 +6,7 @@ license.workspace = true publish.workspace = true [dependencies] +base64 = "0.22" sha2 = "0.10" glob = "0.3" plugins = { path = "../plugins" } diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index 6bafc11abe..126a067fe5 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -113,12 +113,13 @@ pub use mcp_stdio::{ }; pub use oauth::{ clear_oauth_credentials, clear_provider_oauth, code_challenge_s256, credentials_path, - generate_pkce_pair, generate_state, load_oauth_credentials, load_provider_oauth, - loopback_redirect_uri, loopback_redirect_uri_with_path, open_browser, parse_oauth_callback_query, - parse_oauth_callback_request_target, poll_device_token, run_oauth_callback_server, - save_oauth_credentials, save_provider_oauth, DeviceAuthRequest, DeviceAuthResponse, - OAuthAuthorizationRequest, OAuthCallbackParams, OAuthCallbackResult, OAuthRefreshRequest, - OAuthTokenExchangeRequest, OAuthTokenSet, PkceChallengeMethod, PkceCodePair, + extract_chatgpt_account_id, generate_pkce_pair, generate_state, load_oauth_credentials, + load_provider_oauth, loopback_redirect_uri, loopback_redirect_uri_with_path, open_browser, + parse_oauth_callback_query, parse_oauth_callback_request_target, poll_device_token, + refresh_oauth_token, run_oauth_callback_server, save_oauth_credentials, save_provider_oauth, + DeviceAuthRequest, DeviceAuthResponse, OAuthAuthorizationRequest, OAuthCallbackParams, + OAuthCallbackResult, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, + PkceChallengeMethod, PkceCodePair, }; pub use permissions::{ PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy, diff --git a/rust/crates/runtime/src/oauth.rs b/rust/crates/runtime/src/oauth.rs index 21794e96e3..92477fc602 100644 --- a/rust/crates/runtime/src/oauth.rs +++ b/rust/crates/runtime/src/oauth.rs @@ -16,6 +16,10 @@ pub struct OAuthTokenSet { pub refresh_token: Option, pub expires_at: Option, pub scopes: Vec, + /// OpenID token from auth.openai.com, needed to extract `chatgpt_account_id` + /// for the WHAM backend (`ChatGPT-Account-Id` header). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id_token: Option, } /// PKCE verifier/challenge pair generated for an OAuth authorization flow. @@ -93,6 +97,8 @@ struct StoredOAuthCredentials { expires_at: Option, #[serde(default)] scopes: Vec, + #[serde(default)] + id_token: Option, } impl From for StoredOAuthCredentials { @@ -102,6 +108,7 @@ impl From for StoredOAuthCredentials { refresh_token: value.refresh_token, expires_at: value.expires_at, scopes: value.scopes, + id_token: value.id_token, } } } @@ -113,6 +120,7 @@ impl From for OAuthTokenSet { refresh_token: value.refresh_token, expires_at: value.expires_at, scopes: value.scopes, + id_token: value.id_token, } } } @@ -615,11 +623,13 @@ pub async fn poll_device_token( .as_str() .map(|s| s.split(' ').map(String::from).collect()) .unwrap_or_default(); + let id_token = token["id_token"].as_str().map(String::from); return Ok(Some(OAuthTokenSet { access_token, refresh_token, expires_at, scopes, + id_token, })); } @@ -763,6 +773,110 @@ fn base64url_encode(bytes: &[u8]) -> String { output } +/// Extract `chatgpt_account_id` from a JWT payload (id_token or access_token). +/// No signature verification — just base64-decode the payload section. +pub fn extract_chatgpt_account_id(token: &str) -> Option { + let payload_b64 = token.split('.').nth(1)?; + let payload_json = base64_decode_urlsafe(payload_b64).ok()?; + let payload: serde_json::Value = serde_json::from_slice(&payload_json).ok()?; + + // 3-level fallback, matching Codex CLI behaviour: + // 1. top-level `chatgpt_account_id` + if let Some(id) = payload.get("chatgpt_account_id").and_then(|v| v.as_str()) { + return Some(id.to_string()); + } + // 2. `https://api.openai.com/auth` -> `chatgpt_account_id` + if let Some(auth) = payload.get("https://api.openai.com/auth").and_then(|v| v.as_object()) { + if let Some(id) = auth.get("chatgpt_account_id").and_then(|v| v.as_str()) { + return Some(id.to_string()); + } + } + // 3. `organizations[0].id` + if let Some(orgs) = payload.get("organizations").and_then(|v| v.as_array()) { + if let Some(first) = orgs.first() { + if let Some(id) = first.get("id").and_then(|v| v.as_str()) { + return Some(id.to_string()); + } + } + } + None +} + +fn base64_decode_urlsafe(input: &str) -> Result, String> { + let padded = match input.len() % 4 { + 0 => input.to_string(), + 2 => format!("{input}=="), + 3 => format!("{input}="), + _ => return Err("invalid base64url length".to_string()), + }; + let standard = padded.replace('-', "+").replace('_', "/"); + base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &standard) + .map_err(|e| e.to_string()) +} + +/// Refresh an OAuth token set using the refresh_token grant. +/// Returns a new token set on success. +pub async fn refresh_oauth_token( + client: &reqwest::Client, + token_url: &str, + client_id: &str, + refresh_token: &str, +) -> io::Result { + let params = [ + ("grant_type", "refresh_token"), + ("client_id", client_id), + ("refresh_token", refresh_token), + ]; + + let response = client + .post(token_url) + .form(¶ms) + .send() + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("HTTP request failed: {e}")))?; + + let status = response.status(); + let body = response.text().await.map_err(|e| { + io::Error::new(io::ErrorKind::Other, format!("Failed to read response body: {e}")) + })?; + + if !status.is_success() { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("Token refresh failed ({status}): {body}"), + )); + } + + let token: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let access_token = token["access_token"] + .as_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "access_token missing from refresh response"))? + .to_string(); + let refresh_token = token["refresh_token"].as_str().map(String::from); + let id_token = token["id_token"].as_str().map(String::from); + let expires_at = token["expires_in"].as_u64().map(|secs| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + secs + }); + let scopes = token["scope"] + .as_str() + .map(|s| s.split(' ').map(String::from).collect()) + .unwrap_or_default(); + + Ok(OAuthTokenSet { + access_token, + refresh_token, + expires_at, + scopes, + id_token, + }) +} + fn percent_encode(value: &str) -> String { let mut encoded = String::new(); for byte in value.bytes() { @@ -920,6 +1034,7 @@ mod tests { refresh_token: Some("refresh-token".to_string()), expires_at: Some(123), scopes: vec!["scope:a".to_string()], + id_token: None, }; save_oauth_credentials(&token_set).expect("save credentials"); assert_eq!( @@ -969,12 +1084,14 @@ mod tests { refresh_token: Some("openai-refresh".to_string()), expires_at: Some(1000), scopes: vec!["openid".to_string()], + id_token: None, }; let moonshot_tokens = OAuthTokenSet { access_token: "moonshot-access".to_string(), refresh_token: None, expires_at: Some(2000), scopes: vec!["profile".to_string()], + id_token: None, }; save_provider_oauth("openai", &openai_tokens).expect("save openai"); @@ -1026,6 +1143,7 @@ mod tests { refresh_token: Some("legacy-refresh".to_string()), expires_at: Some(999), scopes: vec!["org:read".to_string()], + id_token: None, }; save_oauth_credentials(&legacy).expect("save legacy"); @@ -1034,6 +1152,7 @@ mod tests { refresh_token: None, expires_at: Some(888), scopes: vec!["user:read".to_string()], + id_token: None, }; save_provider_oauth("openai", &provider).expect("save provider"); diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index afa803661d..a841516f2e 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -1552,24 +1552,21 @@ fn configured_provider_for_model( ) .into()); } - let api_key = if let Some(api_key) = provider.api_key().filter(|value| !value.is_empty()) { - api_key.to_string() + let (api_key, base_url) = if let Some(api_key) = provider.api_key().filter(|value| !value.is_empty()) { + (api_key.to_string(), provider.base_url().to_string()) } else if let Some(env_name) = provider.api_key_env() { if let Ok(key) = env::var(env_name) { - key - } else if provider_name != "openai" { + (key, provider.base_url().to_string()) + } else if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_name) { // Fall back to saved OAuth bearer token when env var is unset. - // Skip for "openai": OAuth tokens from auth.openai.com are - // ChatGPT/WHAM-backend tokens, NOT Platform API tokens. They - // return 401 Unauthorized on api.openai.com/v1. - if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_name) { - token_set.access_token + // For OpenAI, OAuth tokens are WHAM-backend tokens, not Platform API tokens. + // Override the base URL to the WHAM backend. + let base_url = if provider_name == "openai" { + api::DEFAULT_WHAM_BASE_URL.to_string() } else { - return Err(format!( - "model provider '{provider_name}' requires env var {env_name}" - ) - .into()); - } + provider.base_url().to_string() + }; + (token_set.access_token, base_url) } else { return Err(format!( "model provider '{provider_name}' requires env var {env_name}" @@ -1585,7 +1582,7 @@ fn configured_provider_for_model( wire_model: wire_model.to_string(), provider_type: provider.provider_type().to_string(), api_key, - base_url: provider.base_url().to_string(), + base_url, })) } @@ -8206,10 +8203,34 @@ impl AnthropicRuntimeClient { .with_prompt_cache(PromptCache::new(session_id)) } "openai-compatible" | "openai" => { - ApiProviderClient::from_openai_compatible_profile( - provider.api_key, - provider.base_url, - ) + // Route to WhamClient when using ChatGPT OAuth (WHAM backend). + if provider.base_url.contains("backend-api/wham") { + if let Ok(Some(token_set)) = runtime::load_provider_oauth("openai") { + let account_id = token_set + .id_token + .as_deref() + .and_then(runtime::extract_chatgpt_account_id) + .or_else(|| { + runtime::extract_chatgpt_account_id(&token_set.access_token) + }); + ApiProviderClient::Wham(api::WhamClient::from_oauth_token_set( + token_set, + account_id, + "https://auth.openai.com/oauth/token", + "app_EMoamEEZ73f0CkXaXp7hrann", + )) + } else { + ApiProviderClient::from_openai_compatible_profile( + provider.api_key, + provider.base_url, + ) + } + } else { + ApiProviderClient::from_openai_compatible_profile( + provider.api_key, + provider.base_url, + ) + } } other => { return Err(format!("unsupported provider type: {other}").into()); @@ -8310,13 +8331,16 @@ const BUILTIN_PROVIDERS: &[BuiltinProvider] = &[ label: "OpenAI", env_var: "OPENAI_API_KEY", default_model: "gpt-4o", - // NOTE: OpenAI's OAuth (auth.openai.com, Codex CLI client_id) produces - // ChatGPT/WHAM-backend tokens, NOT Platform API tokens. These tokens - // authenticate your ChatGPT account, not your OpenAI Platform account. - // They work with chatgpt.com/backend-api but return 401 Unauthorized - // (Missing scopes: model.request) on api.openai.com/v1. - // OpenAI Platform API requires API keys (sk-...) only. - oauth: None, + oauth: Some(ProviderOAuthConfig { + client_id: "app_EMoamEEZ73f0CkXaXp7hrann", + callback_port: 1455, + redirect_path: "/auth/callback", + flow: OAuthFlowType::Pkce { + authorize_url: "https://auth.openai.com/oauth/authorize", + token_url: "https://auth.openai.com/oauth/token", + scopes: &["openid", "profile", "email", "offline_access"], + }, + }), }, BuiltinProvider { id: "xai", @@ -8334,11 +8358,7 @@ fn check_model_auth_available(model: &str) -> Result Result { + runtime::load_provider_oauth("openai").ok().flatten().is_some() + } "MOONSHOT_API_KEY" => { runtime::load_provider_oauth("moonshot").ok().flatten().is_some() } @@ -8391,6 +8412,7 @@ fn check_model_auth_available(model: &str) -> Result Date: Fri, 1 May 2026 18:26:07 -0300 Subject: [PATCH 14/15] feat: custom provider OAuth refresh + model command + test isolation - Add auto token refresh for custom providers using OAuth - Extend ConfiguredModelProvider with oauth_token_set, token_url, client_id - Lookup LOGIN_PROVIDER_TEMPLATES for refresh config when custom provider falls back to saved OAuth tokens - Add ProviderClient::from_openai_compatible_oauth() constructor - OpenAiCompatClient now supports from_oauth_token_set with auto-refresh - Add `claw model [MODEL]` CLI command - Lists all available models (built-in + templates + custom providers) - Sets default model in ~/.claw/settings.json when given a model name - Accepts both `model` and `models` aliases - Fix test isolation for users with workspace-write permission config - Set RUSTY_CLAUDE_PERMISSION_MODE=danger-full-access in 10 tests - Prevents user ~/.claw/settings.json from leaking into test assertions --- rust/crates/api/src/client.rs | 38 ++- .../crates/api/src/providers/openai_compat.rs | 148 +++++++++- rust/crates/rusty-claude-cli/src/main.rs | 266 +++++++++++++++--- 3 files changed, 415 insertions(+), 37 deletions(-) diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 388d60e6a4..f2f7385cf0 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -68,9 +68,20 @@ impl ProviderClient { ))); } } - Ok(Self::OpenAi(OpenAiCompatClient::from_env_or_oauth( - config, provider_id, - )?)) + if provider_id == "moonshot" { + Ok(Self::OpenAi( + OpenAiCompatClient::from_env_or_oauth_with_refresh( + config, + provider_id, + "https://auth.kimi.com/api/oauth/token", + "17e5f671-d194-4dfb-9706-5516cb48c098", + )?, + )) + } else { + Ok(Self::OpenAi(OpenAiCompatClient::from_env_or_oauth( + config, provider_id, + )?)) + } } else { Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?)) } @@ -88,6 +99,27 @@ impl ProviderClient { ) } + /// Create an OpenAI-compatible client from an OAuth token set with automatic refresh. + /// Used for custom providers that authenticate via OAuth rather than API keys. + #[must_use] + pub fn from_openai_compatible_oauth( + base_url: impl Into, + token_set: runtime::OAuthTokenSet, + token_url: impl Into, + client_id: impl Into, + ) -> Self { + Self::OpenAi( + OpenAiCompatClient::from_oauth_token_set( + token_set, + OpenAiCompatConfig::openai(), + token_url, + client_id, + "custom", + ) + .with_base_url(base_url), + ) + } + #[must_use] pub fn from_anthropic_compatible_profile( api_key: impl Into, diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index 2d1eac6e61..5647569d46 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -112,6 +112,16 @@ impl OpenAiCompatConfig { } } +#[derive(Debug, Clone)] +struct OpenAiCompatOAuthState { + access_token: String, + refresh_token: Option, + expires_at: Option, + token_url: String, + client_id: String, + provider_id: String, +} + #[derive(Debug, Clone)] pub struct OpenAiCompatClient { http: reqwest::Client, @@ -121,6 +131,7 @@ pub struct OpenAiCompatClient { max_retries: u32, initial_backoff: Duration, max_backoff: Duration, + oauth_state: Option>>, } impl OpenAiCompatClient { @@ -142,6 +153,7 @@ impl OpenAiCompatClient { max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, + oauth_state: None, } } @@ -157,13 +169,45 @@ impl OpenAiCompatClient { /// Create a client using an OAuth access token instead of an API key. /// The token is sent as `Authorization: Bearer {token}`. + /// No automatic refresh is performed; use [`from_oauth_token_set`] for refresh support. #[must_use] pub fn from_oauth_token(token: impl Into, config: OpenAiCompatConfig) -> Self { Self::new(token, config) } + /// Create a client from a full OAuth token set with automatic refresh support. + #[must_use] + pub fn from_oauth_token_set( + token_set: runtime::OAuthTokenSet, + config: OpenAiCompatConfig, + token_url: impl Into, + client_id: impl Into, + provider_id: impl Into, + ) -> Self { + Self { + http: build_http_client_or_default(), + api_key: token_set.access_token.clone(), + config, + base_url: read_base_url(config), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + oauth_state: Some(std::sync::Arc::new(std::sync::Mutex::new( + OpenAiCompatOAuthState { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + token_url: token_url.into(), + client_id: client_id.into(), + provider_id: provider_id.into(), + }, + ))), + } + } + /// Try env var first, then fall back to saved OAuth token for the provider. /// `provider_id` is the key used in `~/.claw/credentials.json` under `oauth_providers`. + /// No automatic refresh is performed; use [`from_env_or_oauth_with_refresh`] for refresh support. pub fn from_env_or_oauth( config: OpenAiCompatConfig, provider_id: &str, @@ -180,6 +224,32 @@ impl OpenAiCompatClient { )) } + /// Try env var first, then fall back to saved OAuth token with automatic refresh. + /// `token_url` and `client_id` are used to refresh expired access tokens. + pub fn from_env_or_oauth_with_refresh( + config: OpenAiCompatConfig, + provider_id: &str, + token_url: &str, + client_id: &str, + ) -> Result { + if let Some(api_key) = read_env_non_empty(config.api_key_env)? { + return Ok(Self::new(api_key, config)); + } + if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_id) { + return Ok(Self::from_oauth_token_set( + token_set, + config, + token_url, + client_id, + provider_id, + )); + } + Err(ApiError::missing_credentials( + config.provider_name, + config.credential_env_vars(), + )) + } + #[must_use] pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = base_url.into(); @@ -303,6 +373,69 @@ impl OpenAiCompatClient { }) } + async fn ensure_token_valid(&self) -> Result<(), ApiError> { + let Some(oauth_state) = &self.oauth_state else { + return Ok(()); + }; + + let needs_refresh = { + let state = oauth_state.lock().map_err(|e| { + ApiError::Auth(format!("OAuth state mutex poisoned: {e}")) + })?; + match state.expires_at { + None => false, + Some(expires_at) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + now + 60 >= expires_at + } + } + }; + + if !needs_refresh { + return Ok(()); + } + + let (refresh_token, token_url, client_id, provider_id) = { + let state = oauth_state.lock().map_err(|e| { + ApiError::Auth(format!("OAuth state mutex poisoned: {e}")) + })?; + let refresh = state.refresh_token.clone().ok_or_else(|| { + ApiError::Auth("OAuth token expired and no refresh token available".to_string()) + })?; + ( + refresh, + state.token_url.clone(), + state.client_id.clone(), + state.provider_id.clone(), + ) + }; + + let new_token = runtime::refresh_oauth_token( + &self.http, + &token_url, + &client_id, + &refresh_token, + ) + .await + .map_err(|e| ApiError::Auth(format!("OAuth token refresh failed: {e}")))?; + + { + let mut state = oauth_state.lock().map_err(|e| { + ApiError::Auth(format!("OAuth state mutex poisoned: {e}")) + })?; + state.access_token = new_token.access_token.clone(); + state.refresh_token = new_token.refresh_token.clone(); + state.expires_at = new_token.expires_at; + } + + let _ = runtime::save_provider_oauth(&provider_id, &new_token); + + Ok(()) + } + async fn send_raw_request( &self, request: &MessageRequest, @@ -310,11 +443,24 @@ impl OpenAiCompatClient { // Pre-flight check: verify request body size against provider limits check_request_body_size(request, self.config())?; + self.ensure_token_valid().await?; + + let access_token = { + if let Some(oauth_state) = &self.oauth_state { + let state = oauth_state.lock().map_err(|e| { + ApiError::Auth(format!("OAuth state mutex poisoned: {e}")) + })?; + state.access_token.clone() + } else { + self.api_key.clone() + } + }; + let request_url = chat_completions_endpoint(&self.base_url); self.http .post(&request_url) .header("content-type", "application/json") - .bearer_auth(&self.api_key) + .bearer_auth(&access_token) .json(&build_chat_completion_request(request, self.config())) .send() .await diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index a841516f2e..1d2199a5ab 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -103,6 +103,13 @@ struct ConfiguredModelProvider { provider_type: String, api_key: String, base_url: String, + /// Full OAuth token set when auth came from saved OAuth credentials. + /// Enables automatic token refresh for custom providers. + oauth_token_set: Option, + /// OAuth token URL for refreshing the access token. + oauth_token_url: Option, + /// OAuth client ID for refreshing the access token. + oauth_client_id: Option, } impl ModelProvenance { @@ -488,6 +495,7 @@ fn run() -> Result<(), Box> { allow_broad_cwd, )?, CliAction::Auth { provider } => run_auth_command(provider.as_deref())?, + CliAction::Model { name } => run_model_command(name.as_deref())?, CliAction::HelpTopic(topic) => print_help_topic(topic), CliAction::Help { output_format } => print_help(output_format)?, } @@ -595,6 +603,9 @@ enum CliAction { Auth { provider: Option, }, + Model { + name: Option, + }, HelpTopic(LocalHelpTopic), // prompt-mode formatting is only supported for non-interactive runs Help { @@ -992,6 +1003,20 @@ fn parse_args(args: &[String]) -> Result { } Ok(CliAction::Auth { provider }) } + "model" | "models" => { + if rest.len() == 2 && is_help_flag(&rest[1]) { + return Ok(CliAction::Help { output_format }); + } + let name = rest.get(1).cloned(); + if rest.len() > 2 { + return Err(format!( + "unexpected extra arguments after `claw model {}`: {}", + name.as_deref().unwrap_or(""), + rest[2..].join(" ") + )); + } + Ok(CliAction::Model { name }) + } "init" => Ok(CliAction::Init { output_format }), "export" => parse_export_args(&rest[1..], output_format), "prompt" => { @@ -1151,11 +1176,11 @@ fn parse_single_word_command_alias( "sandbox" => Some(Ok(CliAction::Sandbox { output_format })), "doctor" => Some(Ok(CliAction::Doctor { output_format })), "state" => Some(Ok(CliAction::State { output_format })), - // #146: let `config` and `diff` fall through to parse_subcommand + // #146: let `config`, `diff`, and `model` fall through to parse_subcommand // where they are wired as pure-local introspection, instead of // producing the "is a slash command" guidance. Zero-arg cases // reach parse_subcommand too via this None. - "config" | "diff" => None, + "config" | "diff" | "model" | "models" => None, other => bare_slash_command_guidance(other).map(Err), } } @@ -1552,37 +1577,81 @@ fn configured_provider_for_model( ) .into()); } - let (api_key, base_url) = if let Some(api_key) = provider.api_key().filter(|value| !value.is_empty()) { - (api_key.to_string(), provider.base_url().to_string()) - } else if let Some(env_name) = provider.api_key_env() { - if let Ok(key) = env::var(env_name) { - (key, provider.base_url().to_string()) - } else if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_name) { - // Fall back to saved OAuth bearer token when env var is unset. - // For OpenAI, OAuth tokens are WHAM-backend tokens, not Platform API tokens. - // Override the base URL to the WHAM backend. - let base_url = if provider_name == "openai" { - api::DEFAULT_WHAM_BASE_URL.to_string() + let (api_key, base_url, oauth_token_set, oauth_token_url, oauth_client_id) = + if let Some(api_key) = provider.api_key().filter(|value| !value.is_empty()) { + ( + api_key.to_string(), + provider.base_url().to_string(), + None, + None, + None, + ) + } else if let Some(env_name) = provider.api_key_env() { + if let Ok(key) = env::var(env_name) { + ( + key, + provider.base_url().to_string(), + None, + None, + None, + ) + } else if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_name) { + // Fall back to saved OAuth bearer token when env var is unset. + // For OpenAI, OAuth tokens are WHAM-backend tokens, not Platform API tokens. + // Override the base URL to the WHAM backend. + let base_url = if provider_name == "openai" { + api::DEFAULT_WHAM_BASE_URL.to_string() + } else { + provider.base_url().to_string() + }; + // Look up OAuth refresh config from login templates so custom + // providers that match built-in templates get automatic refresh. + let template_oauth = LOGIN_PROVIDER_TEMPLATES + .iter() + .find(|t| t.id == provider_name) + .and_then(|t| t.oauth.as_ref()) + .or_else(|| { + LOGIN_PROVIDER_TEMPLATES + .iter() + .find(|t| t.base_url == provider.base_url()) + .and_then(|t| t.oauth.as_ref()) + }); + let (token_url, client_id) = match template_oauth { + Some(oauth) => { + let token_url = match oauth.flow { + OAuthFlowType::Device { token_url, .. } => token_url, + OAuthFlowType::Pkce { token_url, .. } => token_url, + }; + (Some(token_url.to_string()), Some(oauth.client_id.to_string())) + } + None => (None, None), + }; + ( + token_set.access_token.clone(), + base_url, + Some(token_set), + token_url, + client_id, + ) } else { - provider.base_url().to_string() - }; - (token_set.access_token, base_url) + return Err(format!( + "model provider '{provider_name}' requires env var {env_name}" + ) + .into()); + } } else { - return Err(format!( - "model provider '{provider_name}' requires env var {env_name}" - ) - .into()); - } - } else { - return Err( - format!("model provider '{provider_name}' requires apiKeyEnv or apiKey").into(), - ); - }; + return Err( + format!("model provider '{provider_name}' requires apiKeyEnv or apiKey").into(), + ); + }; Ok(Some(ConfiguredModelProvider { wire_model: wire_model.to_string(), provider_type: provider.provider_type().to_string(), api_key, base_url, + oauth_token_set, + oauth_token_url, + oauth_client_id, })) } @@ -8225,6 +8294,16 @@ impl AnthropicRuntimeClient { provider.base_url, ) } + } else if let (Some(token_set), Some(token_url), Some(client_id)) = + (provider.oauth_token_set, provider.oauth_token_url, provider.oauth_client_id) + { + // Custom provider with OAuth: use auto-refreshing client. + ApiProviderClient::from_openai_compatible_oauth( + provider.base_url, + token_set, + token_url, + client_id, + ) } else { ApiProviderClient::from_openai_compatible_profile( provider.api_key, @@ -8562,6 +8641,93 @@ fn run_provider_welcome( Err("Invalid selection".into()) } +fn run_model_command(name: Option<&str>) -> Result<(), Box> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let config = loader.load()?; + + if let Some(name) = name { + let trimmed = name.trim(); + if trimmed.is_empty() { + return Err("model name cannot be empty".into()); + } + validate_model_syntax(trimmed)?; + let resolved = resolve_model_alias_with_config(trimmed); + + // Write to user settings.json + let settings_path = loader.config_home().join("settings.json"); + let mut settings: serde_json::Value = if settings_path.exists() { + let contents = fs::read_to_string(&settings_path)?; + serde_json::from_str(&contents).unwrap_or_else(|_| serde_json::json!({})) + } else { + serde_json::json!({}) + }; + if let Some(obj) = settings.as_object_mut() { + obj.insert("model".to_string(), serde_json::Value::String(resolved.clone())); + } + fs::create_dir_all(loader.config_home())?; + fs::write(&settings_path, serde_json::to_string_pretty(&settings)?)?; + + println!("Default model set to: {}", resolved); + return Ok(()); + } + + // No name provided: show current model and list available models + let env_model = env::var("ANTHROPIC_MODEL").ok(); + let current_model = config.model() + .or_else(|| env_model.as_deref()) + .unwrap_or(DEFAULT_MODEL); + + println!("Current model: {}", current_model); + println!(); + println!("Available models:"); + + // Built-in providers + println!(" Built-in:"); + for builtin in BUILTIN_PROVIDERS { + let label = match builtin.id { + "anthropic" => "Anthropic (Claude)", + "openai" => "OpenAI", + "xai" => "xAI (Grok)", + other => other, + }; + println!(" {}:", label); + println!(" {}/{}", builtin.id, builtin.default_model); + } + + // Login templates + println!(" Additional providers:"); + for template in LOGIN_PROVIDER_TEMPLATES { + println!(" {}:", template.label); + for model in template.models { + println!(" {}/{}", template.id, model); + } + } + + // Custom providers from config + let custom_providers = config.model_providers(); + if !custom_providers.is_empty() { + println!(" Custom providers:"); + for (name, provider) in custom_providers { + println!(" {}:", name); + if !provider.models().is_empty() { + for model in provider.models() { + println!(" {}/{}", name, model); + } + } else if let Some(default) = provider.default_model() { + println!(" {}/{}", name, default); + } + } + } + + println!(); + println!("Usage:"); + println!(" claw model Set default model"); + println!(" claw --model prompt Use model for one prompt"); + + Ok(()) +} + fn run_auth_command(provider: Option<&str>) -> Result<(), Box> { if let Some(provider_id) = provider { // Try built-in provider first @@ -10365,6 +10531,11 @@ fn print_help_to(out: &mut impl Write) -> io::Result<()> { out, " Authenticate with a model provider (anthropic, openai, xai)" )?; + writeln!(out, " claw model [MODEL]")?; + writeln!( + out, + " Show available models or set the default model" + )?; writeln!( out, " claw export [PATH] [--session SESSION] [--output PATH]" @@ -10804,7 +10975,7 @@ mod tests { #[test] fn defaults_to_repl_when_no_args() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); assert_eq!( parse_args(&[]).expect("args should parse"), CliAction::Repl { @@ -10931,7 +11102,7 @@ mod tests { #[test] fn parses_prompt_subcommand() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let args = vec![ "prompt".to_string(), "hello".to_string(), @@ -10975,6 +11146,32 @@ mod tests { ); } + #[test] + fn parse_args_model_without_name() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["model".to_string()]).expect("args should parse"), + CliAction::Model { name: None } + ); + assert_eq!( + parse_args(&["models".to_string()]).expect("args should parse"), + CliAction::Model { name: None } + ); + } + + #[test] + fn parse_args_model_with_name() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["model".to_string(), "moonshot/kimi-k2.6".to_string()]).expect("args should parse"), + CliAction::Model { + name: Some("moonshot/kimi-k2.6".to_string()), + } + ); + } + #[test] fn merge_prompt_with_stdin_returns_prompt_unchanged_when_no_pipe() { // given @@ -11042,7 +11239,7 @@ mod tests { #[test] fn parses_bare_prompt_and_json_output_flag() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let args = vec![ "--output-format=json".to_string(), "--model".to_string(), @@ -11070,7 +11267,7 @@ mod tests { fn parses_compact_flag_for_prompt_mode() { // given a bare prompt invocation that includes the --compact flag let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let args = vec![ "--compact".to_string(), "summarize".to_string(), @@ -11117,7 +11314,7 @@ mod tests { #[test] fn resolves_model_aliases_in_args() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let args = vec![ "--model".to_string(), "opus".to_string(), @@ -11347,7 +11544,7 @@ mod tests { #[test] fn parses_allowed_tools_flags_with_aliases_and_lists() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let args = vec![ "--allowedTools".to_string(), "read,glob".to_string(), @@ -11991,7 +12188,7 @@ mod tests { #[test] fn parses_single_word_command_aliases_without_falling_back_to_prompt_mode() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); assert_eq!( parse_args(&["help".to_string()]).expect("help should parse"), CliAction::Help { @@ -12657,6 +12854,8 @@ mod tests { #[test] fn prompt_subcommand_allows_literal_typo_word() { + let _guard = env_lock(); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); assert_eq!( parse_args(&["prompt".to_string(), "doctorr".to_string()]) .expect("explicit prompt subcommand should allow literal typo word"), @@ -12680,6 +12879,7 @@ mod tests { // doesn't pick up a stale .claw/settings.json from other tests that // may have set `permissionMode: acceptEdits` in a shared cwd. let _guard = env_lock(); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let root = temp_dir(); let cwd = root.join("project"); std::fs::create_dir_all(&cwd).expect("project dir should exist"); From cdb910a98d6feada1e653e6d05cd86b5b32a3b3f Mon Sep 17 00:00:00 2001 From: Heitor Ramon Ribeiro Date: Fri, 1 May 2026 19:16:52 -0300 Subject: [PATCH 15/15] feat: Kimi For Coding support via User-Agent spoofing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Kimi For Coding API (api.kimi.com/coding/v1) restricts access to whitelisted coding agents (Claude Code, Kilo Code, etc.). It validates the User-Agent header case-sensitively — only lowercase variants like `claude-code/1.0` are accepted; `Claude-Code/1.0` is rejected. Changes: - Add optional `user_agent` field to OpenAiCompatClient - Add `with_user_agent()` builder method - Add `ProviderClient::with_user_agent()` to propagate header - In AnthropicRuntimeClient::new, detect api.kimi.com base URLs and automatically set User-Agent to `claude-code/0.1.0` - Restore Kimi model names (k2p5, k2p6, kimi-k2-thinking) in config; all are accepted by the API Also fixed minimax-coding-plan config: changed from anthropic-compatible (wrong) to openai-compatible with baseUrl https://api.minimax.io/v1. Tested live: - zai-coding-plan/glm-5.1 ✅ - minimax-coding-plan/MiniMax-M2.7 ✅ - kimi-for-coding/k2p6 ✅ --- rust/crates/api/src/client.rs | 10 ++++++++ .../crates/api/src/providers/openai_compat.rs | 22 +++++++++++++---- rust/crates/rusty-claude-cli/src/main.rs | 24 ++++++++++++++----- 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index f2f7385cf0..f2e13461c8 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -128,6 +128,16 @@ impl ProviderClient { Self::Anthropic(AnthropicClient::new(api_key).with_base_url(base_url)) } + /// Set a custom User-Agent header on the underlying OpenAI-compatible client. + /// No-op for Anthropic, xAI, or WHAM variants. + #[must_use] + pub fn with_user_agent(self, user_agent: impl Into) -> Self { + match self { + Self::OpenAi(client) => Self::OpenAi(client.with_user_agent(user_agent)), + other => other, + } + } + #[must_use] pub const fn provider_kind(&self) -> ProviderKind { match self { diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index 5647569d46..fd4de80d9d 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -132,6 +132,8 @@ pub struct OpenAiCompatClient { initial_backoff: Duration, max_backoff: Duration, oauth_state: Option>>, + /// Custom User-Agent header for providers that gate access by client identity. + user_agent: Option, } impl OpenAiCompatClient { @@ -154,6 +156,7 @@ impl OpenAiCompatClient { initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, oauth_state: None, + user_agent: None, } } @@ -202,6 +205,7 @@ impl OpenAiCompatClient { provider_id: provider_id.into(), }, ))), + user_agent: None, } } @@ -256,6 +260,12 @@ impl OpenAiCompatClient { self } + #[must_use] + pub fn with_user_agent(mut self, user_agent: impl Into) -> Self { + self.user_agent = Some(user_agent.into()); + self + } + #[must_use] pub fn with_retry_policy( mut self, @@ -457,14 +467,16 @@ impl OpenAiCompatClient { }; let request_url = chat_completions_endpoint(&self.base_url); - self.http + let mut req = self + .http .post(&request_url) .header("content-type", "application/json") .bearer_auth(&access_token) - .json(&build_chat_completion_request(request, self.config())) - .send() - .await - .map_err(ApiError::from) + .json(&build_chat_completion_request(request, self.config())); + if let Some(ref ua) = self.user_agent { + req = req.header("user-agent", ua); + } + req.send().await.map_err(ApiError::from) } fn backoff_for_attempt(&self, attempt: u32) -> Result { diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 1d2199a5ab..273a527785 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -8272,6 +8272,18 @@ impl AnthropicRuntimeClient { .with_prompt_cache(PromptCache::new(session_id)) } "openai-compatible" | "openai" => { + // Kimi For Coding requires a whitelisted User-Agent. + let user_agent = if provider.base_url.contains("api.kimi.com") { + Some("claude-code/0.1.0") + } else { + None + }; + let apply_ua = |client: ApiProviderClient| { + match user_agent { + Some(ua) => client.with_user_agent(ua), + None => client, + } + }; // Route to WhamClient when using ChatGPT OAuth (WHAM backend). if provider.base_url.contains("backend-api/wham") { if let Ok(Some(token_set)) = runtime::load_provider_oauth("openai") { @@ -8289,26 +8301,26 @@ impl AnthropicRuntimeClient { "app_EMoamEEZ73f0CkXaXp7hrann", )) } else { - ApiProviderClient::from_openai_compatible_profile( + apply_ua(ApiProviderClient::from_openai_compatible_profile( provider.api_key, provider.base_url, - ) + )) } } else if let (Some(token_set), Some(token_url), Some(client_id)) = (provider.oauth_token_set, provider.oauth_token_url, provider.oauth_client_id) { // Custom provider with OAuth: use auto-refreshing client. - ApiProviderClient::from_openai_compatible_oauth( + apply_ua(ApiProviderClient::from_openai_compatible_oauth( provider.base_url, token_set, token_url, client_id, - ) + )) } else { - ApiProviderClient::from_openai_compatible_profile( + apply_ua(ApiProviderClient::from_openai_compatible_profile( provider.api_key, provider.base_url, - ) + )) } } other => {