diff --git a/.gitignore b/.gitignore index 919ab8387f..6d3a0fc459 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ archive/ .claw/sessions/ .clawhip/ status-help.txt +.DS_Store diff --git a/USAGE.md b/USAGE.md index c8e7b09692..4d7d6e0c44 100644 --- a/USAGE.md +++ b/USAGE.md @@ -308,6 +308,147 @@ The OpenAI-compatible backend also serves as the gateway for **OpenRouter**, **O **Model-name prefix routing:** If a model name starts with `openai/`, `gpt-`, `qwen/`, or `qwen-`, the provider is selected by the prefix regardless of which env vars are set. This prevents accidental misrouting to Anthropic when multiple credentials exist in the environment. +### Configured compatible providers + +If you use several compatible providers, define named provider profiles in `settings.json` instead of changing `OPENAI_BASE_URL` before every run. Each profile gets its own protocol, base URL, credential env var, and model allow-list. + +Run `claw login` or `/login` to create these profiles interactively for Z.AI, MiniMax, OpenAI, Kimi, Moonshot, or a custom compatible endpoint. The wizard can store a pasted local token in `~/.claw/settings.json`, but `apiKeyEnv` is preferred when you can keep the secret in your shell environment: + +```json +{ + "model": "zai/glm-5.1", + "modelProviders": { + "zai": { + "type": "openai-compatible", + "baseUrl": "https://api.z.ai/api/paas/v4", + "apiKeyEnv": "Z_AI_API_KEY", + "models": [ + "glm-5.1", + "glm-5", + "glm-5-turbo", + "glm-4.7", + "glm-4.7-flashx", + "glm-4.7-flash", + "glm-4.6", + "glm-4.5", + "glm-4.5-x", + "glm-4.5-air", + "glm-4.5-airx", + "glm-4.5-flash", + "glm-4-32b-0414-128k" + ], + "defaultModel": "glm-5.1" + }, + "zai-coding-plan": { + "type": "openai-compatible", + "baseUrl": "https://api.z.ai/api/coding/paas/v4", + "apiKeyEnv": "Z_AI_API_KEY", + "models": [ + "glm-4.5-air", + "glm-4.7", + "glm-5-turbo", + "glm-5.1", + "glm-5v-turbo" + ], + "defaultModel": "glm-5.1" + }, + "minimax-coding-plan": { + "type": "anthropic-compatible", + "baseUrl": "https://api.minimax.io/anthropic/v1", + "apiKeyEnv": "MINIMAX_API_KEY", + "models": [ + "MiniMax-M2", + "MiniMax-M2.1", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed" + ], + "defaultModel": "MiniMax-M2.7-highspeed" + }, + "openai": { + "type": "openai-compatible", + "baseUrl": "https://api.openai.com/v1", + "apiKeyEnv": "OPENAI_API_KEY", + "models": [ + "gpt-5-codex", + "gpt-5.1-codex", + "gpt-5.1-codex-max", + "gpt-5.1-codex-mini", + "gpt-5.2", + "gpt-5.2-codex", + "gpt-5.3-codex", + "gpt-5.3-codex-spark", + "gpt-5.4", + "gpt-5.4-fast", + "gpt-5.4-mini", + "gpt-5.4-mini-fast", + "gpt-5.5", + "gpt-5.5-fast", + "gpt-5.5-pro" + ], + "defaultModel": "gpt-5.5" + }, + "kimi-for-coding": { + "type": "anthropic-compatible", + "baseUrl": "https://api.kimi.com/coding/v1", + "apiKeyEnv": "KIMI_API_KEY", + "models": [ + "k2p5", + "k2p6", + "kimi-k2-thinking" + ], + "defaultModel": "k2p6" + }, + "moonshot": { + "type": "openai-compatible", + "baseUrl": "https://api.moonshot.ai/v1", + "apiKeyEnv": "MOONSHOT_API_KEY", + "models": [ + "kimi-k2.6", + "kimi-k2.5", + "kimi-k2-0905-preview", + "kimi-k2-0711-preview", + "kimi-k2-turbo-preview", + "kimi-k2-thinking", + "kimi-k2-thinking-turbo", + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", + "moonshot-v1-8k-vision-preview", + "moonshot-v1-32k-vision-preview", + "moonshot-v1-128k-vision-preview" + ], + "defaultModel": "kimi-k2.6" + } + } +} +``` + +Use `/model provider/model` in the REPL to switch without restarting: + +```text +/model zai/glm-5.1 +/model zai-coding-plan/glm-5.1 +/model minimax-coding-plan/MiniMax-M2.7-highspeed +/model openai/gpt-5.5 +/model kimi-for-coding/k2p6 +/model moonshot/kimi-k2.6 +``` + +You can also use the provider name alone when it has `defaultModel` configured: + +```text +/model zai +/model zai-coding-plan +/model minimax-coding-plan +/model openai +/model kimi-for-coding +/model moonshot +``` + +Prefer `apiKeyEnv` so secrets stay out of source-controlled project settings. `apiKey` is supported for local-only files when an environment variable is not practical. + ### Tested models and aliases These are the models registered in the built-in alias table with known token limits: diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 740147e78e..d25dcbabc9 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1361,9 +1361,11 @@ dependencies = [ name = "runtime" version = "0.1.0" dependencies = [ + "base64", "glob", "plugins", "regex", + "reqwest", "serde", "serde_json", "sha2", @@ -1456,6 +1458,7 @@ dependencies = [ "mock-anthropic-service", "plugins", "pulldown-cmark", + "reqwest", "runtime", "rustyline", "serde", diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 6e68fd2e2c..f2e13461c8 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -2,6 +2,7 @@ use crate::error::ApiError; use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats}; use crate::providers::anthropic::{self, AnthropicClient, AuthSource}; use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig}; +use crate::providers::wham::{self, WhamClient}; use crate::providers::{self, ProviderKind}; use crate::types::{MessageRequest, MessageResponse, StreamEvent}; @@ -11,6 +12,8 @@ pub enum ProviderClient { Anthropic(AnthropicClient), Xai(OpenAiCompatClient), OpenAi(OpenAiCompatClient), + /// OpenAI WHAM backend (chatgpt.com/backend-api/wham) using ChatGPT OAuth tokens. + Wham(WhamClient), } impl ProviderClient { @@ -32,26 +35,115 @@ impl ProviderClient { OpenAiCompatConfig::xai(), )?)), ProviderKind::OpenAi => { - // DashScope models (qwen-*) also return ProviderKind::OpenAi because they - // speak the OpenAI wire format, but they need the DashScope config which - // reads DASHSCOPE_API_KEY and points at dashscope.aliyuncs.com. - let config = match providers::metadata_for_model(&resolved_model) { + // Use metadata_for_model for prefix-aware config selection. + // DashScope models (qwen-*, kimi-*) and Moonshot models (moonshot/*) + // all speak the OpenAI wire format but need different configs. + let (config, oauth_provider_id) = match providers::metadata_for_model(&resolved_model) { Some(meta) if meta.auth_env == "DASHSCOPE_API_KEY" => { - OpenAiCompatConfig::dashscope() + (OpenAiCompatConfig::dashscope(), None) } - _ => OpenAiCompatConfig::openai(), + Some(meta) if meta.auth_env == "MOONSHOT_API_KEY" => { + (OpenAiCompatConfig::moonshot(), Some("moonshot")) + } + _ => (OpenAiCompatConfig::openai(), Some("openai")), }; - Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?)) + // Try OAuth if the provider supports it and env var is not set + if let Some(provider_id) = oauth_provider_id { + if provider_id == "openai" { + // OpenAI OAuth tokens are WHAM-backend tokens (chatgpt.com/backend-api/wham), + // NOT Platform API tokens. Route to WhamClient when using OAuth. + if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_id) { + let account_id = token_set + .id_token + .as_deref() + .and_then(runtime::extract_chatgpt_account_id) + .or_else(|| { + runtime::extract_chatgpt_account_id(&token_set.access_token) + }); + return Ok(Self::Wham(WhamClient::from_oauth_token_set( + token_set, + account_id, + "https://auth.openai.com/oauth/token", + "app_EMoamEEZ73f0CkXaXp7hrann", + ))); + } + } + if provider_id == "moonshot" { + Ok(Self::OpenAi( + OpenAiCompatClient::from_env_or_oauth_with_refresh( + config, + provider_id, + "https://auth.kimi.com/api/oauth/token", + "17e5f671-d194-4dfb-9706-5516cb48c098", + )?, + )) + } else { + Ok(Self::OpenAi(OpenAiCompatClient::from_env_or_oauth( + config, provider_id, + )?)) + } + } else { + Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?)) + } } } } + #[must_use] + pub fn from_openai_compatible_profile( + api_key: impl Into, + base_url: impl Into, + ) -> Self { + Self::OpenAi( + OpenAiCompatClient::new(api_key, OpenAiCompatConfig::openai()).with_base_url(base_url), + ) + } + + /// Create an OpenAI-compatible client from an OAuth token set with automatic refresh. + /// Used for custom providers that authenticate via OAuth rather than API keys. + #[must_use] + pub fn from_openai_compatible_oauth( + base_url: impl Into, + token_set: runtime::OAuthTokenSet, + token_url: impl Into, + client_id: impl Into, + ) -> Self { + Self::OpenAi( + OpenAiCompatClient::from_oauth_token_set( + token_set, + OpenAiCompatConfig::openai(), + token_url, + client_id, + "custom", + ) + .with_base_url(base_url), + ) + } + + #[must_use] + pub fn from_anthropic_compatible_profile( + api_key: impl Into, + base_url: impl Into, + ) -> Self { + Self::Anthropic(AnthropicClient::new(api_key).with_base_url(base_url)) + } + + /// Set a custom User-Agent header on the underlying OpenAI-compatible client. + /// No-op for Anthropic, xAI, or WHAM variants. + #[must_use] + pub fn with_user_agent(self, user_agent: impl Into) -> Self { + match self { + Self::OpenAi(client) => Self::OpenAi(client.with_user_agent(user_agent)), + other => other, + } + } + #[must_use] pub const fn provider_kind(&self) -> ProviderKind { match self { Self::Anthropic(_) => ProviderKind::Anthropic, Self::Xai(_) => ProviderKind::Xai, - Self::OpenAi(_) => ProviderKind::OpenAi, + Self::OpenAi(_) | Self::Wham(_) => ProviderKind::OpenAi, } } @@ -67,7 +159,7 @@ impl ProviderClient { pub fn prompt_cache_stats(&self) -> Option { match self { Self::Anthropic(client) => client.prompt_cache_stats(), - Self::Xai(_) | Self::OpenAi(_) => None, + Self::Xai(_) | Self::OpenAi(_) | Self::Wham(_) => None, } } @@ -75,7 +167,7 @@ impl ProviderClient { pub fn take_last_prompt_cache_record(&self) -> Option { match self { Self::Anthropic(client) => client.take_last_prompt_cache_record(), - Self::Xai(_) | Self::OpenAi(_) => None, + Self::Xai(_) | Self::OpenAi(_) | Self::Wham(_) => None, } } @@ -86,6 +178,7 @@ impl ProviderClient { match self { Self::Anthropic(client) => client.send_message(request).await, Self::Xai(client) | Self::OpenAi(client) => client.send_message(request).await, + Self::Wham(client) => client.send_message(request).await, } } @@ -102,6 +195,10 @@ impl ProviderClient { .stream_message(request) .await .map(MessageStream::OpenAiCompat), + Self::Wham(client) => client + .stream_message(request) + .await + .map(MessageStream::Wham), } } } @@ -110,6 +207,7 @@ impl ProviderClient { pub enum MessageStream { Anthropic(anthropic::MessageStream), OpenAiCompat(openai_compat::MessageStream), + Wham(wham::WhamMessageStream), } impl MessageStream { @@ -118,6 +216,7 @@ impl MessageStream { match self { Self::Anthropic(stream) => stream.request_id(), Self::OpenAiCompat(stream) => stream.request_id(), + Self::Wham(stream) => stream.request_id(), } } @@ -125,6 +224,7 @@ impl MessageStream { match self { Self::Anthropic(stream) => stream.next_event().await, Self::OpenAiCompat(stream) => stream.next_event().await, + Self::Wham(stream) => stream.next_event().await, } } } diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 40da29f140..e14e9f0074 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -18,14 +18,15 @@ pub use prompt_cache::{ CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord, PromptCacheStats, }; -pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource}; +pub use providers::anthropic::{has_auth_from_env_or_saved as anthropic_has_auth, AnthropicClient, AnthropicClient as ApiClient, AuthSource}; pub use providers::openai_compat::{ - build_chat_completion_request, flatten_tool_result_content, is_reasoning_model, + build_chat_completion_request, flatten_tool_result_content, has_api_key, is_reasoning_model, model_rejects_is_error_field, translate_message, OpenAiCompatClient, OpenAiCompatConfig, }; +pub use providers::wham::{WhamClient, DEFAULT_WHAM_BASE_URL}; pub use providers::{ detect_provider_kind, max_tokens_for_model, max_tokens_for_model_with_override, - resolve_model_alias, ProviderKind, + metadata_for_model, resolve_model_alias, ProviderKind, ProviderMetadata, }; pub use sse::{parse_frame, SseParser}; pub use types::{ diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs index 7c9f02945e..bb7e19d072 100644 --- a/rust/crates/api/src/providers/anthropic.rs +++ b/rust/crates/api/src/providers/anthropic.rs @@ -707,6 +707,7 @@ fn resolve_saved_oauth_token_set( refresh_token: resolved.refresh_token.clone(), expires_at: resolved.expires_at, scopes: resolved.scopes.clone(), + id_token: None, }) .map_err(ApiError::from)?; Ok(resolved) @@ -1165,6 +1166,7 @@ mod tests { refresh_token: Some("refresh".to_string()), expires_at: Some(now_unix_timestamp() + 300), scopes: vec!["scope:a".to_string()], + id_token: None, }) .expect("save oauth credentials"); @@ -1204,6 +1206,7 @@ mod tests { refresh_token: Some("refresh-token".to_string()), expires_at: Some(1), scopes: vec!["scope:a".to_string()], + id_token: None, }) .expect("save expired oauth credentials"); @@ -1236,6 +1239,7 @@ mod tests { refresh_token: Some("refresh".to_string()), expires_at: Some(now_unix_timestamp() + 300), scopes: vec!["scope:a".to_string()], + id_token: None, }) .expect("save oauth credentials"); @@ -1260,6 +1264,7 @@ mod tests { refresh_token: Some("refresh-token".to_string()), expires_at: Some(1), scopes: vec!["scope:a".to_string()], + id_token: None, }) .expect("save expired oauth credentials"); diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index 86871a82a1..d4562d2b4d 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -9,6 +9,7 @@ use crate::types::{MessageRequest, MessageResponse}; pub mod anthropic; pub mod openai_compat; +pub mod wham; #[allow(dead_code)] pub type ProviderFuture<'a, T> = Pin> + Send + 'a>>; @@ -216,6 +217,15 @@ pub fn metadata_for_model(model: &str) -> Option { default_base_url: openai_compat::DEFAULT_DASHSCOPE_BASE_URL, }); } + // Moonshot / Kimi models via native Moonshot API endpoint. + if canonical.starts_with("moonshot/") { + return Some(ProviderMetadata { + provider: ProviderKind::OpenAi, + auth_env: "MOONSHOT_API_KEY", + base_url_env: "MOONSHOT_BASE_URL", + default_base_url: openai_compat::DEFAULT_MOONSHOT_BASE_URL, + }); + } None } diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index a810502e66..fd4de80d9d 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -19,6 +19,7 @@ use super::{preflight_message_request, Provider, ProviderFuture}; pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; pub const DEFAULT_DASHSCOPE_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1"; +pub const DEFAULT_MOONSHOT_BASE_URL: &str = "https://api.moonshot.ai/v1"; const REQUEST_ID_HEADER: &str = "request-id"; const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_secs(1); @@ -41,11 +42,13 @@ pub struct OpenAiCompatConfig { const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"]; const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"]; const DASHSCOPE_ENV_VARS: &[&str] = &["DASHSCOPE_API_KEY"]; +const MOONSHOT_ENV_VARS: &[&str] = &["MOONSHOT_API_KEY"]; // Provider-specific request body size limits in bytes const XAI_MAX_REQUEST_BODY_BYTES: usize = 52_428_800; // 50MB const OPENAI_MAX_REQUEST_BODY_BYTES: usize = 104_857_600; // 100MB const DASHSCOPE_MAX_REQUEST_BODY_BYTES: usize = 6_291_456; // 6MB (observed limit in dogfood) +const MOONSHOT_MAX_REQUEST_BODY_BYTES: usize = 52_428_800; // 50MB impl OpenAiCompatConfig { #[must_use] @@ -85,17 +88,40 @@ impl OpenAiCompatConfig { } } + /// Moonshot AI native endpoint (Kimi family models). + #[must_use] + pub const fn moonshot() -> Self { + Self { + provider_name: "Moonshot", + api_key_env: "MOONSHOT_API_KEY", + base_url_env: "MOONSHOT_BASE_URL", + default_base_url: DEFAULT_MOONSHOT_BASE_URL, + max_request_body_bytes: MOONSHOT_MAX_REQUEST_BODY_BYTES, + } + } + #[must_use] pub fn credential_env_vars(self) -> &'static [&'static str] { match self.provider_name { "xAI" => XAI_ENV_VARS, "OpenAI" => OPENAI_ENV_VARS, "DashScope" => DASHSCOPE_ENV_VARS, + "Moonshot" => MOONSHOT_ENV_VARS, _ => &[], } } } +#[derive(Debug, Clone)] +struct OpenAiCompatOAuthState { + access_token: String, + refresh_token: Option, + expires_at: Option, + token_url: String, + client_id: String, + provider_id: String, +} + #[derive(Debug, Clone)] pub struct OpenAiCompatClient { http: reqwest::Client, @@ -105,6 +131,9 @@ pub struct OpenAiCompatClient { max_retries: u32, initial_backoff: Duration, max_backoff: Duration, + oauth_state: Option>>, + /// Custom User-Agent header for providers that gate access by client identity. + user_agent: Option, } impl OpenAiCompatClient { @@ -126,6 +155,8 @@ impl OpenAiCompatClient { max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, + oauth_state: None, + user_agent: None, } } @@ -139,12 +170,102 @@ impl OpenAiCompatClient { Ok(Self::new(api_key, config)) } + /// Create a client using an OAuth access token instead of an API key. + /// The token is sent as `Authorization: Bearer {token}`. + /// No automatic refresh is performed; use [`from_oauth_token_set`] for refresh support. + #[must_use] + pub fn from_oauth_token(token: impl Into, config: OpenAiCompatConfig) -> Self { + Self::new(token, config) + } + + /// Create a client from a full OAuth token set with automatic refresh support. + #[must_use] + pub fn from_oauth_token_set( + token_set: runtime::OAuthTokenSet, + config: OpenAiCompatConfig, + token_url: impl Into, + client_id: impl Into, + provider_id: impl Into, + ) -> Self { + Self { + http: build_http_client_or_default(), + api_key: token_set.access_token.clone(), + config, + base_url: read_base_url(config), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + oauth_state: Some(std::sync::Arc::new(std::sync::Mutex::new( + OpenAiCompatOAuthState { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + token_url: token_url.into(), + client_id: client_id.into(), + provider_id: provider_id.into(), + }, + ))), + user_agent: None, + } + } + + /// Try env var first, then fall back to saved OAuth token for the provider. + /// `provider_id` is the key used in `~/.claw/credentials.json` under `oauth_providers`. + /// No automatic refresh is performed; use [`from_env_or_oauth_with_refresh`] for refresh support. + pub fn from_env_or_oauth( + config: OpenAiCompatConfig, + provider_id: &str, + ) -> Result { + if let Some(api_key) = read_env_non_empty(config.api_key_env)? { + return Ok(Self::new(api_key, config)); + } + if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_id) { + return Ok(Self::from_oauth_token(token_set.access_token, config)); + } + Err(ApiError::missing_credentials( + config.provider_name, + config.credential_env_vars(), + )) + } + + /// Try env var first, then fall back to saved OAuth token with automatic refresh. + /// `token_url` and `client_id` are used to refresh expired access tokens. + pub fn from_env_or_oauth_with_refresh( + config: OpenAiCompatConfig, + provider_id: &str, + token_url: &str, + client_id: &str, + ) -> Result { + if let Some(api_key) = read_env_non_empty(config.api_key_env)? { + return Ok(Self::new(api_key, config)); + } + if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_id) { + return Ok(Self::from_oauth_token_set( + token_set, + config, + token_url, + client_id, + provider_id, + )); + } + Err(ApiError::missing_credentials( + config.provider_name, + config.credential_env_vars(), + )) + } + #[must_use] pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = base_url.into(); self } + #[must_use] + pub fn with_user_agent(mut self, user_agent: impl Into) -> Self { + self.user_agent = Some(user_agent.into()); + self + } + #[must_use] pub fn with_retry_policy( mut self, @@ -262,6 +383,69 @@ impl OpenAiCompatClient { }) } + async fn ensure_token_valid(&self) -> Result<(), ApiError> { + let Some(oauth_state) = &self.oauth_state else { + return Ok(()); + }; + + let needs_refresh = { + let state = oauth_state.lock().map_err(|e| { + ApiError::Auth(format!("OAuth state mutex poisoned: {e}")) + })?; + match state.expires_at { + None => false, + Some(expires_at) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + now + 60 >= expires_at + } + } + }; + + if !needs_refresh { + return Ok(()); + } + + let (refresh_token, token_url, client_id, provider_id) = { + let state = oauth_state.lock().map_err(|e| { + ApiError::Auth(format!("OAuth state mutex poisoned: {e}")) + })?; + let refresh = state.refresh_token.clone().ok_or_else(|| { + ApiError::Auth("OAuth token expired and no refresh token available".to_string()) + })?; + ( + refresh, + state.token_url.clone(), + state.client_id.clone(), + state.provider_id.clone(), + ) + }; + + let new_token = runtime::refresh_oauth_token( + &self.http, + &token_url, + &client_id, + &refresh_token, + ) + .await + .map_err(|e| ApiError::Auth(format!("OAuth token refresh failed: {e}")))?; + + { + let mut state = oauth_state.lock().map_err(|e| { + ApiError::Auth(format!("OAuth state mutex poisoned: {e}")) + })?; + state.access_token = new_token.access_token.clone(); + state.refresh_token = new_token.refresh_token.clone(); + state.expires_at = new_token.expires_at; + } + + let _ = runtime::save_provider_oauth(&provider_id, &new_token); + + Ok(()) + } + async fn send_raw_request( &self, request: &MessageRequest, @@ -269,15 +453,30 @@ impl OpenAiCompatClient { // Pre-flight check: verify request body size against provider limits check_request_body_size(request, self.config())?; + self.ensure_token_valid().await?; + + let access_token = { + if let Some(oauth_state) = &self.oauth_state { + let state = oauth_state.lock().map_err(|e| { + ApiError::Auth(format!("OAuth state mutex poisoned: {e}")) + })?; + state.access_token.clone() + } else { + self.api_key.clone() + } + }; + let request_url = chat_completions_endpoint(&self.base_url); - self.http + let mut req = self + .http .post(&request_url) .header("content-type", "application/json") - .bearer_auth(&self.api_key) - .json(&build_chat_completion_request(request, self.config())) - .send() - .await - .map_err(ApiError::from) + .bearer_auth(&access_token) + .json(&build_chat_completion_request(request, self.config())); + if let Some(ref ua) = self.user_agent { + req = req.header("user-agent", ua); + } + req.send().await.map_err(ApiError::from) } fn backoff_for_attempt(&self, attempt: u32) -> Result { diff --git a/rust/crates/api/src/providers/wham.rs b/rust/crates/api/src/providers/wham.rs new file mode 100644 index 0000000000..fb715b71ab --- /dev/null +++ b/rust/crates/api/src/providers/wham.rs @@ -0,0 +1,738 @@ +use std::collections::VecDeque; + +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +use crate::error::ApiError; +use crate::http_client::build_http_client_or_default; +use crate::types::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, + MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, Usage, +}; + +use super::{Provider, ProviderFuture}; + +pub const DEFAULT_WHAM_BASE_URL: &str = "https://chatgpt.com/backend-api/wham"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; + +#[derive(Debug, Clone)] +pub struct WhamClient { + http: reqwest::Client, + token: std::sync::Arc>, + base_url: String, +} + +#[derive(Debug, Clone)] +struct WhamToken { + access_token: String, + refresh_token: Option, + expires_at: Option, + account_id: Option, + token_url: String, + client_id: String, +} + +impl WhamClient { + #[must_use] + pub fn new(access_token: impl Into, account_id: Option) -> Self { + Self { + http: build_http_client_or_default(), + token: std::sync::Arc::new(std::sync::Mutex::new(WhamToken { + access_token: access_token.into(), + refresh_token: None, + expires_at: None, + account_id, + token_url: "https://auth.openai.com/oauth/token".to_string(), + client_id: "app_EMoamEEZ73f0CkXaXp7hrann".to_string(), + })), + base_url: DEFAULT_WHAM_BASE_URL.to_string(), + } + } + + /// Create a WHAM client with full OAuth token info, enabling automatic refresh. + #[must_use] + pub fn from_oauth_token_set( + token_set: runtime::OAuthTokenSet, + account_id: Option, + token_url: impl Into, + client_id: impl Into, + ) -> Self { + Self { + http: build_http_client_or_default(), + token: std::sync::Arc::new(std::sync::Mutex::new(WhamToken { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + account_id, + token_url: token_url.into(), + client_id: client_id.into(), + })), + base_url: DEFAULT_WHAM_BASE_URL.to_string(), + } + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + async fn ensure_token_valid(&self) -> Result<(), ApiError> { + let needs_refresh = { + let token = self.token.lock().map_err(|e| { + ApiError::Auth(format!("token mutex poisoned: {e}")) + })?; + match token.expires_at { + None => false, + Some(expires_at) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + // Refresh if fewer than 60 seconds remain. + now + 60 >= expires_at + } + } + }; + + if !needs_refresh { + return Ok(()); + } + + let (refresh_token, token_url, client_id) = { + let token = self.token.lock().map_err(|e| { + ApiError::Auth(format!("token mutex poisoned: {e}")) + })?; + let refresh = token.refresh_token.clone().ok_or_else(|| { + ApiError::Auth("OAuth token expired and no refresh token available".to_string()) + })?; + (refresh, token.token_url.clone(), token.client_id.clone()) + }; + + let new_token = runtime::refresh_oauth_token(&self.http, &token_url, &client_id, &refresh_token) + .await + .map_err(|e| ApiError::Auth(format!("OAuth token refresh failed: {e}")))?; + + { + let mut token = self.token.lock().map_err(|e| { + ApiError::Auth(format!("token mutex poisoned: {e}")) + })?; + token.access_token = new_token.access_token.clone(); + token.refresh_token = new_token.refresh_token.clone(); + token.expires_at = new_token.expires_at; + } + + // Persist the refreshed token. + let _ = runtime::save_provider_oauth("openai", &new_token); + + Ok(()) + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + self.ensure_token_valid().await?; + let (access_token, account_id) = { + let token = self.token.lock().map_err(|e| { + ApiError::Auth(format!("token mutex poisoned: {e}")) + })?; + (token.access_token.clone(), token.account_id.clone()) + }; + + let request_url = responses_endpoint(&self.base_url); + let body = build_responses_request(request); + + let mut req_builder = self + .http + .post(&request_url) + .header("content-type", "application/json") + .bearer_auth(&access_token); + + if let Some(ref id) = account_id { + req_builder = req_builder.header("ChatGPT-Account-Id", id); + } + + let response = req_builder.json(&body).send().await.map_err(ApiError::from)?; + let request_id = request_id_from_headers(response.headers()); + + if !response.status().is_success() { + let status = response.status(); + let body_text = response.text().await.unwrap_or_default(); + return Err(ApiError::Api { + status, + error_type: Some("wham_error".to_string()), + message: Some(body_text.clone()), + request_id, + body: body_text, + retryable: status.is_server_error(), + suggested_action: suggested_action_for_status(status), + }); + } + + let resp_body = response.text().await.map_err(ApiError::from)?; + let wham_resp: ResponsesResponse = serde_json::from_str(&resp_body).map_err(|error| { + ApiError::json_deserialize("OpenAI WHAM", &request.model, &resp_body, error) + })?; + + Ok(convert_to_message_response(wham_resp, request_id)) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + self.ensure_token_valid().await?; + let (access_token, account_id) = { + let token = self.token.lock().map_err(|e| { + ApiError::Auth(format!("token mutex poisoned: {e}")) + })?; + (token.access_token.clone(), token.account_id.clone()) + }; + + let request_url = responses_endpoint(&self.base_url); + let body = build_responses_request(request); + + let mut req_builder = self + .http + .post(&request_url) + .header("content-type", "application/json") + .bearer_auth(&access_token); + + if let Some(ref id) = account_id { + req_builder = req_builder.header("ChatGPT-Account-Id", id); + } + + let response = req_builder.json(&body).send().await.map_err(ApiError::from)?; + let request_id = request_id_from_headers(response.headers()); + + if !response.status().is_success() { + let status = response.status(); + let body_text = response.text().await.unwrap_or_default(); + return Err(ApiError::Api { + status, + error_type: Some("wham_error".to_string()), + message: Some(body_text.clone()), + request_id, + body: body_text, + retryable: status.is_server_error(), + suggested_action: suggested_action_for_status(status), + }); + } + + Ok(WhamMessageStream { + response, + buffer: Vec::new(), + pending: VecDeque::new(), + done: false, + state: WhamStreamState::new(request_id), + }) + } +} + +impl Provider for WhamClient { + type Stream = WhamMessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +struct WhamStreamState { + request_id: Option, + message_started: bool, + content_index: u32, + text_started: bool, + finished: bool, + usage: Option, +} + +impl WhamStreamState { + fn new(request_id: Option) -> Self { + Self { + request_id, + message_started: false, + content_index: 0, + text_started: false, + finished: false, + usage: None, + } + } + + fn ingest_event(&mut self, event: WhamSseEvent) -> Vec { + let mut events = Vec::new(); + + match event { + WhamSseEvent::Created { response } | WhamSseEvent::InProgress { response } => { + if !self.message_started { + self.message_started = true; + events.push(StreamEvent::MessageStart(MessageStartEvent { + message: MessageResponse { + id: response.id, + kind: "message".to_string(), + role: "assistant".to_string(), + content: Vec::new(), + model: response.model, + stop_reason: None, + stop_sequence: None, + usage: Usage::default(), + request_id: self.request_id.clone(), + }, + })); + } + } + WhamSseEvent::OutputTextDelta { content_index, delta } => { + if !self.text_started { + self.text_started = true; + self.content_index = content_index; + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: content_index, + content_block: OutputContentBlock::Text { text: String::new() }, + })); + } + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: content_index, + delta: ContentBlockDelta::TextDelta { text: delta }, + })); + } + WhamSseEvent::OutputTextDone { content_index } => { + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: content_index, + })); + self.text_started = false; + } + WhamSseEvent::Completed { usage } => { + self.usage = usage; + } + _ => {} + } + + events + } + + fn finish(&mut self) -> Vec { + let mut events = Vec::new(); + if self.text_started { + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: self.content_index, + })); + self.text_started = false; + } + if self.message_started && !self.finished { + self.finished = true; + events.push(StreamEvent::MessageDelta(MessageDeltaEvent { + delta: MessageDelta { + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + }, + usage: self.usage.clone().unwrap_or_default(), + })); + events.push(StreamEvent::MessageStop(MessageStopEvent {})); + } + events + } +} + +#[derive(Debug)] +pub struct WhamMessageStream { + response: reqwest::Response, + buffer: Vec, + pending: VecDeque, + done: bool, + state: WhamStreamState, +} + +impl WhamMessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.state.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + self.pending.extend(self.state.finish()); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + self.buffer.extend_from_slice(&chunk); + while let Some(frame) = next_sse_frame(&mut self.buffer) { + if let Some(event) = parse_wham_sse_frame(&frame)? { + self.pending.extend(self.state.ingest_event(event)); + } + } + } + None => { + self.done = true; + } + } + } + } +} + +fn next_sse_frame(buffer: &mut Vec) -> Option { + let separator = buffer + .windows(2) + .position(|window| window == b"\n\n") + .map(|position| (position, 2)) + .or_else(|| { + buffer + .windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|position| (position, 4)) + })?; + + let (position, separator_len) = separator; + let frame = buffer.drain(..position + separator_len).collect::>(); + let frame_len = frame.len().saturating_sub(separator_len); + Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) +} + +#[derive(Debug, Clone, Deserialize)] +struct WhamResponseStub { + id: String, + model: String, +} + +#[derive(Debug, Clone, Deserialize)] +struct WhamUsageStub { + input_tokens: u32, + output_tokens: u32, +} + +#[derive(Debug, Clone)] +enum WhamSseEvent { + Created { response: WhamResponseStub }, + InProgress { response: WhamResponseStub }, + OutputTextDelta { content_index: u32, delta: String }, + OutputTextDone { content_index: u32 }, + Completed { usage: Option }, + Other, +} + +fn parse_wham_sse_frame(frame: &str) -> Result, ApiError> { + let trimmed = frame.trim(); + if trimmed.is_empty() || trimmed.starts_with(':') { + return Ok(None); + } + + let mut event_type = String::new(); + let mut data_json = String::new(); + + for line in trimmed.lines() { + if let Some(et) = line.strip_prefix("event:") { + event_type = et.trim().to_string(); + } else if let Some(data) = line.strip_prefix("data:") { + data_json.push_str(data.trim_start()); + } + } + + if data_json.is_empty() { + return Ok(None); + } + + let data: serde_json::Value = serde_json::from_str(&data_json) + .map_err(|e| ApiError::json_deserialize("OpenAI WHAM", "", &data_json, e))?; + + let event = match event_type.as_str() { + "response.created" => WhamSseEvent::Created { + response: serde_json::from_value(data.get("response").cloned().unwrap_or_default()) + .unwrap_or(WhamResponseStub { id: String::new(), model: String::new() }), + }, + "response.in_progress" => WhamSseEvent::InProgress { + response: serde_json::from_value(data.get("response").cloned().unwrap_or_default()) + .unwrap_or(WhamResponseStub { id: String::new(), model: String::new() }), + }, + "response.output_text.delta" => WhamSseEvent::OutputTextDelta { + content_index: data["content_index"].as_u64().unwrap_or(0) as u32, + delta: data["delta"].as_str().unwrap_or("").to_string(), + }, + "response.output_text.done" => WhamSseEvent::OutputTextDone { + content_index: data["content_index"].as_u64().unwrap_or(0) as u32, + }, + "response.completed" => { + let usage = data.get("response").and_then(|r| r.get("usage")).map(|u| Usage { + input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32, + output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }); + WhamSseEvent::Completed { usage } + } + _ => WhamSseEvent::Other, + }; + + Ok(Some(event)) +} + +fn responses_endpoint(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if trimmed.ends_with("/responses") { + trimmed.to_string() + } else { + format!("{trimmed}/responses") + } +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|v| v.to_str().ok()) + .map(String::from) +} + +fn suggested_action_for_status(status: reqwest::StatusCode) -> Option { + match status.as_u16() { + 401 => Some("OAuth token may be expired. Try re-authenticating with `claw auth login openai`".to_string()), + 403 => Some("Verify ChatGPT subscription is active (Plus/Pro required)".to_string()), + 429 => Some("Wait a moment before retrying; consider reducing request rate".to_string()), + _ => None, + } +} + +/// Build a Responses API request body from our internal MessageRequest. +fn build_responses_request(request: &MessageRequest) -> Value { + // WHAM backend requires streaming mode. + let mut body = json!({ + "model": request.model, + "store": false, + "stream": true, + }); + + if let Some(ref system) = request.system { + body["instructions"] = json!(system); + } + + // Convert messages to Responses API `input` format + let input: Vec = request + .messages + .iter() + .map(|msg| { + let content_blocks: Vec = msg + .content + .iter() + .map(|block| match block { + InputContentBlock::Text { text } => { + json!({"type": "input_text", "text": text}) + } + InputContentBlock::ToolUse { id, name, input } => { + if msg.role == "assistant" { + json!({ + "type": "tool_use", + "id": id, + "name": name, + "input": input, + }) + } else { + json!({"type": "input_text", "text": "[tool input]"}) + } + } + InputContentBlock::ToolResult { tool_use_id, content, is_error } => { + let text = content + .iter() + .map(|c| match c { + crate::types::ToolResultContentBlock::Text { text } => text.clone(), + crate::types::ToolResultContentBlock::Json { value } => { + value.to_string() + } + }) + .collect::>() + .join("\n"); + json!({ + "type": "input_text", + "text": format!( + "[tool result {} {}]\n{}", + tool_use_id, + if *is_error { "(error)" } else { "" }, + text + ), + }) + } + }) + .collect(); + json!({"role": msg.role, "content": content_blocks}) + }) + .collect(); + + body["input"] = json!(input); + + // Note: WHAM backend does not support `max_output_tokens`. + // if request.max_tokens > 0 { + // body["max_output_tokens"] = json!(request.max_tokens); + // } + + if let Some(temp) = request.temperature { + body["temperature"] = json!(temp); + } + if let Some(top_p) = request.top_p { + body["top_p"] = json!(top_p); + } + + body +} + +/// Responses API response shape. +#[derive(Debug, Clone, Deserialize)] +struct ResponsesResponse { + id: String, + model: String, + output: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ResponsesOutputItem { + Message { + id: String, + role: String, + content: Vec, + }, + #[serde(other)] + Other, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ResponsesContentItem { + OutputText { text: String }, + #[serde(other)] + Other, +} + +#[derive(Debug, Clone, Deserialize, Default)] +struct ResponsesUsage { + input_tokens: u32, + output_tokens: u32, +} + +fn convert_to_message_response( + wham: ResponsesResponse, + request_id: Option, +) -> MessageResponse { + let mut content = Vec::new(); + for item in &wham.output { + if let ResponsesOutputItem::Message { content: blocks, .. } = item { + for block in blocks { + if let ResponsesContentItem::OutputText { text } = block { + content.push(OutputContentBlock::Text { text: text.clone() }); + } + } + } + } + + let usage = wham.usage.map(|u| Usage { + input_tokens: u.input_tokens, + output_tokens: u.output_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }).unwrap_or_default(); + + MessageResponse { + id: wham.id, + kind: "message".to_string(), + role: "assistant".to_string(), + content, + model: wham.model, + stop_reason: None, + stop_sequence: None, + usage, + request_id, + } +} + +fn response_to_stream_events(response: MessageResponse) -> Vec { + let mut events = Vec::new(); + + events.push(StreamEvent::MessageStart(MessageStartEvent { + message: response.clone(), + })); + + for (index, block) in response.content.iter().enumerate() { + let block_index = index as u32; + match block { + OutputContentBlock::Text { text } => { + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: block_index, + content_block: OutputContentBlock::Text { text: String::new() }, + })); + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: block_index, + delta: ContentBlockDelta::TextDelta { text: text.clone() }, + })); + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + OutputContentBlock::ToolUse { id, name, input } => { + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: block_index, + content_block: OutputContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: input.clone(), + }, + })); + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: block_index, + delta: ContentBlockDelta::InputJsonDelta { + partial_json: input.to_string(), + }, + })); + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + OutputContentBlock::Thinking { thinking, .. } => { + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: block_index, + content_block: OutputContentBlock::Text { text: String::new() }, + })); + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: block_index, + delta: ContentBlockDelta::ThinkingDelta { thinking: thinking.clone() }, + })); + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + OutputContentBlock::RedactedThinking { .. } => { + // Skip redacted thinking blocks in stream replay + } + } + } + + events.push(StreamEvent::MessageDelta(MessageDeltaEvent { + delta: MessageDelta { + stop_reason: response.stop_reason.clone(), + stop_sequence: response.stop_sequence.clone(), + }, + usage: response.usage.clone(), + })); + + events +} diff --git a/rust/crates/runtime/Cargo.toml b/rust/crates/runtime/Cargo.toml index b1bd04f374..1357edd98b 100644 --- a/rust/crates/runtime/Cargo.toml +++ b/rust/crates/runtime/Cargo.toml @@ -6,10 +6,12 @@ license.workspace = true publish.workspace = true [dependencies] +base64 = "0.22" sha2 = "0.10" glob = "0.3" plugins = { path = "../plugins" } regex = "1" +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } serde = { version = "1", features = ["derive"] } serde_json.workspace = true telemetry = { path = "../telemetry" } diff --git a/rust/crates/runtime/src/config.rs b/rust/crates/runtime/src/config.rs index 1566189282..46d871d54f 100644 --- a/rust/crates/runtime/src/config.rs +++ b/rust/crates/runtime/src/config.rs @@ -64,9 +64,21 @@ pub struct RuntimeFeatureConfig { permission_rules: RuntimePermissionRuleConfig, sandbox: SandboxConfig, provider_fallbacks: ProviderFallbackConfig, + model_providers: BTreeMap, trusted_roots: Vec, } +/// User-configured model provider profile. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModelProviderConfig { + provider_type: String, + base_url: String, + api_key_env: Option, + api_key: Option, + models: Vec, + default_model: Option, +} + /// Ordered chain of fallback model identifiers used when the primary /// provider returns a retryable failure (429/500/503/etc.). The chain is /// strict: each entry is tried in order until one succeeds. @@ -314,6 +326,7 @@ impl ConfigLoader { permission_rules: parse_optional_permission_rules(&merged_value)?, sandbox: parse_optional_sandbox_config(&merged_value)?, provider_fallbacks: parse_optional_provider_fallbacks(&merged_value)?, + model_providers: parse_optional_model_providers(&merged_value)?, trusted_roots: parse_optional_trusted_roots(&merged_value)?, }; @@ -410,6 +423,11 @@ impl RuntimeConfig { &self.feature_config.provider_fallbacks } + #[must_use] + pub fn model_providers(&self) -> &BTreeMap { + &self.feature_config.model_providers + } + #[must_use] pub fn trusted_roots(&self) -> &[String] { &self.feature_config.trusted_roots @@ -479,12 +497,68 @@ impl RuntimeFeatureConfig { &self.provider_fallbacks } + #[must_use] + pub fn model_providers(&self) -> &BTreeMap { + &self.model_providers + } + #[must_use] pub fn trusted_roots(&self) -> &[String] { &self.trusted_roots } } +impl ModelProviderConfig { + #[must_use] + pub fn new( + provider_type: String, + base_url: String, + api_key_env: Option, + api_key: Option, + models: Vec, + default_model: Option, + ) -> Self { + Self { + provider_type, + base_url, + api_key_env, + api_key, + models, + default_model, + } + } + + #[must_use] + pub fn provider_type(&self) -> &str { + &self.provider_type + } + + #[must_use] + pub fn base_url(&self) -> &str { + &self.base_url + } + + #[must_use] + pub fn api_key_env(&self) -> Option<&str> { + self.api_key_env.as_deref() + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + self.api_key.as_deref() + } + + #[must_use] + pub fn models(&self) -> &[String] { + &self.models + } + + #[must_use] + pub fn default_model(&self) -> Option<&str> { + self.default_model.as_deref() + } +} + impl ProviderFallbackConfig { #[must_use] pub fn new(primary: Option, fallbacks: Vec) -> Self { @@ -904,6 +978,64 @@ fn parse_optional_provider_fallbacks( Ok(ProviderFallbackConfig { primary, fallbacks }) } +fn parse_optional_model_providers( + root: &JsonValue, +) -> Result, ConfigError> { + let Some(object) = root.as_object() else { + return Ok(BTreeMap::new()); + }; + let Some(value) = object.get("modelProviders") else { + return Ok(BTreeMap::new()); + }; + let providers = expect_object(value, "merged settings.modelProviders")?; + let mut parsed = BTreeMap::new(); + for (name, value) in providers { + let context = format!("merged settings.modelProviders.{name}"); + let provider = expect_object(value, &context)?; + let provider_type = optional_string(provider, "type", &context)? + .unwrap_or("openai-compatible") + .to_string(); + if !matches!( + provider_type.as_str(), + "openai-compatible" | "openai" | "anthropic-compatible" | "anthropic" + ) { + return Err(ConfigError::Parse(format!( + "{context}: unsupported provider type {provider_type}" + ))); + } + let base_url = expect_string(provider, "baseUrl", &context)?.to_string(); + let api_key_env = optional_string(provider, "apiKeyEnv", &context)?.map(str::to_string); + let api_key = optional_string(provider, "apiKey", &context)?.map(str::to_string); + let models = optional_string_array(provider, "models", &context)?.unwrap_or_default(); + let default_model = + optional_string(provider, "defaultModel", &context)?.map(str::to_string); + if models.is_empty() && default_model.is_none() { + return Err(ConfigError::Parse(format!( + "{context}: expected at least one model in models or defaultModel" + ))); + } + if let Some(default_model) = &default_model { + if !models.is_empty() && !models.iter().any(|model| model == default_model) { + return Err(ConfigError::Parse(format!( + "{context}: defaultModel must be listed in models" + ))); + } + } + parsed.insert( + name.clone(), + ModelProviderConfig::new( + provider_type, + base_url, + api_key_env, + api_key, + models, + default_model, + ), + ); + } + Ok(parsed) +} + fn parse_optional_trusted_roots(root: &JsonValue) -> Result, ConfigError> { let Some(object) = root.as_object() else { return Ok(Vec::new()); @@ -1812,6 +1944,51 @@ mod tests { fs::remove_dir_all(root).expect("cleanup temp dir"); } + #[test] + fn parses_model_provider_profiles_from_settings() { + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write( + home.join("settings.json"), + r#"{ + "model": "zai/glm-5.1", + "modelProviders": { + "zai": { + "type": "openai-compatible", + "baseUrl": "https://api.z.ai/api/paas/v4", + "apiKeyEnv": "Z_AI_API_KEY", + "models": ["glm-5.1", "glm-4.6"], + "defaultModel": "glm-5.1" + } + } + }"#, + ) + .expect("write settings"); + + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + let provider = loaded + .model_providers() + .get("zai") + .expect("zai provider should parse"); + + assert_eq!(loaded.model(), Some("zai/glm-5.1")); + assert_eq!(provider.provider_type(), "openai-compatible"); + assert_eq!(provider.base_url(), "https://api.z.ai/api/paas/v4"); + assert_eq!(provider.api_key_env(), Some("Z_AI_API_KEY")); + assert_eq!(provider.default_model(), Some("glm-5.1")); + assert_eq!( + provider.models(), + &["glm-5.1".to_string(), "glm-4.6".to_string()] + ); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + #[test] fn empty_settings_file_loads_defaults() { // given diff --git a/rust/crates/runtime/src/config_validate.rs b/rust/crates/runtime/src/config_validate.rs index 7a9c1c4adc..3a50bf5516 100644 --- a/rust/crates/runtime/src/config_validate.rs +++ b/rust/crates/runtime/src/config_validate.rs @@ -193,6 +193,10 @@ const TOP_LEVEL_FIELDS: &[FieldSpec] = &[ name: "providerFallbacks", expected: FieldType::Object, }, + FieldSpec { + name: "modelProviders", + expected: FieldType::Object, + }, FieldSpec { name: "trustedRoots", expected: FieldType::StringArray, @@ -310,6 +314,33 @@ const OAUTH_FIELDS: &[FieldSpec] = &[ }, ]; +const MODEL_PROVIDER_FIELDS: &[FieldSpec] = &[ + FieldSpec { + name: "type", + expected: FieldType::String, + }, + FieldSpec { + name: "baseUrl", + expected: FieldType::String, + }, + FieldSpec { + name: "apiKeyEnv", + expected: FieldType::String, + }, + FieldSpec { + name: "apiKey", + expected: FieldType::String, + }, + FieldSpec { + name: "models", + expected: FieldType::StringArray, + }, + FieldSpec { + name: "defaultModel", + expected: FieldType::String, + }, +]; + const DEPRECATED_FIELDS: &[DeprecatedField] = &[ DeprecatedField { name: "permissionMode", @@ -501,6 +532,19 @@ pub fn validate_config_file( &path_display, )); } + if let Some(model_providers) = object.get("modelProviders").and_then(JsonValue::as_object) { + for (name, provider) in model_providers { + if let Some(provider) = provider.as_object() { + result.merge(validate_object_keys( + provider, + MODEL_PROVIDER_FIELDS, + &format!("modelProviders.{name}"), + source, + &path_display, + )); + } + } + } result } diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index c7d87091fa..126a067fe5 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -112,10 +112,13 @@ pub use mcp_stdio::{ UnsupportedMcpServer, }; pub use oauth::{ - clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair, - generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query, - parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest, - OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, + clear_oauth_credentials, clear_provider_oauth, code_challenge_s256, credentials_path, + extract_chatgpt_account_id, generate_pkce_pair, generate_state, load_oauth_credentials, + load_provider_oauth, loopback_redirect_uri, loopback_redirect_uri_with_path, open_browser, + parse_oauth_callback_query, parse_oauth_callback_request_target, poll_device_token, + refresh_oauth_token, run_oauth_callback_server, save_oauth_credentials, save_provider_oauth, + DeviceAuthRequest, DeviceAuthResponse, OAuthAuthorizationRequest, OAuthCallbackParams, + OAuthCallbackResult, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, PkceChallengeMethod, PkceCodePair, }; pub use permissions::{ diff --git a/rust/crates/runtime/src/oauth.rs b/rust/crates/runtime/src/oauth.rs index aa3ca158c7..92477fc602 100644 --- a/rust/crates/runtime/src/oauth.rs +++ b/rust/crates/runtime/src/oauth.rs @@ -16,6 +16,10 @@ pub struct OAuthTokenSet { pub refresh_token: Option, pub expires_at: Option, pub scopes: Vec, + /// OpenID token from auth.openai.com, needed to extract `chatgpt_account_id` + /// for the WHAM backend (`ChatGPT-Account-Id` header). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id_token: Option, } /// PKCE verifier/challenge pair generated for an OAuth authorization flow. @@ -93,6 +97,8 @@ struct StoredOAuthCredentials { expires_at: Option, #[serde(default)] scopes: Vec, + #[serde(default)] + id_token: Option, } impl From for StoredOAuthCredentials { @@ -102,6 +108,7 @@ impl From for StoredOAuthCredentials { refresh_token: value.refresh_token, expires_at: value.expires_at, scopes: value.scopes, + id_token: value.id_token, } } } @@ -113,6 +120,7 @@ impl From for OAuthTokenSet { refresh_token: value.refresh_token, expires_at: value.expires_at, scopes: value.scopes, + id_token: value.id_token, } } } @@ -207,7 +215,6 @@ impl OAuthTokenExchangeRequest { ("redirect_uri", self.redirect_uri.clone()), ("client_id", self.client_id.clone()), ("code_verifier", self.code_verifier.clone()), - ("state", self.state.clone()), ]) } } @@ -262,6 +269,11 @@ pub fn loopback_redirect_uri(port: u16) -> String { format!("http://localhost:{port}/callback") } +#[must_use] +pub fn loopback_redirect_uri_with_path(port: u16, path: &str) -> String { + format!("http://localhost:{port}{path}") +} + pub fn credentials_path() -> io::Result { Ok(credentials_home_dir()?.join("credentials.json")) } @@ -298,12 +310,363 @@ pub fn clear_oauth_credentials() -> io::Result<()> { write_credentials_root(&path, &root) } -pub fn parse_oauth_callback_request_target(target: &str) -> Result { +// --------------------------------------------------------------------------- +// Per-provider OAuth token storage +// --------------------------------------------------------------------------- + +/// Load OAuth credentials for a specific provider from `credentials.json`. +/// Credentials are stored under the `oauth_providers.{provider_id}` key. +pub fn load_provider_oauth(provider_id: &str) -> io::Result> { + let path = credentials_path()?; + let root = read_credentials_root(&path)?; + let Some(oauth_providers) = root.get("oauth_providers") else { + return Ok(None); + }; + let Some(provider_value) = oauth_providers.get(provider_id) else { + return Ok(None); + }; + if provider_value.is_null() { + return Ok(None); + } + let stored = serde_json::from_value::(provider_value.clone()) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + Ok(Some(stored.into())) +} + +/// Save OAuth credentials for a specific provider to `credentials.json`. +/// Preserves other providers and the legacy `"oauth"` key. +pub fn save_provider_oauth(provider_id: &str, token_set: &OAuthTokenSet) -> io::Result<()> { + let path = credentials_path()?; + let mut root = read_credentials_root(&path)?; + let oauth_providers = root + .entry("oauth_providers") + .or_insert_with(|| Value::Object(Map::new())); + let provider_map = oauth_providers + .as_object_mut() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "oauth_providers must be an object"))?; + provider_map.insert( + provider_id.to_string(), + serde_json::to_value(StoredOAuthCredentials::from(token_set.clone())) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?, + ); + write_credentials_root(&path, &root) +} + +/// Clear OAuth credentials for a specific provider. +pub fn clear_provider_oauth(provider_id: &str) -> io::Result<()> { + let path = credentials_path()?; + let mut root = read_credentials_root(&path)?; + let Some(oauth_providers) = root.get_mut("oauth_providers") else { + return Ok(()); + }; + let Some(provider_map) = oauth_providers.as_object_mut() else { + return Ok(()); + }; + provider_map.remove(provider_id); + if provider_map.is_empty() { + root.remove("oauth_providers"); + } + write_credentials_root(&path, &root) +} + +// --------------------------------------------------------------------------- +// Browser launcher +// --------------------------------------------------------------------------- + +/// Open a URL in the user's default browser. +/// Falls back to printing the URL if the platform command fails. +pub fn open_browser(url: &str) -> io::Result<()> { + let (cmd, args): (&str, Vec<&str>) = if cfg!(target_os = "macos") { + ("open", vec![url]) + } else if cfg!(target_os = "linux") { + ("xdg-open", vec![url]) + } else if cfg!(target_os = "windows") { + ("cmd", vec!["/C", "start", "", url]) + } else { + eprintln!("Please open this URL in your browser:"); + eprintln!(" {url}"); + return Ok(()); + }; + match std::process::Command::new(cmd).args(&args).output() { + Ok(output) if output.status.success() => Ok(()), + _ => { + eprintln!("Could not open browser automatically. Please open this URL:"); + eprintln!(" {url}"); + Ok(()) + } + } +} + +// --------------------------------------------------------------------------- +// Local HTTP callback server (blocking, single-request) +// --------------------------------------------------------------------------- + +/// Result of a successful OAuth callback. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthCallbackResult { + pub code: String, + pub state: String, +} + +/// Run a blocking local HTTP server that waits for a single callback request. +/// Returns the authorization code and state on success. +/// Times out after `timeout` duration. +pub fn run_oauth_callback_server( + port: u16, + timeout: std::time::Duration, + callback_path: &str, +) -> io::Result { + use std::io::{BufRead, BufReader, Write}; + use std::net::{SocketAddr, TcpListener}; + use std::sync::mpsc; + + let addr: SocketAddr = format!("127.0.0.1:{port}").parse().map_err(|e| { + io::Error::new(io::ErrorKind::InvalidInput, format!("invalid address: {e}")) + })?; + let listener = TcpListener::bind(addr)?; + + let (tx, rx) = mpsc::channel::(); + + std::thread::spawn(move || { + if let Ok((stream, _)) = listener.accept() { + let _ = tx.send(stream); + } + }); + + let mut stream = match rx.recv_timeout(timeout) { + Ok(stream) => stream, + Err(mpsc::RecvTimeoutError::Timeout) => { + return Err(io::Error::new( + io::ErrorKind::TimedOut, + "OAuth callback timed out waiting for browser redirect", + )); + } + Err(mpsc::RecvTimeoutError::Disconnected) => { + return Err(io::Error::new( + io::ErrorKind::Other, + "callback server thread disconnected", + )); + } + }; + + let mut reader = BufReader::new(&mut stream); + let mut first_line = String::new(); + reader.read_line(&mut first_line)?; + + // Parse "GET /callback?code=...&state=... HTTP/1.1" + let target = first_line + .split_whitespace() + .nth(1) + .unwrap_or(""); + + // Consume remaining headers so browser doesn't reset connection + loop { + let mut line = String::new(); + if reader.read_line(&mut line)? == 0 { + break; + } + if line == "\r\n" || line == "\n" { + break; + } + } + + match parse_oauth_callback_request_target(target, callback_path) { + Ok(params) => { + if let (Some(code), Some(state)) = (¶ms.code, ¶ms.state) { + // Success page + let body = r#" +Authentication Successful + +

✅ Authentication Successful

+

You can close this tab and return to the terminal.

+"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + return Ok(OAuthCallbackResult { + code: code.clone(), + state: state.clone(), + }); + } + if let Some(error) = ¶ms.error { + let err_desc = params.error_description.as_deref().unwrap_or(error); + let body = format!( + r#" +Authentication Failed + +

❌ Authentication Failed

+

{}

+

You can close this tab and return to the terminal.

+"#, + html_escape(err_desc) + ); + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + return Err(io::Error::new( + io::ErrorKind::Other, + format!("OAuth error: {error} - {err_desc}"), + )); + } + Err(io::Error::new( + io::ErrorKind::InvalidData, + "callback received neither code nor error", + )) + } + Err(e) => { + let body = format!( + r#" +Authentication Failed + +

❌ Invalid Callback

+

{}

+"#, + html_escape(&e) + ); + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + Err(io::Error::new(io::ErrorKind::Other, e)) + } + } +} + +#[must_use] +fn html_escape(input: &str) -> String { + input + .replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) +} + +// --------------------------------------------------------------------------- +// Device Authorization Flow (RFC 8628) +// --------------------------------------------------------------------------- + +/// Request body for starting a device authorization flow. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DeviceAuthRequest { + pub client_id: String, + pub scope: String, +} + +/// Response from a device authorization endpoint. +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +pub struct DeviceAuthResponse { + pub device_code: String, + pub user_code: String, + pub verification_uri: String, + #[serde(default)] + pub verification_uri_complete: Option, + pub expires_in: u64, + pub interval: u64, +} + +/// Poll a device token endpoint until the user authorizes or the flow expires. +/// Returns `Ok(None)` if the user hasn't authorized yet but we should keep polling. +/// Returns `Ok(Some(token_set))` on success. +/// Returns `Err` on fatal errors. +pub async fn poll_device_token( + client: &reqwest::Client, + device_code: &str, + client_id: &str, + token_url: &str, +) -> io::Result> { + let params = [ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("device_code", device_code), + ("client_id", client_id), + ]; + + let response = client + .post(token_url) + .form(¶ms) + .send() + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("HTTP request failed: {e}")))?; + let status = response.status(); + let body = response.text().await.map_err(|e| { + io::Error::new(io::ErrorKind::Other, format!("Failed to read response body: {e}")) + })?; + + if status.is_success() { + let token: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let access_token = token["access_token"] + .as_str() + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "access_token missing from device token response", + ) + })? + .to_string(); + let refresh_token = token["refresh_token"].as_str().map(String::from); + let expires_at = token["expires_in"].as_u64().map(|secs| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + secs + }); + let scopes = token["scope"] + .as_str() + .map(|s| s.split(' ').map(String::from).collect()) + .unwrap_or_default(); + let id_token = token["id_token"].as_str().map(String::from); + return Ok(Some(OAuthTokenSet { + access_token, + refresh_token, + expires_at, + scopes, + id_token, + })); + } + + // Parse OAuth error response + let error_json: serde_json::Value = serde_json::from_str(&body).unwrap_or_default(); + let error = error_json["error"].as_str().unwrap_or("unknown"); + + match error { + "authorization_pending" => Ok(None), + "slow_down" => { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + Ok(None) + } + "expired_token" => Err(io::Error::new( + io::ErrorKind::Other, + "Device authorization expired. Please try again.", + )), + "access_denied" => Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "User denied authorization.", + )), + _ => Err(io::Error::new( + io::ErrorKind::Other, + format!("Device token error: {error}: {body}"), + )), + } +} + +pub fn parse_oauth_callback_request_target( + target: &str, + expected_path: &str, +) -> Result { let (path, query) = target .split_once('?') .map_or((target, ""), |(path, query)| (path, query)); - if path != "/callback" { - return Err(format!("unexpected callback path: {path}")); + if path != expected_path { + return Err(format!("unexpected callback path: {path}, expected: {expected_path}")); } parse_oauth_callback_query(query) } @@ -410,6 +773,110 @@ fn base64url_encode(bytes: &[u8]) -> String { output } +/// Extract `chatgpt_account_id` from a JWT payload (id_token or access_token). +/// No signature verification — just base64-decode the payload section. +pub fn extract_chatgpt_account_id(token: &str) -> Option { + let payload_b64 = token.split('.').nth(1)?; + let payload_json = base64_decode_urlsafe(payload_b64).ok()?; + let payload: serde_json::Value = serde_json::from_slice(&payload_json).ok()?; + + // 3-level fallback, matching Codex CLI behaviour: + // 1. top-level `chatgpt_account_id` + if let Some(id) = payload.get("chatgpt_account_id").and_then(|v| v.as_str()) { + return Some(id.to_string()); + } + // 2. `https://api.openai.com/auth` -> `chatgpt_account_id` + if let Some(auth) = payload.get("https://api.openai.com/auth").and_then(|v| v.as_object()) { + if let Some(id) = auth.get("chatgpt_account_id").and_then(|v| v.as_str()) { + return Some(id.to_string()); + } + } + // 3. `organizations[0].id` + if let Some(orgs) = payload.get("organizations").and_then(|v| v.as_array()) { + if let Some(first) = orgs.first() { + if let Some(id) = first.get("id").and_then(|v| v.as_str()) { + return Some(id.to_string()); + } + } + } + None +} + +fn base64_decode_urlsafe(input: &str) -> Result, String> { + let padded = match input.len() % 4 { + 0 => input.to_string(), + 2 => format!("{input}=="), + 3 => format!("{input}="), + _ => return Err("invalid base64url length".to_string()), + }; + let standard = padded.replace('-', "+").replace('_', "/"); + base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &standard) + .map_err(|e| e.to_string()) +} + +/// Refresh an OAuth token set using the refresh_token grant. +/// Returns a new token set on success. +pub async fn refresh_oauth_token( + client: &reqwest::Client, + token_url: &str, + client_id: &str, + refresh_token: &str, +) -> io::Result { + let params = [ + ("grant_type", "refresh_token"), + ("client_id", client_id), + ("refresh_token", refresh_token), + ]; + + let response = client + .post(token_url) + .form(¶ms) + .send() + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("HTTP request failed: {e}")))?; + + let status = response.status(); + let body = response.text().await.map_err(|e| { + io::Error::new(io::ErrorKind::Other, format!("Failed to read response body: {e}")) + })?; + + if !status.is_success() { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("Token refresh failed ({status}): {body}"), + )); + } + + let token: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let access_token = token["access_token"] + .as_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "access_token missing from refresh response"))? + .to_string(); + let refresh_token = token["refresh_token"].as_str().map(String::from); + let id_token = token["id_token"].as_str().map(String::from); + let expires_at = token["expires_in"].as_u64().map(|secs| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + secs + }); + let scopes = token["scope"] + .as_str() + .map(|s| s.split(' ').map(String::from).collect()) + .unwrap_or_default(); + + Ok(OAuthTokenSet { + access_token, + refresh_token, + expires_at, + scopes, + id_token, + }) +} + fn percent_encode(value: &str) -> String { let mut encoded = String::new(); for byte in value.bytes() { @@ -465,10 +932,12 @@ mod tests { use std::time::{SystemTime, UNIX_EPOCH}; use super::{ - clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair, - generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query, - parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest, - OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, + clear_oauth_credentials, clear_provider_oauth, code_challenge_s256, credentials_path, + generate_pkce_pair, generate_state, load_oauth_credentials, load_provider_oauth, + loopback_redirect_uri, parse_oauth_callback_query, parse_oauth_callback_request_target, + run_oauth_callback_server, save_oauth_credentials, save_provider_oauth, html_escape, + OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, + OAuthTokenSet, }; fn sample_config() -> OAuthConfig { @@ -565,6 +1034,7 @@ mod tests { refresh_token: Some("refresh-token".to_string()), expires_at: Some(123), scopes: vec!["scope:a".to_string()], + id_token: None, }; save_oauth_credentials(&token_set).expect("save credentials"); assert_eq!( @@ -594,10 +1064,184 @@ mod tests { assert_eq!(params.state.as_deref(), Some("state-1")); assert_eq!(params.error_description.as_deref(), Some("needs login")); - let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz") + let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz", "/callback") .expect("parse callback target"); assert_eq!(params.code.as_deref(), Some("abc")); assert_eq!(params.state.as_deref(), Some("xyz")); - assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err()); + assert!(parse_oauth_callback_request_target("/wrong?code=abc", "/callback").is_err()); + } + + #[test] + fn provider_oauth_credentials_round_trip_and_clear() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + let path = credentials_path().expect("credentials path"); + std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent"); + + let openai_tokens = OAuthTokenSet { + access_token: "openai-access".to_string(), + refresh_token: Some("openai-refresh".to_string()), + expires_at: Some(1000), + scopes: vec!["openid".to_string()], + id_token: None, + }; + let moonshot_tokens = OAuthTokenSet { + access_token: "moonshot-access".to_string(), + refresh_token: None, + expires_at: Some(2000), + scopes: vec!["profile".to_string()], + id_token: None, + }; + + save_provider_oauth("openai", &openai_tokens).expect("save openai"); + save_provider_oauth("moonshot", &moonshot_tokens).expect("save moonshot"); + + assert_eq!( + load_provider_oauth("openai").expect("load openai"), + Some(openai_tokens.clone()) + ); + assert_eq!( + load_provider_oauth("moonshot").expect("load moonshot"), + Some(moonshot_tokens.clone()) + ); + assert_eq!( + load_provider_oauth("unknown").expect("load unknown"), + None + ); + + let saved = std::fs::read_to_string(&path).expect("read saved file"); + assert!(saved.contains("\"oauth_providers\"")); + assert!(saved.contains("\"openai\"")); + assert!(saved.contains("\"moonshot\"")); + + clear_provider_oauth("openai").expect("clear openai"); + assert_eq!(load_provider_oauth("openai").expect("load cleared"), None); + assert_eq!( + load_provider_oauth("moonshot").expect("load moonshot after clear"), + Some(moonshot_tokens) + ); + + clear_provider_oauth("moonshot").expect("clear moonshot"); + let cleared = std::fs::read_to_string(&path).expect("read cleared file"); + assert!(!cleared.contains("\"oauth_providers\"")); + + std::env::remove_var("CLAW_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn provider_oauth_preserves_legacy_oauth_key() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + let path = credentials_path().expect("credentials path"); + std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent"); + + let legacy = OAuthTokenSet { + access_token: "legacy-access".to_string(), + refresh_token: Some("legacy-refresh".to_string()), + expires_at: Some(999), + scopes: vec!["org:read".to_string()], + id_token: None, + }; + save_oauth_credentials(&legacy).expect("save legacy"); + + let provider = OAuthTokenSet { + access_token: "provider-access".to_string(), + refresh_token: None, + expires_at: Some(888), + scopes: vec!["user:read".to_string()], + id_token: None, + }; + save_provider_oauth("openai", &provider).expect("save provider"); + + assert_eq!( + load_oauth_credentials().expect("load legacy"), + Some(legacy) + ); + assert_eq!( + load_provider_oauth("openai").expect("load provider"), + Some(provider) + ); + + std::env::remove_var("CLAW_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn callback_server_returns_code_and_state() { + use std::io::{Read, Write}; + use std::net::TcpStream; + use std::thread; + + let port = 4547; + let server_thread = thread::spawn(move || { + run_oauth_callback_server(port, std::time::Duration::from_secs(5), "/callback") + }); + + // Give the server a moment to bind + thread::sleep(std::time::Duration::from_millis(100)); + + // Simulate browser callback + let request = format!( + "GET /callback?code=test-code-123&state=test-state-456 HTTP/1.1\r\nHost: localhost:{port}\r\n\r\n" + ); + let mut stream = TcpStream::connect(format!("127.0.0.1:{port}")).expect("connect to callback server"); + stream.write_all(request.as_bytes()).expect("send request"); + stream.flush().expect("flush"); + + // Read response (should be HTML success page) + let mut response = String::new(); + stream.read_to_string(&mut response).expect("read response"); + assert!(response.contains("200 OK"), "expected 200 OK, got: {response}"); + assert!(response.contains("Authentication Successful"), "expected success page"); + + let result = server_thread.join().expect("server thread join"); + let callback = result.expect("callback result"); + assert_eq!(callback.code, "test-code-123"); + assert_eq!(callback.state, "test-state-456"); + } + + #[test] + fn callback_server_returns_error_on_oauth_error() { + use std::io::{Read, Write}; + use std::net::TcpStream; + use std::thread; + + let port = 4548; + let server_thread = thread::spawn(move || { + run_oauth_callback_server(port, std::time::Duration::from_secs(5), "/callback") + }); + + thread::sleep(std::time::Duration::from_millis(100)); + + let request = format!( + "GET /callback?error=access_denied&error_description=user%20denied HTTP/1.1\r\nHost: localhost:{port}\r\n\r\n" + ); + let mut stream = TcpStream::connect(format!("127.0.0.1:{port}")).expect("connect"); + stream.write_all(request.as_bytes()).expect("send"); + stream.flush().expect("flush"); + + let mut response = String::new(); + stream.read_to_string(&mut response).expect("read"); + assert!(response.contains("400 Bad Request"), "expected 400, got: {response}"); + + let result = server_thread.join().expect("join"); + assert!(result.is_err(), "expected error for OAuth error response"); + } + + #[test] + fn callback_server_times_out_when_no_request() { + let port = 4549; + let result = run_oauth_callback_server(port, std::time::Duration::from_millis(50), "/callback"); + assert!(result.is_err(), "expected timeout error"); + } + + #[test] + fn html_escape_works() { + assert_eq!(html_escape(""), "<script>alert('xss')</script>"); + assert_eq!(html_escape("foo & bar"), "foo & bar"); + assert_eq!(html_escape("\"quoted\""), ""quoted""); } } diff --git a/rust/crates/rusty-claude-cli/Cargo.toml b/rust/crates/rusty-claude-cli/Cargo.toml index 635fdb32f7..a1e8fba543 100644 --- a/rust/crates/rusty-claude-cli/Cargo.toml +++ b/rust/crates/rusty-claude-cli/Cargo.toml @@ -15,6 +15,7 @@ commands = { path = "../commands" } compat-harness = { path = "../compat-harness" } crossterm = "0.28" pulldown-cmark = "0.13" +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } rustyline = "15" runtime = { path = "../runtime" } plugins = { path = "../plugins" } diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index dbdbd07b64..273a527785 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -24,9 +24,10 @@ use std::thread::{self, JoinHandle}; use std::time::{Duration, Instant, UNIX_EPOCH}; use api::{ - detect_provider_kind, resolve_startup_auth_source, AnthropicClient, AuthSource, - ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest, MessageResponse, - OutputContentBlock, PromptCache, ProviderClient as ApiProviderClient, ProviderKind, + anthropic_has_auth, detect_provider_kind, has_api_key, metadata_for_model, + resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, + InputMessage, MessageRequest, MessageResponse, OutputContentBlock, PromptCache, + ProviderClient as ApiProviderClient, ProviderKind, ProviderMetadata, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; @@ -96,6 +97,21 @@ struct ModelProvenance { source: ModelSource, } +#[derive(Debug, Clone)] +struct ConfiguredModelProvider { + wire_model: String, + provider_type: String, + api_key: String, + base_url: String, + /// Full OAuth token set when auth came from saved OAuth credentials. + /// Enables automatic token refresh for custom providers. + oauth_token_set: Option, + /// OAuth token URL for refreshing the access token. + oauth_token_url: Option, + /// OAuth client ID for refreshing the access token. + oauth_client_id: Option, +} + impl ModelProvenance { fn default_fallback() -> Self { Self { @@ -364,6 +380,11 @@ fn run() -> Result<(), Box> { output_format, } => print_system_prompt(cwd, date, output_format)?, CliAction::Version { output_format } => print_version(output_format)?, + CliAction::Login => { + if let Some(model) = run_login_wizard()? { + println!("Configured provider. Use `claw --model {model}` or `/model {model}`."); + } + } CliAction::ResumeSession { session_path, commands, @@ -407,7 +428,16 @@ fn run() -> Result<(), Box> { None }; let effective_prompt = merge_prompt_with_stdin(&prompt, stdin_context.as_deref()); - let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode)?; + let resolved_model = resolve_model_alias_with_config(&model); + let final_model = if !check_model_auth_available(&resolved_model)? { + match run_provider_welcome(&resolved_model)? { + Some(new_model) => new_model, + None => return Ok(()), + } + } else { + resolved_model + }; + let mut cli = LiveCli::new(final_model, true, allowed_tools, permission_mode)?; cli.set_reasoning_effort(reasoning_effort); cli.run_turn_with_output(&effective_prompt, output_format, compact)?; } @@ -464,6 +494,8 @@ fn run() -> Result<(), Box> { reasoning_effort, allow_broad_cwd, )?, + CliAction::Auth { provider } => run_auth_command(provider.as_deref())?, + CliAction::Model { name } => run_model_command(name.as_deref())?, CliAction::HelpTopic(topic) => print_help_topic(topic), CliAction::Help { output_format } => print_help(output_format)?, } @@ -504,6 +536,7 @@ enum CliAction { Version { output_format: CliOutputFormat, }, + Login, ResumeSession { session_path: PathBuf, commands: Vec, @@ -567,6 +600,12 @@ enum CliAction { reasoning_effort: Option, allow_broad_cwd: bool, }, + Auth { + provider: Option, + }, + Model { + name: Option, + }, HelpTopic(LocalHelpTopic), // prompt-mode formatting is only supported for non-interactive runs Help { @@ -948,7 +987,36 @@ fn parse_args(args: &[String]) -> Result { } "system-prompt" => parse_system_prompt_args(&rest[1..], output_format), "acp" => parse_acp_args(&rest[1..], output_format), - "login" | "logout" => Err(removed_auth_surface_error(rest[0].as_str())), + "login" => Ok(CliAction::Login), + "logout" => Err(removed_auth_surface_error(rest[0].as_str())), + "auth" => { + if rest.len() == 2 && is_help_flag(&rest[1]) { + return Ok(CliAction::Help { output_format }); + } + let provider = rest.get(1).cloned(); + if rest.len() > 2 { + return Err(format!( + "unexpected extra arguments after `claw auth {}`: {}", + provider.as_deref().unwrap_or(""), + rest[2..].join(" ") + )); + } + Ok(CliAction::Auth { provider }) + } + "model" | "models" => { + if rest.len() == 2 && is_help_flag(&rest[1]) { + return Ok(CliAction::Help { output_format }); + } + let name = rest.get(1).cloned(); + if rest.len() > 2 { + return Err(format!( + "unexpected extra arguments after `claw model {}`: {}", + name.as_deref().unwrap_or(""), + rest[2..].join(" ") + )); + } + Ok(CliAction::Model { name }) + } "init" => Ok(CliAction::Init { output_format }), "export" => parse_export_args(&rest[1..], output_format), "prompt" => { @@ -1108,11 +1176,11 @@ fn parse_single_word_command_alias( "sandbox" => Some(Ok(CliAction::Sandbox { output_format })), "doctor" => Some(Ok(CliAction::Doctor { output_format })), "state" => Some(Ok(CliAction::State { output_format })), - // #146: let `config` and `diff` fall through to parse_subcommand + // #146: let `config`, `diff`, and `model` fall through to parse_subcommand // where they are wired as pure-local introspection, instead of // producing the "is a slash command" guidance. Zero-arg cases // reach parse_subcommand too via this None. - "config" | "diff" => None, + "config" | "diff" | "model" | "models" => None, other => bare_slash_command_guidance(other).map(Err), } } @@ -1339,6 +1407,7 @@ fn suggest_similar_subcommand(input: &str) -> Option> { "init", "export", "prompt", + "auth", ]; let normalized_input = input.to_ascii_lowercase(); @@ -1443,9 +1512,482 @@ fn resolve_model_alias_with_config(model: &str) -> String { if let Some(resolved) = config_alias_for_current_dir(trimmed) { return resolve_model_alias(&resolved).to_string(); } + if let Some(resolved) = config_provider_default_model_for_current_dir(trimmed) { + return resolved; + } resolve_model_alias(trimmed).to_string() } +fn config_provider_default_model_for_current_dir(provider_name: &str) -> Option { + if provider_name.is_empty() || provider_name.contains('/') { + return None; + } + let cwd = env::current_dir().ok()?; + let loader = ConfigLoader::default_for(&cwd); + let config = loader.load().ok()?; + let provider = config.model_providers().get(provider_name)?; + let model = provider + .default_model() + .or_else(|| provider.models().first().map(String::as_str))?; + Some(format!("{provider_name}/{model}")) +} + +fn configured_provider_names_for_current_dir() -> Vec { + let Ok(cwd) = env::current_dir() else { + return Vec::new(); + }; + let loader = ConfigLoader::default_for(&cwd); + loader + .load() + .map(|config| config.model_providers().keys().cloned().collect()) + .unwrap_or_default() +} + +fn configured_provider_for_model( + model: &str, +) -> Result, Box> { + let Some((provider_name, requested_model)) = model.split_once('/') else { + return Ok(None); + }; + let cwd = env::current_dir()?; + let config = ConfigLoader::default_for(&cwd).load()?; + let Some(provider) = config.model_providers().get(provider_name) else { + return Ok(None); + }; + if !matches!( + provider.provider_type(), + "openai-compatible" | "openai" | "anthropic-compatible" | "anthropic" + ) { + return Err(format!( + "model provider '{provider_name}' uses unsupported type '{}'", + provider.provider_type() + ) + .into()); + } + let wire_model = if requested_model.is_empty() { + provider.default_model().ok_or_else(|| { + format!("model provider '{provider_name}' does not define defaultModel") + })? + } else { + requested_model + }; + if !provider.models().is_empty() && !provider.models().iter().any(|model| model == wire_model) { + return Err(format!( + "model '{wire_model}' is not listed in modelProviders.{provider_name}.models" + ) + .into()); + } + let (api_key, base_url, oauth_token_set, oauth_token_url, oauth_client_id) = + if let Some(api_key) = provider.api_key().filter(|value| !value.is_empty()) { + ( + api_key.to_string(), + provider.base_url().to_string(), + None, + None, + None, + ) + } else if let Some(env_name) = provider.api_key_env() { + if let Ok(key) = env::var(env_name) { + ( + key, + provider.base_url().to_string(), + None, + None, + None, + ) + } else if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_name) { + // Fall back to saved OAuth bearer token when env var is unset. + // For OpenAI, OAuth tokens are WHAM-backend tokens, not Platform API tokens. + // Override the base URL to the WHAM backend. + let base_url = if provider_name == "openai" { + api::DEFAULT_WHAM_BASE_URL.to_string() + } else { + provider.base_url().to_string() + }; + // Look up OAuth refresh config from login templates so custom + // providers that match built-in templates get automatic refresh. + let template_oauth = LOGIN_PROVIDER_TEMPLATES + .iter() + .find(|t| t.id == provider_name) + .and_then(|t| t.oauth.as_ref()) + .or_else(|| { + LOGIN_PROVIDER_TEMPLATES + .iter() + .find(|t| t.base_url == provider.base_url()) + .and_then(|t| t.oauth.as_ref()) + }); + let (token_url, client_id) = match template_oauth { + Some(oauth) => { + let token_url = match oauth.flow { + OAuthFlowType::Device { token_url, .. } => token_url, + OAuthFlowType::Pkce { token_url, .. } => token_url, + }; + (Some(token_url.to_string()), Some(oauth.client_id.to_string())) + } + None => (None, None), + }; + ( + token_set.access_token.clone(), + base_url, + Some(token_set), + token_url, + client_id, + ) + } else { + return Err(format!( + "model provider '{provider_name}' requires env var {env_name}" + ) + .into()); + } + } else { + return Err( + format!("model provider '{provider_name}' requires apiKeyEnv or apiKey").into(), + ); + }; + Ok(Some(ConfiguredModelProvider { + wire_model: wire_model.to_string(), + provider_type: provider.provider_type().to_string(), + api_key, + base_url, + oauth_token_set, + oauth_token_url, + oauth_client_id, + })) +} + +struct LoginProviderTemplate { + id: &'static str, + label: &'static str, + provider_type: &'static str, + base_url: &'static str, + api_key_env: &'static str, + models: &'static [&'static str], + default_model: &'static str, + oauth: Option, +} + +const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ + LoginProviderTemplate { + id: "zai", + label: "Z.AI", + provider_type: "openai-compatible", + base_url: "https://api.z.ai/api/paas/v4", + api_key_env: "Z_AI_API_KEY", + models: &[ + "glm-5.1", + "glm-5", + "glm-5-turbo", + "glm-4.7", + "glm-4.7-flashx", + "glm-4.7-flash", + "glm-4.6", + "glm-4.5", + "glm-4.5-x", + "glm-4.5-air", + "glm-4.5-airx", + "glm-4.5-flash", + "glm-4-32b-0414-128k", + ], + default_model: "glm-5.1", + oauth: None, + }, + LoginProviderTemplate { + id: "zai-coding-plan", + label: "Z.AI Coding Plan", + provider_type: "openai-compatible", + base_url: "https://api.z.ai/api/coding/paas/v4", + api_key_env: "Z_AI_API_KEY", + models: &[ + "glm-4.5-air", + "glm-4.7", + "glm-5-turbo", + "glm-5.1", + "glm-5v-turbo", + ], + default_model: "glm-5.1", + oauth: None, + }, + LoginProviderTemplate { + id: "minimax-coding-plan", + label: "MiniMax Coding Plan", + provider_type: "anthropic-compatible", + base_url: "https://api.minimax.io/anthropic/v1", + api_key_env: "MINIMAX_API_KEY", + models: &[ + "MiniMax-M2", + "MiniMax-M2.1", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + ], + default_model: "MiniMax-M2.7-highspeed", + oauth: None, + }, + LoginProviderTemplate { + id: "kimi-for-coding", + label: "Kimi For Coding", + provider_type: "anthropic-compatible", + base_url: "https://api.kimi.com/coding/v1", + api_key_env: "KIMI_API_KEY", + models: &["k2p5", "k2p6", "kimi-k2-thinking"], + default_model: "k2p6", + oauth: None, + }, + LoginProviderTemplate { + id: "moonshot", + label: "Moonshot / Kimi", + provider_type: "openai-compatible", + base_url: "https://api.moonshot.ai/v1", + api_key_env: "MOONSHOT_API_KEY", + models: &[ + "kimi-k2.6", + "kimi-k2.5", + "kimi-k2-0905-preview", + "kimi-k2-0711-preview", + "kimi-k2-turbo-preview", + "kimi-k2-thinking", + "kimi-k2-thinking-turbo", + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", + "moonshot-v1-8k-vision-preview", + "moonshot-v1-32k-vision-preview", + "moonshot-v1-128k-vision-preview", + ], + default_model: "kimi-k2.6", + oauth: Some(ProviderOAuthConfig { + client_id: "17e5f671-d194-4dfb-9706-5516cb48c098", + callback_port: 4546, + redirect_path: "/callback", + flow: OAuthFlowType::Device { + device_auth_url: "https://auth.kimi.com/api/oauth/device_authorization", + token_url: "https://auth.kimi.com/api/oauth/token", + scopes: &["openid", "profile", "email"], + }, + }), + }, +]; + +fn run_login_wizard() -> Result, Box> { + if !io::stdin().is_terminal() { + return Err("login requires an interactive terminal".into()); + } + + println!(); + println!("Claw provider login"); + println!("Configure a model provider profile."); + println!("Press Enter to accept defaults."); + println!(); + + for (index, provider) in LOGIN_PROVIDER_TEMPLATES.iter().enumerate() { + println!(" [{}] {}", index + 1, provider.label); + } + println!( + " [{}] Custom compatible endpoint", + LOGIN_PROVIDER_TEMPLATES.len() + 1 + ); + + let choice = read_prompt("Select provider [1]: ")?; + let choice = if choice.trim().is_empty() { + 1 + } else { + choice.trim().parse::()? + }; + + let ( + provider_id, + label, + provider_type, + default_base_url, + default_api_key_env, + default_models, + default_model, + ) = if choice == LOGIN_PROVIDER_TEMPLATES.len() + 1 { + let id = read_required_prompt("Provider id (e.g. openrouter): ")?; + let provider_type = read_prompt( + "Provider type [openai-compatible, anthropic-compatible] [openai-compatible]: ", + )?; + let provider_type = if provider_type.trim().is_empty() { + "openai-compatible".to_string() + } else { + provider_type.trim().to_string() + }; + if !matches!( + provider_type.as_str(), + "openai-compatible" | "openai" | "anthropic-compatible" | "anthropic" + ) { + return Err(format!("unsupported provider type: {provider_type}").into()); + } + let base_url = read_required_prompt("Base URL: ")?; + let api_key_env = read_prompt("API key env var [OPENAI_API_KEY]: ")?; + let model = read_required_prompt("Default model: ")?; + ( + id, + "Custom".to_string(), + provider_type, + base_url, + if api_key_env.trim().is_empty() { + "OPENAI_API_KEY".to_string() + } else { + api_key_env.trim().to_string() + }, + vec![model.clone()], + model, + ) + } else { + let template = LOGIN_PROVIDER_TEMPLATES + .get(choice.saturating_sub(1)) + .ok_or_else(|| format!("invalid provider choice: {choice}"))?; + ( + template.id.to_string(), + template.label.to_string(), + template.provider_type.to_string(), + template.base_url.to_string(), + template.api_key_env.to_string(), + template + .models + .iter() + .map(|model| (*model).to_string()) + .collect::>(), + template.default_model.to_string(), + ) + }; + + println!(); + println!("{label}"); + println!("Provider type: {provider_type}"); + let base_url = read_prompt(&format!("Base URL [{default_base_url}]: "))?; + let base_url = if base_url.trim().is_empty() { + default_base_url + } else { + base_url.trim().to_string() + }; + let api_key_env = read_prompt(&format!("API key env var [{default_api_key_env}]: "))?; + let api_key_env = if api_key_env.trim().is_empty() { + default_api_key_env + } else { + api_key_env.trim().to_string() + }; + + let token = read_prompt("Paste API key / bearer token now, or press Enter to use env var: ")?; + let api_key = (!token.trim().is_empty()).then(|| token.trim().to_string()); + + println!("Available models: {}", default_models.join(", ")); + let model = read_prompt(&format!("Default model [{default_model}]: "))?; + let model = if model.trim().is_empty() { + default_model + } else { + model.trim().to_string() + }; + let mut models = default_models; + if !models.iter().any(|known| known == &model) { + models.push(model.clone()); + } + + save_model_provider_profile( + &provider_id, + &provider_type, + &base_url, + &api_key_env, + api_key.as_deref(), + &models, + &model, + )?; + Ok(Some(format!("{provider_id}/{model}"))) +} + +fn read_prompt(prompt: &str) -> Result> { + print!("{prompt}"); + io::stdout().flush()?; + let mut buffer = String::new(); + io::stdin().read_line(&mut buffer)?; + Ok(buffer) +} + +fn read_required_prompt(prompt: &str) -> Result> { + let value = read_prompt(prompt)?; + let value = value.trim(); + if value.is_empty() { + return Err(format!("{prompt} is required").into()); + } + Ok(value.to_string()) +} + +fn save_model_provider_profile( + provider_id: &str, + provider_type: &str, + base_url: &str, + api_key_env: &str, + api_key: Option<&str>, + models: &[String], + default_model: &str, +) -> Result<(), Box> { + let cwd = env::current_dir()?; + let config_home = ConfigLoader::default_for(&cwd).config_home().to_path_buf(); + fs::create_dir_all(&config_home)?; + let settings_path = config_home.join("settings.json"); + let mut root = match fs::read_to_string(&settings_path) { + Ok(contents) if !contents.trim().is_empty() => serde_json::from_str::(&contents)?, + Ok(_) => Value::Object(Map::new()), + Err(error) if error.kind() == io::ErrorKind::NotFound => Value::Object(Map::new()), + Err(error) => return Err(error.into()), + }; + if !root.is_object() { + root = Value::Object(Map::new()); + } + let root_object = root.as_object_mut().expect("root object initialized"); + let providers = root_object + .entry("modelProviders") + .or_insert_with(|| Value::Object(Map::new())); + if !providers.is_object() { + *providers = Value::Object(Map::new()); + } + let provider_map = providers + .as_object_mut() + .expect("modelProviders object initialized"); + + let mut provider = Map::new(); + provider.insert("type".to_string(), Value::String(provider_type.to_string())); + provider.insert("baseUrl".to_string(), Value::String(base_url.to_string())); + provider.insert( + "apiKeyEnv".to_string(), + Value::String(api_key_env.to_string()), + ); + if let Some(api_key) = api_key { + provider.insert("apiKey".to_string(), Value::String(api_key.to_string())); + } + provider.insert( + "models".to_string(), + Value::Array( + models + .iter() + .map(|model| Value::String(model.clone())) + .collect(), + ), + ); + provider.insert( + "defaultModel".to_string(), + Value::String(default_model.to_string()), + ); + provider_map.insert(provider_id.to_string(), Value::Object(provider)); + root_object.insert( + "model".to_string(), + Value::String(format!("{provider_id}/{default_model}")), + ); + + let serialized = format!("{}\n", serde_json::to_string_pretty(&root)?); + fs::write(&settings_path, serialized)?; + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut permissions = fs::metadata(&settings_path)?.permissions(); + permissions.set_mode(0o600); + fs::set_permissions(&settings_path, permissions)?; + } + Ok(()) +} + /// Validate model syntax at parse time. /// Accepts: known aliases (opus, sonnet, haiku) or provider/model pattern. /// Rejects: empty, whitespace-only, strings with spaces, or invalid chars. @@ -1459,6 +2001,12 @@ fn validate_model_syntax(model: &str) -> Result<(), String> { "opus" | "sonnet" | "haiku" => return Ok(()), _ => {} } + if configured_provider_names_for_current_dir() + .iter() + .any(|name| name == trimmed) + { + return Ok(()); + } // Check for spaces (malformed) if trimmed.contains(' ') { return Err(format!( @@ -2117,7 +2665,7 @@ fn check_auth_health() -> DiagnosticCheck { token_set.scopes.join(",") } ), - "Suggested action set ANTHROPIC_API_KEY or ANTHROPIC_AUTH_TOKEN; `claw login` is removed" + "Suggested action set ANTHROPIC_API_KEY or ANTHROPIC_AUTH_TOKEN for Anthropic, or run `claw login` to configure a compatible provider" .to_string(), ]) .with_data(Map::from_iter([ @@ -3772,6 +4320,21 @@ fn run_repl( enforce_broad_cwd_policy(allow_broad_cwd, CliOutputFormat::Text)?; run_stale_base_preflight(base_commit.as_deref()); let resolved_model = resolve_repl_model(model); + if !check_model_auth_available(&resolved_model)? { + match run_provider_welcome(&resolved_model)? { + Some(new_model) => { + return run_repl( + new_model, + allowed_tools, + permission_mode, + base_commit, + reasoning_effort, + allow_broad_cwd, + ); + } + None => return Ok(()), + } + } let mut cli = LiveCli::new(resolved_model, true, allowed_tools, permission_mode)?; cli.set_reasoning_effort(reasoning_effort); let mut editor = @@ -4713,8 +5276,13 @@ impl LiveCli { println!("{}", format_cost_report(usage)); false } - SlashCommand::Login - | SlashCommand::Logout + SlashCommand::Login => { + if let Some(model) = run_login_wizard()? { + self.set_model(Some(model))?; + } + false + } + SlashCommand::Logout | SlashCommand::Vim | SlashCommand::Upgrade | SlashCommand::Share @@ -7527,6 +8095,10 @@ fn build_runtime_with_plugin_state( plugin_registry, mcp_state, } = runtime_plugin_state; + let configured_provider = configured_provider_for_model(&model)?; + let request_model = configured_provider + .as_ref() + .map_or_else(|| model.clone(), |provider| provider.wire_model.clone()); plugin_registry.initialize()?; let policy = permission_policy(permission_mode, &feature_config, &tool_registry) .map_err(std::io::Error::other)?; @@ -7534,7 +8106,8 @@ fn build_runtime_with_plugin_state( session, AnthropicRuntimeClient::new( session_id, - model, + request_model, + configured_provider, enable_tools, emit_output, allowed_tools.clone(), @@ -7662,6 +8235,7 @@ impl AnthropicRuntimeClient { fn new( session_id: &str, model: String, + configured_provider: Option, enable_tools: bool, emit_output: bool, allowed_tools: Option, @@ -7688,26 +8262,93 @@ impl AnthropicRuntimeClient { // prompt cache is Anthropic-only so non-Anthropic variants // skip it. let resolved_model = api::resolve_model_alias(&model); - let client = match detect_provider_kind(&resolved_model) { - ProviderKind::Anthropic => { - let auth = resolve_cli_auth_source()?; - let inner = AnthropicClient::from_auth(auth) - .with_base_url(api::read_base_url()) - .with_prompt_cache(PromptCache::new(session_id)); - ApiProviderClient::Anthropic(inner) - } - ProviderKind::Xai | ProviderKind::OpenAi => { - // The api crate's `ProviderClient::from_model_with_anthropic_auth` - // with `None` for the anthropic auth routes via - // `detect_provider_kind` and builds an - // `OpenAiCompatClient::from_env` with the matching - // `OpenAiCompatConfig` (openai / xai / dashscope). - // That reads the correct API-key env var and BASE_URL - // override internally, so this one call covers OpenAI, - // OpenRouter, xAI, DashScope, Ollama, and any other - // OpenAI-compat endpoint users configure via - // `OPENAI_BASE_URL` / `XAI_BASE_URL` / `DASHSCOPE_BASE_URL`. - ApiProviderClient::from_model_with_anthropic_auth(&resolved_model, None)? + let client = if let Some(provider) = configured_provider { + match provider.provider_type.as_str() { + "anthropic-compatible" | "anthropic" => { + ApiProviderClient::from_anthropic_compatible_profile( + provider.api_key, + provider.base_url, + ) + .with_prompt_cache(PromptCache::new(session_id)) + } + "openai-compatible" | "openai" => { + // Kimi For Coding requires a whitelisted User-Agent. + let user_agent = if provider.base_url.contains("api.kimi.com") { + Some("claude-code/0.1.0") + } else { + None + }; + let apply_ua = |client: ApiProviderClient| { + match user_agent { + Some(ua) => client.with_user_agent(ua), + None => client, + } + }; + // Route to WhamClient when using ChatGPT OAuth (WHAM backend). + if provider.base_url.contains("backend-api/wham") { + if let Ok(Some(token_set)) = runtime::load_provider_oauth("openai") { + let account_id = token_set + .id_token + .as_deref() + .and_then(runtime::extract_chatgpt_account_id) + .or_else(|| { + runtime::extract_chatgpt_account_id(&token_set.access_token) + }); + ApiProviderClient::Wham(api::WhamClient::from_oauth_token_set( + token_set, + account_id, + "https://auth.openai.com/oauth/token", + "app_EMoamEEZ73f0CkXaXp7hrann", + )) + } else { + apply_ua(ApiProviderClient::from_openai_compatible_profile( + provider.api_key, + provider.base_url, + )) + } + } else if let (Some(token_set), Some(token_url), Some(client_id)) = + (provider.oauth_token_set, provider.oauth_token_url, provider.oauth_client_id) + { + // Custom provider with OAuth: use auto-refreshing client. + apply_ua(ApiProviderClient::from_openai_compatible_oauth( + provider.base_url, + token_set, + token_url, + client_id, + )) + } else { + apply_ua(ApiProviderClient::from_openai_compatible_profile( + provider.api_key, + provider.base_url, + )) + } + } + other => { + return Err(format!("unsupported provider type: {other}").into()); + } + } + } else { + match detect_provider_kind(&resolved_model) { + ProviderKind::Anthropic => { + let auth = resolve_cli_auth_source()?; + let inner = AnthropicClient::from_auth(auth) + .with_base_url(api::read_base_url()) + .with_prompt_cache(PromptCache::new(session_id)); + ApiProviderClient::Anthropic(inner) + } + ProviderKind::Xai | ProviderKind::OpenAi => { + // The api crate's `ProviderClient::from_model_with_anthropic_auth` + // with `None` for the anthropic auth routes via + // `detect_provider_kind` and builds an + // `OpenAiCompatClient::from_env` with the matching + // `OpenAiCompatConfig` (openai / xai / dashscope). + // That reads the correct API-key env var and BASE_URL + // override internally, so this one call covers OpenAI, + // OpenRouter, xAI, DashScope, Ollama, and any other + // OpenAI-compat endpoint users configure via + // `OPENAI_BASE_URL` / `XAI_BASE_URL` / `DASHSCOPE_BASE_URL`. + ApiProviderClient::from_model_with_anthropic_auth(&resolved_model, None)? + } } }; Ok(Self { @@ -7737,61 +8378,833 @@ fn resolve_cli_auth_source_for_cwd() -> Result { resolve_startup_auth_source(|| Ok(None)) } -impl ApiClient for AnthropicRuntimeClient { - #[allow(clippy::too_many_lines)] - fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { - if let Some(progress_reporter) = &self.progress_reporter { - progress_reporter.mark_model_phase(); - } - let is_post_tool = request_ends_with_tool_result(&request); - let message_request = MessageRequest { - model: self.model.clone(), - max_tokens: max_tokens_for_model(&self.model), - messages: convert_messages(&request.messages), - system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), - tools: self - .enable_tools - .then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())), - tool_choice: self.enable_tools.then_some(ToolChoice::Auto), - stream: true, - reasoning_effort: self.reasoning_effort.clone(), - ..Default::default() - }; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum OAuthFlowType { + Pkce { + authorize_url: &'static str, + token_url: &'static str, + scopes: &'static [&'static str], + }, + Device { + device_auth_url: &'static str, + token_url: &'static str, + scopes: &'static [&'static str], + }, +} - self.runtime.block_on(async { - // When resuming after tool execution, apply a stall timeout on the - // first stream event. If the model does not respond within the - // deadline we drop the stalled connection and re-send the request as - // a continuation nudge (one retry only). - let max_attempts: usize = if is_post_tool { 2 } else { 1 }; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ProviderOAuthConfig { + client_id: &'static str, + callback_port: u16, + redirect_path: &'static str, + flow: OAuthFlowType, +} - for attempt in 1..=max_attempts { - let result = self - .consume_stream(&message_request, is_post_tool && attempt == 1) - .await; - match result { - Ok(events) => return Ok(events), - Err(error) - if error.to_string().contains("post-tool stall") - && attempt < max_attempts => - { - // Stalled after tool completion — nudge the model by - // re-sending the same request. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct BuiltinProvider { + id: &'static str, + label: &'static str, + env_var: &'static str, + default_model: &'static str, + oauth: Option, +} + +const BUILTIN_PROVIDERS: &[BuiltinProvider] = &[ + BuiltinProvider { + id: "anthropic", + label: "Anthropic (Claude)", + env_var: "ANTHROPIC_API_KEY", + default_model: "claude-opus-4-6", + oauth: None, + }, + BuiltinProvider { + id: "openai", + label: "OpenAI", + env_var: "OPENAI_API_KEY", + default_model: "gpt-4o", + oauth: Some(ProviderOAuthConfig { + client_id: "app_EMoamEEZ73f0CkXaXp7hrann", + callback_port: 1455, + redirect_path: "/auth/callback", + flow: OAuthFlowType::Pkce { + authorize_url: "https://auth.openai.com/oauth/authorize", + token_url: "https://auth.openai.com/oauth/token", + scopes: &["openid", "profile", "email", "offline_access"], + }, + }), + }, + BuiltinProvider { + id: "xai", + label: "xAI (Grok)", + env_var: "XAI_API_KEY", + default_model: "grok-3", + oauth: None, + }, +]; + +fn check_model_auth_available(model: &str) -> Result> { + let resolved = api::resolve_model_alias(model); + + // If model has a provider/ prefix, check auth for that specific provider. + // This works for ANY provider (built-in, template, or custom from .claw.json). + if let Some((provider_name, _)) = resolved.split_once('/') { + // Generic OAuth check: any provider may have a saved OAuth token. + if runtime::load_provider_oauth(provider_name).ok().flatten().is_some() { + return Ok(true); + } + // Check the provider's config from .claw.json for env var or hardcoded key + if let Ok(cwd) = env::current_dir() { + if let Ok(config) = ConfigLoader::default_for(&cwd).load() { + if let Some(provider) = config.model_providers().get(provider_name) { + if provider.api_key().filter(|k| !k.is_empty()).is_some() { + return Ok(true); } - Err(error) => return Err(error), + if let Some(env_name) = provider.api_key_env() { + if has_api_key(env_name) { + return Ok(true); + } + } + // Configured but no auth available + return Ok(false); } } + } + // Not configured in .claw.json - check built-in provider env vars + return Ok(match provider_name { + "openai" => has_api_key("OPENAI_API_KEY"), + "xai" => has_api_key("XAI_API_KEY"), + "anthropic" => anthropic_has_auth().unwrap_or(false), + _ => false, + }); + } - Err(RuntimeError::new("post-tool continuation nudge exhausted")) - }) + // No prefix - use metadata_for_model for built-in model detection + if let Some(meta) = api::metadata_for_model(&resolved) { + let has_env = has_api_key(meta.auth_env); + // OAuth fallback for bare model names. + let has_oauth = match meta.auth_env { + "OPENAI_API_KEY" => { + runtime::load_provider_oauth("openai").ok().flatten().is_some() + } + "MOONSHOT_API_KEY" => { + runtime::load_provider_oauth("moonshot").ok().flatten().is_some() + } + _ => false, + }; + return Ok(has_env || has_oauth); } + + // Bare model name without recognized prefix - fall back to env sniffing + let provider = detect_provider_kind(&resolved); + let available = match provider { + ProviderKind::Anthropic => anthropic_has_auth().unwrap_or(false), + ProviderKind::Xai => has_api_key("XAI_API_KEY"), + ProviderKind::OpenAi => { + has_api_key("OPENAI_API_KEY") + || has_api_key("DASHSCOPE_API_KEY") + || has_api_key("MOONSHOT_API_KEY") + || runtime::load_provider_oauth("openai").ok().flatten().is_some() + || runtime::load_provider_oauth("moonshot").ok().flatten().is_some() + } + }; + Ok(available) } -impl AnthropicRuntimeClient { - /// Consume a single streaming response, optionally applying a stall - /// timeout on the first event for post-tool continuations. - #[allow(clippy::too_many_lines)] - async fn consume_stream( +fn run_provider_welcome( + default_model: &str, +) -> Result, Box> { + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err( + "Authentication required. Set one of these environment variables:\n\ + \n\ + ANTHROPIC_API_KEY= # For Anthropic (Claude)\n\ + OPENAI_API_KEY= # For OpenAI\n\ + XAI_API_KEY= # For xAI (Grok)\n\ + DASHSCOPE_API_KEY= # For DashScope/Qwen" + .into(), + ); + } + + println!( + "\n\x1b[1mWelcome to Claw Code\x1b[0m\n\n\ + No API key detected for the selected model.\n\ + Choose a provider to authenticate with:\n" + ); + + let builtin_count = BUILTIN_PROVIDERS.len(); + let template_count = LOGIN_PROVIDER_TEMPLATES.len(); + let total = builtin_count + template_count; + + println!(" Built-in:"); + for (i, provider) in BUILTIN_PROVIDERS.iter().enumerate() { + let oauth_tag = if provider.oauth.is_some() { " [OAuth]" } else { "" }; + println!(" {}. {}{}", i + 1, provider.label, oauth_tag); + } + + println!("\n Additional providers:"); + for (i, template) in LOGIN_PROVIDER_TEMPLATES.iter().enumerate() { + let oauth_tag = if template.oauth.is_some() { " [OAuth]" } else { "" }; + println!( + " {}. {}{}", + builtin_count + i + 1, + template.label, + oauth_tag + ); + } + + let total = builtin_count + template_count; + + print!("\nEnter number (1-{total}): "); + std::io::stdout().flush()?; + let mut choice = String::new(); + std::io::stdin().read_line(&mut choice)?; + let choice = choice.trim(); + + let index: usize = choice.parse().map_err(|_| "invalid selection")?; + if index == 0 || index > total { + return Err("invalid selection".into()); + } + + // Built-in provider selected + if index <= builtin_count { + let provider = BUILTIN_PROVIDERS.get(index - 1).expect("valid builtin index"); + + // Offer OAuth if available + if let Some(ref oauth) = provider.oauth { + if prompt_oauth_or_api_key(provider.label, true)? { + run_pkce_oauth_flow(provider.id, oauth)?; + return Ok(Some(format!("{}/{}", provider.id, provider.default_model))); + } + } + + print!("Enter {} (or press Enter to cancel): ", provider.env_var); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(None); + } + + std::env::set_var(provider.env_var, key); + + // Optionally save to ~/.claw/settings.json as a simple model config. + if let Some(home) = std::env::var_os("HOME") { + let claw_dir = std::path::PathBuf::from(home).join(".claw"); + if claw_dir.exists() || std::fs::create_dir_all(&claw_dir).is_ok() { + let settings_path = claw_dir.join("settings.json"); + let model_str = format!("{}/{}", provider.id, provider.default_model); + let content = format!("{{\"model\": \"{model_str}\"}}"); + let _ = std::fs::write(&settings_path, content); + } + } + + return Ok(Some(format!("{}/{}", provider.id, provider.default_model))); + } + + // Template provider selected + let template = LOGIN_PROVIDER_TEMPLATES + .get(index - builtin_count - 1) + .expect("valid template index"); + + // Offer OAuth if available + if let Some(ref oauth) = template.oauth { + if prompt_oauth_or_api_key(template.label, true)? { + match oauth.flow { + OAuthFlowType::Pkce { .. } => { + run_pkce_oauth_flow(template.id, oauth)?; + } + OAuthFlowType::Device { .. } => { + run_device_oauth_flow(template.id, oauth)?; + } + } + return Ok(Some(format!("{}/{}", template.id, template.default_model))); + } + } + + print!( + "Enter {} (or press Enter to cancel): ", + template.api_key_env + ); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(None); + } + + std::env::set_var(template.api_key_env, key); + save_model_provider_profile( + template.id, + template.provider_type, + template.base_url, + template.api_key_env, + Some(key), + &template.models.iter().map(|m| (*m).to_string()).collect::>(), + template.default_model, + )?; + + // No custom providers in welcome — only built-in and templates. + Err("Invalid selection".into()) +} + +fn run_model_command(name: Option<&str>) -> Result<(), Box> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let config = loader.load()?; + + if let Some(name) = name { + let trimmed = name.trim(); + if trimmed.is_empty() { + return Err("model name cannot be empty".into()); + } + validate_model_syntax(trimmed)?; + let resolved = resolve_model_alias_with_config(trimmed); + + // Write to user settings.json + let settings_path = loader.config_home().join("settings.json"); + let mut settings: serde_json::Value = if settings_path.exists() { + let contents = fs::read_to_string(&settings_path)?; + serde_json::from_str(&contents).unwrap_or_else(|_| serde_json::json!({})) + } else { + serde_json::json!({}) + }; + if let Some(obj) = settings.as_object_mut() { + obj.insert("model".to_string(), serde_json::Value::String(resolved.clone())); + } + fs::create_dir_all(loader.config_home())?; + fs::write(&settings_path, serde_json::to_string_pretty(&settings)?)?; + + println!("Default model set to: {}", resolved); + return Ok(()); + } + + // No name provided: show current model and list available models + let env_model = env::var("ANTHROPIC_MODEL").ok(); + let current_model = config.model() + .or_else(|| env_model.as_deref()) + .unwrap_or(DEFAULT_MODEL); + + println!("Current model: {}", current_model); + println!(); + println!("Available models:"); + + // Built-in providers + println!(" Built-in:"); + for builtin in BUILTIN_PROVIDERS { + let label = match builtin.id { + "anthropic" => "Anthropic (Claude)", + "openai" => "OpenAI", + "xai" => "xAI (Grok)", + other => other, + }; + println!(" {}:", label); + println!(" {}/{}", builtin.id, builtin.default_model); + } + + // Login templates + println!(" Additional providers:"); + for template in LOGIN_PROVIDER_TEMPLATES { + println!(" {}:", template.label); + for model in template.models { + println!(" {}/{}", template.id, model); + } + } + + // Custom providers from config + let custom_providers = config.model_providers(); + if !custom_providers.is_empty() { + println!(" Custom providers:"); + for (name, provider) in custom_providers { + println!(" {}:", name); + if !provider.models().is_empty() { + for model in provider.models() { + println!(" {}/{}", name, model); + } + } else if let Some(default) = provider.default_model() { + println!(" {}/{}", name, default); + } + } + } + + println!(); + println!("Usage:"); + println!(" claw model Set default model"); + println!(" claw --model prompt Use model for one prompt"); + + Ok(()) +} + +fn run_auth_command(provider: Option<&str>) -> Result<(), Box> { + if let Some(provider_id) = provider { + // Try built-in provider first + if let Some(builtin) = BUILTIN_PROVIDERS.iter().find(|p| p.id == provider_id) { + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err(format!( + "Authentication requires a terminal. Set the environment variable instead:\n\ + \n\ + export {}=", + builtin.env_var + ) + .into()); + } + + // Offer OAuth if available + if let Some(ref oauth) = builtin.oauth { + if prompt_oauth_or_api_key(builtin.label, true)? { + run_pkce_oauth_flow(builtin.id, oauth)?; + println!("Authenticated with {} via OAuth.", builtin.label); + return Ok(()); + } + } + + print!("Enter {} (or press Enter to cancel): ", builtin.env_var); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(()); + } + + std::env::set_var(builtin.env_var, key); + println!("Authentication set for {}.", builtin.label); + return Ok(()); + } + + // Try template provider + if let Some(template) = LOGIN_PROVIDER_TEMPLATES.iter().find(|p| p.id == provider_id) { + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err(format!( + "Authentication requires a terminal. Set the environment variable instead:\n\ + \n\ + export {}=", + template.api_key_env + ) + .into()); + } + + // Offer OAuth if available + if let Some(ref oauth) = template.oauth { + if prompt_oauth_or_api_key(template.label, true)? { + match oauth.flow { + OAuthFlowType::Pkce { .. } => { + run_pkce_oauth_flow(template.id, oauth)?; + } + OAuthFlowType::Device { .. } => { + run_device_oauth_flow(template.id, oauth)?; + } + } + println!( + "Authenticated with {label} via OAuth. Model: {id}/{model}", + label = template.label, + id = template.id, + model = template.default_model + ); + return Ok(()); + } + } + + print!( + "Enter {} (or press Enter to cancel): ", + template.api_key_env + ); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(()); + } + + std::env::set_var(template.api_key_env, key); + save_model_provider_profile( + template.id, + template.provider_type, + template.base_url, + template.api_key_env, + Some(key), + &template.models.iter().map(|m| (*m).to_string()).collect::>(), + template.default_model, + )?; + println!( + "Authenticated with {label}. Model: {id}/{model}", + label = template.label, + id = template.id, + model = template.default_model + ); + return Ok(()); + } + + // Try custom provider from .claw.json / settings.json + if let Ok(cwd) = env::current_dir() { + if let Ok(config) = ConfigLoader::default_for(&cwd).load() { + if let Some(provider) = config.model_providers().get(provider_id) { + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err(format!( + "Authentication requires a terminal. Set the environment variable instead:\n\ + \n\ + export {}=", + provider.api_key_env().unwrap_or("API_KEY") + ) + .into()); + } + if let Some(env_name) = provider.api_key_env() { + print!("Enter {} (or press Enter to cancel): ", env_name); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + if key.is_empty() { + println!("Cancelled."); + return Ok(()); + } + std::env::set_var(env_name, key); + println!("Authentication set for {provider_id}."); + return Ok(()); + } + return Err(format!( + "provider '{provider_id}' has no apiKeyEnv configured. Add it to your .claw.json." + ) + .into()); + } + } + } + + return Err(format!( + "unknown provider: '{provider_id}'. Run `claw auth` to see available providers." + ) + .into()); + } + + match run_provider_welcome(DEFAULT_MODEL)? { + Some(model) => { + println!("Authentication configured. Model set to {model}."); + Ok(()) + } + None => { + println!("Cancelled."); + Ok(()) + } + } +} + +// --------------------------------------------------------------------------- +// OAuth flow implementations +// --------------------------------------------------------------------------- + +fn run_pkce_oauth_flow( + provider_id: &str, + oauth: &ProviderOAuthConfig, +) -> Result> { + use runtime::{ + generate_pkce_pair, generate_state, loopback_redirect_uri_with_path, open_browser, + run_oauth_callback_server, OAuthAuthorizationRequest, OAuthTokenExchangeRequest, + }; + + let pkce = generate_pkce_pair()?; + let state = generate_state()?; + + let (authorize_url, token_url, scopes) = match oauth.flow { + OAuthFlowType::Pkce { + authorize_url, + token_url, + scopes, + } => (authorize_url, token_url, scopes), + _ => return Err("Expected PKCE flow".into()), + }; + + let config = runtime::OAuthConfig { + client_id: oauth.client_id.to_string(), + authorize_url: authorize_url.to_string(), + token_url: token_url.to_string(), + callback_port: Some(oauth.callback_port), + manual_redirect_url: None, + scopes: scopes.iter().map(|s| (*s).to_string()).collect(), + }; + + let redirect_uri = loopback_redirect_uri_with_path(oauth.callback_port, oauth.redirect_path); + + let mut auth_request = + OAuthAuthorizationRequest::from_config(&config, &redirect_uri, &state, &pkce); + + // OpenAI-specific parameters required for drop-in Codex CLI compatibility + if provider_id == "openai" { + auth_request = auth_request + .with_extra_param("id_token_add_organizations", "true") + .with_extra_param("codex_cli_simplified_flow", "true") + .with_extra_param("originator", "codex_cli_rs"); + } + + let auth_url = auth_request.build_url(); + + println!("Opening browser for OAuth authentication..."); + println!("If the browser doesn't open automatically, visit:"); + println!(" {auth_url}"); + open_browser(&auth_url)?; + + println!("Waiting for authentication..."); + let callback = run_oauth_callback_server( + oauth.callback_port, + std::time::Duration::from_secs(300), + oauth.redirect_path, + )?; + + if callback.state != state { + return Err("OAuth state mismatch. Possible CSRF attack.".into()); + } + + println!("Exchanging authorization code for tokens..."); + + let rt = tokio::runtime::Runtime::new()?; + let token_set = rt.block_on(async { + let client = reqwest::Client::new(); + let exchange = OAuthTokenExchangeRequest::from_config( + &config, + &callback.code, + &state, + &pkce.verifier, + &redirect_uri, + ); + + let response = client + .post(&config.token_url) + .form(&exchange.form_params()) + .send() + .await + .map_err(|e| format!("Token exchange request failed: {e}"))?; + + let status = response.status(); + let body = response + .text() + .await + .map_err(|e| format!("Failed to read token response: {e}"))?; + + if !status.is_success() { + return Err(format!("Token exchange failed ({status}): {body}").into()); + } + + let json: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| format!("Invalid token response JSON: {e}"))?; + + let access_token = json["access_token"] + .as_str() + .ok_or("access_token missing from response")? + .to_string(); + let refresh_token = json["refresh_token"].as_str().map(String::from); + let id_token = json["id_token"].as_str().map(String::from); + let expires_at = json["expires_in"].as_u64().map(|secs| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + secs + }); + let scopes = json["scope"] + .as_str() + .map(|s| s.split(' ').map(String::from).collect()) + .unwrap_or_default(); + + Ok::<_, Box>(runtime::OAuthTokenSet { + access_token, + refresh_token, + expires_at, + scopes, + id_token, + }) + })?; + + runtime::save_provider_oauth(provider_id, &token_set)?; + println!("✓ OAuth authentication successful. Tokens saved."); + + Ok(token_set) +} + +fn run_device_oauth_flow( + provider_id: &str, + oauth: &ProviderOAuthConfig, +) -> Result> { + use runtime::{open_browser, poll_device_token, DeviceAuthRequest}; + + let (device_auth_url, token_url, scopes) = match oauth.flow { + OAuthFlowType::Device { + device_auth_url, + token_url, + scopes, + } => (device_auth_url, token_url, scopes), + _ => return Err("Expected Device flow".into()), + }; + + let rt = tokio::runtime::Runtime::new()?; + + // Step 1: Request device code + let device_response = rt.block_on(async { + let client = reqwest::Client::new(); + let scope_str = scopes.join(" "); + let params = [ + ("client_id", oauth.client_id), + ("scope", scope_str.as_str()), + ]; + let response = client + .post(device_auth_url) + .form(¶ms) + .send() + .await + .map_err(|e| format!("Device auth request failed: {e}"))?; + + let status = response.status(); + let body = response + .text() + .await + .map_err(|e| format!("Failed to read device auth response: {e}"))?; + + if !status.is_success() { + return Err(format!("Device auth failed ({status}): {body}").into()); + } + + let resp: runtime::DeviceAuthResponse = serde_json::from_str(&body) + .map_err(|e| format!("Invalid device auth response: {e}"))?; + Ok::<_, Box>(resp) + })?; + + println!("\nDevice authentication started."); + println!("User code: {}", device_response.user_code); + println!( + "Please visit: {}", + device_response + .verification_uri_complete + .as_deref() + .unwrap_or(&device_response.verification_uri) + ); + + if let Some(ref complete_uri) = device_response.verification_uri_complete { + open_browser(complete_uri)?; + } else { + open_browser(&device_response.verification_uri)?; + } + + // Step 2: Poll for token + let start = std::time::Instant::now(); + let expires_in = std::time::Duration::from_secs(device_response.expires_in); + let interval = std::time::Duration::from_secs(device_response.interval); + + let token_set = rt.block_on(async { + let client = reqwest::Client::new(); + loop { + if start.elapsed() > expires_in { + return Err::<_, Box>( + "Device authorization expired. Please try again.".into(), + ); + } + + tokio::time::sleep(interval).await; + + match poll_device_token( + &client, + &device_response.device_code, + oauth.client_id, + token_url, + ) + .await + { + Ok(Some(token_set)) => return Ok(token_set), + Ok(None) => { + println!("Waiting for authorization..."); + continue; + } + Err(e) => return Err(e.into()), + } + } + })?; + + runtime::save_provider_oauth(provider_id, &token_set)?; + println!("✓ OAuth authentication successful. Tokens saved."); + + Ok(token_set) +} + +fn prompt_oauth_or_api_key(provider_label: &str, has_oauth: bool) -> Result> { + if !has_oauth { + return Ok(false); + } + + println!("\nChoose authentication method for {provider_label}:"); + println!(" 1. Sign in with {provider_label} account (OAuth) — recommended"); + println!(" 2. Enter API key manually"); + print!("Select [1]: "); + std::io::stdout().flush()?; + + let mut choice = String::new(); + std::io::stdin().read_line(&mut choice)?; + let choice = choice.trim(); + + Ok(choice.is_empty() || choice == "1") +} + +impl ApiClient for AnthropicRuntimeClient { + #[allow(clippy::too_many_lines)] + fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { + if let Some(progress_reporter) = &self.progress_reporter { + progress_reporter.mark_model_phase(); + } + let is_post_tool = request_ends_with_tool_result(&request); + let message_request = MessageRequest { + model: self.model.clone(), + max_tokens: max_tokens_for_model(&self.model), + messages: convert_messages(&request.messages), + system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), + tools: self + .enable_tools + .then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())), + tool_choice: self.enable_tools.then_some(ToolChoice::Auto), + stream: true, + reasoning_effort: self.reasoning_effort.clone(), + ..Default::default() + }; + + self.runtime.block_on(async { + // When resuming after tool execution, apply a stall timeout on the + // first stream event. If the model does not respond within the + // deadline we drop the stalled connection and re-send the request as + // a continuation nudge (one retry only). + let max_attempts: usize = if is_post_tool { 2 } else { 1 }; + + for attempt in 1..=max_attempts { + let result = self + .consume_stream(&message_request, is_post_tool && attempt == 1) + .await; + match result { + Ok(events) => return Ok(events), + Err(error) + if error.to_string().contains("post-tool stall") + && attempt < max_attempts => + { + // Stalled after tool completion — nudge the model by + // re-sending the same request. + } + Err(error) => return Err(error), + } + } + + Err(RuntimeError::new("post-tool continuation nudge exhausted")) + }) + } +} + +impl AnthropicRuntimeClient { + /// Consume a single streaming response, optionally applying a stall + /// timeout on the first event for post-tool continuations. + #[allow(clippy::too_many_lines)] + async fn consume_stream( &self, message_request: &MessageRequest, apply_stall_timeout: bool, @@ -8134,7 +9547,6 @@ fn collect_prompt_cache_events(summary: &runtime::TurnSummary) -> Vec io::Result<()> { out, " Diagnose local auth, config, workspace, and sandbox health" )?; + writeln!(out, " claw login")?; + writeln!( + out, + " Configure a compatible model provider in settings.json" + )?; writeln!(out, " claw acp [serve]")?; writeln!( out, @@ -9121,6 +10538,16 @@ fn print_help_to(out: &mut impl Write) -> io::Result<()> { writeln!(out, " claw skills")?; writeln!(out, " claw system-prompt [--cwd PATH] [--date YYYY-MM-DD]")?; writeln!(out, " claw init")?; + writeln!(out, " claw auth [PROVIDER]")?; + writeln!( + out, + " Authenticate with a model provider (anthropic, openai, xai)" + )?; + writeln!(out, " claw model [MODEL]")?; + writeln!( + out, + " Show available models or set the default model" + )?; writeln!( out, " claw export [PATH] [--session SESSION] [--output PATH]" @@ -9261,6 +10688,7 @@ mod tests { SlashCommand, StatusUsage, TmuxPaneSnapshot, DEFAULT_MODEL, LATEST_SESSION_REFERENCE, STUB_COMMANDS, }; + use crate::configured_provider_for_model; use api::{ApiError, MessageResponse, OutputContentBlock, Usage}; use plugins::{ PluginManager, PluginManagerConfig, PluginTool, PluginToolDefinition, PluginToolPermission, @@ -9559,7 +10987,7 @@ mod tests { #[test] fn defaults_to_repl_when_no_args() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); assert_eq!( parse_args(&[]).expect("args should parse"), CliAction::Repl { @@ -9659,6 +11087,7 @@ mod tests { refresh_token: Some("refresh-token".to_string()), expires_at: Some(0), scopes: vec!["org:create_api_key".to_string(), "user:profile".to_string()], + id_token: None, }) .expect("save expired oauth credentials"); @@ -9685,7 +11114,7 @@ mod tests { #[test] fn parses_prompt_subcommand() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let args = vec![ "prompt".to_string(), "hello".to_string(), @@ -9707,6 +11136,54 @@ mod tests { ); } + #[test] + fn parse_args_auth_without_provider() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["auth".to_string()]).expect("args should parse"), + CliAction::Auth { provider: None } + ); + } + + #[test] + fn parse_args_auth_with_provider() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["auth".to_string(), "openai".to_string()]).expect("args should parse"), + CliAction::Auth { + provider: Some("openai".to_string()), + } + ); + } + + #[test] + fn parse_args_model_without_name() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["model".to_string()]).expect("args should parse"), + CliAction::Model { name: None } + ); + assert_eq!( + parse_args(&["models".to_string()]).expect("args should parse"), + CliAction::Model { name: None } + ); + } + + #[test] + fn parse_args_model_with_name() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["model".to_string(), "moonshot/kimi-k2.6".to_string()]).expect("args should parse"), + CliAction::Model { + name: Some("moonshot/kimi-k2.6".to_string()), + } + ); + } + #[test] fn merge_prompt_with_stdin_returns_prompt_unchanged_when_no_pipe() { // given @@ -9774,7 +11251,7 @@ mod tests { #[test] fn parses_bare_prompt_and_json_output_flag() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let args = vec![ "--output-format=json".to_string(), "--model".to_string(), @@ -9802,7 +11279,7 @@ mod tests { fn parses_compact_flag_for_prompt_mode() { // given a bare prompt invocation that includes the --compact flag let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let args = vec![ "--compact".to_string(), "summarize".to_string(), @@ -9849,7 +11326,7 @@ mod tests { #[test] fn resolves_model_aliases_in_args() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let args = vec![ "--model".to_string(), "opus".to_string(), @@ -9919,6 +11396,80 @@ mod tests { assert_eq!(builtin, "claude-haiku-4-5-20251213"); } + #[test] + fn configured_model_provider_default_resolves_to_provider_model_ref() { + // given + let _guard = env_lock(); + let root = temp_dir(); + let cwd = root.join("project"); + let config_home = root.join("config-home"); + std::fs::create_dir_all(cwd.join(".claw")).expect("project config dir should exist"); + std::fs::create_dir_all(&config_home).expect("config home should exist"); + std::fs::write( + cwd.join(".claw").join("settings.json"), + r#"{"modelProviders":{"zai":{"type":"openai-compatible","baseUrl":"https://api.z.ai/api/paas/v4","apiKeyEnv":"Z_AI_API_KEY","models":["glm-5.1","glm-4.6"],"defaultModel":"glm-5.1"}}}"#, + ) + .expect("project config should write"); + + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + + // when + let resolved = with_current_dir(&cwd, || resolve_model_alias_with_config("zai")); + + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + std::fs::remove_dir_all(root).expect("temp config root should clean up"); + + // then + assert_eq!(resolved, "zai/glm-5.1"); + } + + #[test] + fn configured_model_provider_resolves_runtime_connection_details() { + // given + let _guard = env_lock(); + let root = temp_dir(); + let cwd = root.join("project"); + let config_home = root.join("config-home"); + std::fs::create_dir_all(cwd.join(".claw")).expect("project config dir should exist"); + std::fs::create_dir_all(&config_home).expect("config home should exist"); + std::fs::write( + cwd.join(".claw").join("settings.json"), + r#"{"modelProviders":{"minimax":{"baseUrl":"https://api.minimax.io/v1","apiKeyEnv":"MINIMAX_API_KEY","models":["MiniMax-M2.7-highspeed"]}}}"#, + ) + .expect("project config should write"); + + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_key = std::env::var("MINIMAX_API_KEY").ok(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::set_var("MINIMAX_API_KEY", "test-minimax-key"); + + // when + let provider = with_current_dir(&cwd, || { + configured_provider_for_model("minimax/MiniMax-M2.7-highspeed") + }) + .expect("provider lookup should succeed") + .expect("provider should exist"); + + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + match original_key { + Some(value) => std::env::set_var("MINIMAX_API_KEY", value), + None => std::env::remove_var("MINIMAX_API_KEY"), + } + std::fs::remove_dir_all(root).expect("temp config root should clean up"); + + // then + assert_eq!(provider.wire_model, "MiniMax-M2.7-highspeed"); + assert_eq!(provider.api_key, "test-minimax-key"); + assert_eq!(provider.base_url, "https://api.minimax.io/v1"); + } + #[test] fn parses_version_flags_without_initializing_prompt_mode() { assert_eq!( @@ -10005,7 +11556,7 @@ mod tests { #[test] fn parses_allowed_tools_flags_with_aliases_and_lists() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let args = vec![ "--allowedTools".to_string(), "read,glob".to_string(), @@ -10068,9 +11619,11 @@ mod tests { } #[test] - fn removed_login_and_logout_subcommands_error_helpfully() { - let login = parse_args(&["login".to_string()]).expect_err("login should be removed"); - assert!(login.contains("ANTHROPIC_API_KEY")); + fn login_subcommand_parses_and_logout_errors_helpfully() { + assert_eq!( + parse_args(&["login".to_string()]).expect("login should parse"), + CliAction::Login + ); let logout = parse_args(&["logout".to_string()]).expect_err("logout should be removed"); assert!(logout.contains("ANTHROPIC_AUTH_TOKEN")); assert_eq!( @@ -10647,7 +12200,7 @@ mod tests { #[test] fn parses_single_word_command_aliases_without_falling_back_to_prompt_mode() { let _guard = env_lock(); - std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); assert_eq!( parse_args(&["help".to_string()]).expect("help should parse"), CliAction::Help { @@ -11313,6 +12866,8 @@ mod tests { #[test] fn prompt_subcommand_allows_literal_typo_word() { + let _guard = env_lock(); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); assert_eq!( parse_args(&["prompt".to_string(), "doctorr".to_string()]) .expect("explicit prompt subcommand should allow literal typo word"), @@ -11336,6 +12891,7 @@ mod tests { // doesn't pick up a stale .claw/settings.json from other tests that // may have set `permissionMode: acceptEdits` in a shared cwd. let _guard = env_lock(); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "danger-full-access"); let root = temp_dir(); let cwd = root.join("project"); std::fs::create_dir_all(&cwd).expect("project dir should exist"); @@ -11790,7 +13346,7 @@ mod tests { assert!(help.contains("claw /skills")); assert!(help.contains("ultraworkers/claw-code")); assert!(help.contains("cargo install claw-code")); - assert!(!help.contains("claw login")); + assert!(help.contains("claw login")); assert!(!help.contains("claw logout")); } @@ -13619,3 +15175,127 @@ mod dump_manifests_tests { let _ = fs::remove_dir_all(&root); } } + + +#[cfg(test)] +mod auth_tests { + use std::sync::{Mutex, OnceLock}; + use super::{parse_args, check_model_auth_available, BUILTIN_PROVIDERS, LOGIN_PROVIDER_TEMPLATES, CliAction}; + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + + #[test] + fn parse_args_auth_without_provider() { + let args = vec!["auth".to_string()]; + let action = parse_args(&args).expect("should parse"); + match action { + CliAction::Auth { provider } => assert!(provider.is_none()), + other => panic!("expected CliAction::Auth, got {other:?}"), + } + } + + #[test] + fn parse_args_auth_with_provider() { + let args = vec!["auth".to_string(), "openai".to_string()]; + let action = parse_args(&args).expect("should parse"); + match action { + CliAction::Auth { provider } => assert_eq!(provider, Some("openai".to_string())), + other => panic!("expected CliAction::Auth, got {other:?}"), + } + } + + #[test] + fn openai_builtin_provider_has_oauth_config() { + let openai = BUILTIN_PROVIDERS + .iter() + .find(|p| p.id == "openai") + .expect("openai provider exists"); + assert!(openai.oauth.is_some(), "openai should have OAuth config"); + let oauth = openai.oauth.unwrap(); + assert_eq!(oauth.client_id, "app_EMoamEEZ73f0CkXaXp7hrann"); + assert_eq!(oauth.callback_port, 1455); + } + + #[test] + fn anthropic_and_xai_have_no_oauth() { + for provider in BUILTIN_PROVIDERS.iter() { + if provider.id == "anthropic" || provider.id == "xai" { + assert!(provider.oauth.is_none(), "{} should not have OAuth", provider.id); + } + } + } + + #[test] + fn moonshot_template_has_device_oauth() { + let moonshot = LOGIN_PROVIDER_TEMPLATES + .iter() + .find(|p| p.id == "moonshot") + .expect("moonshot template exists"); + assert!(moonshot.oauth.is_some(), "moonshot should have OAuth config"); + let oauth = moonshot.oauth.unwrap(); + assert_eq!(oauth.client_id, "17e5f671-d194-4dfb-9706-5516cb48c098"); + } + + #[test] + fn zai_and_minimax_have_no_oauth() { + for template in LOGIN_PROVIDER_TEMPLATES.iter() { + if template.id == "zai" || template.id == "minimax-coding-plan" { + assert!( + template.oauth.is_none(), + "{} should not have OAuth", + template.id + ); + } + } + } + + #[test] + fn check_model_auth_available_detects_saved_moonshot_oauth() { + let _guard = env_lock(); + let config_home = std::env::temp_dir().join(format!( + "claw-oauth-auth-test-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() + )); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::fs::create_dir_all(&config_home).expect("create config home"); + + // Ensure no env var is set + std::env::remove_var("MOONSHOT_API_KEY"); + std::env::remove_var("DASHSCOPE_API_KEY"); + + // Without saved OAuth, auth should not be available + assert!( + !check_model_auth_available("moonshot/moonshot-v1-8k").expect("check auth"), + "auth should be unavailable without env or saved tokens" + ); + + // Save an OAuth token for moonshot + let token_set = runtime::OAuthTokenSet { + access_token: "test-access-token".to_string(), + refresh_token: Some("test-refresh".to_string()), + expires_at: Some(9999999999), + scopes: vec!["openid".to_string()], + id_token: None, + }; + runtime::save_provider_oauth("moonshot", &token_set).expect("save token"); + + // With saved OAuth, auth should be available + assert!( + check_model_auth_available("moonshot/moonshot-v1-8k").expect("check auth with oauth"), + "auth should be available with saved OAuth token" + ); + + // Clean up + std::env::remove_var("CLAW_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).ok(); + } +}