diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 740147e78e..78b9feeff4 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1364,6 +1364,7 @@ dependencies = [ "glob", "plugins", "regex", + "reqwest", "serde", "serde_json", "sha2", @@ -1456,6 +1457,7 @@ dependencies = [ "mock-anthropic-service", "plugins", "pulldown-cmark", + "reqwest", "runtime", "rustyline", "serde", diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 6e68fd2e2c..d9cb895596 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -41,7 +41,14 @@ impl ProviderClient { } _ => OpenAiCompatConfig::openai(), }; - Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?)) + // Try OAuth for OpenAI if env var is not set + if config.provider_name == "OpenAI" { + Ok(Self::OpenAi(OpenAiCompatClient::from_env_or_oauth( + config, "openai", + )?)) + } else { + Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?)) + } } } } diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 40da29f140..0f04905e7a 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -18,9 +18,9 @@ pub use prompt_cache::{ CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord, PromptCacheStats, }; -pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource}; +pub use providers::anthropic::{has_auth_from_env_or_saved as anthropic_has_auth, AnthropicClient, AnthropicClient as ApiClient, AuthSource}; pub use providers::openai_compat::{ - build_chat_completion_request, flatten_tool_result_content, is_reasoning_model, + build_chat_completion_request, flatten_tool_result_content, has_api_key, is_reasoning_model, model_rejects_is_error_field, translate_message, OpenAiCompatClient, OpenAiCompatConfig, }; pub use providers::{ diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index a810502e66..42eaf7fb9e 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -139,6 +139,31 @@ impl OpenAiCompatClient { Ok(Self::new(api_key, config)) } + /// Create a client using an OAuth access token instead of an API key. + /// The token is sent as `Authorization: Bearer {token}`. + #[must_use] + pub fn from_oauth_token(token: impl Into, config: OpenAiCompatConfig) -> Self { + Self::new(token, config) + } + + /// Try env var first, then fall back to saved OAuth token for the provider. + /// `provider_id` is the key used in `~/.claw/credentials.json` under `oauth_providers`. + pub fn from_env_or_oauth( + config: OpenAiCompatConfig, + provider_id: &str, + ) -> Result { + if let Some(api_key) = read_env_non_empty(config.api_key_env)? { + return Ok(Self::new(api_key, config)); + } + if let Ok(Some(token_set)) = runtime::load_provider_oauth(provider_id) { + return Ok(Self::from_oauth_token(token_set.access_token, config)); + } + Err(ApiError::missing_credentials( + config.provider_name, + config.credential_env_vars(), + )) + } + #[must_use] pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = base_url.into(); diff --git a/rust/crates/runtime/Cargo.toml b/rust/crates/runtime/Cargo.toml index b1bd04f374..382e38fc29 100644 --- a/rust/crates/runtime/Cargo.toml +++ b/rust/crates/runtime/Cargo.toml @@ -10,6 +10,7 @@ sha2 = "0.10" glob = "0.3" plugins = { path = "../plugins" } regex = "1" +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } serde = { version = "1", features = ["derive"] } serde_json.workspace = true telemetry = { path = "../telemetry" } diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index c7d87091fa..365e7ad6c8 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -112,11 +112,13 @@ pub use mcp_stdio::{ UnsupportedMcpServer, }; pub use oauth::{ - clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair, - generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query, - parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest, - OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet, - PkceChallengeMethod, PkceCodePair, + clear_oauth_credentials, clear_provider_oauth, code_challenge_s256, credentials_path, + generate_pkce_pair, generate_state, load_oauth_credentials, load_provider_oauth, + loopback_redirect_uri, open_browser, parse_oauth_callback_query, + parse_oauth_callback_request_target, poll_device_token, run_oauth_callback_server, + save_oauth_credentials, save_provider_oauth, DeviceAuthRequest, DeviceAuthResponse, + OAuthAuthorizationRequest, OAuthCallbackParams, OAuthCallbackResult, OAuthRefreshRequest, + OAuthTokenExchangeRequest, OAuthTokenSet, PkceChallengeMethod, PkceCodePair, }; pub use permissions::{ PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy, diff --git a/rust/crates/runtime/src/oauth.rs b/rust/crates/runtime/src/oauth.rs index aa3ca158c7..87572a44c8 100644 --- a/rust/crates/runtime/src/oauth.rs +++ b/rust/crates/runtime/src/oauth.rs @@ -298,6 +298,343 @@ pub fn clear_oauth_credentials() -> io::Result<()> { write_credentials_root(&path, &root) } +// --------------------------------------------------------------------------- +// Per-provider OAuth token storage +// --------------------------------------------------------------------------- + +/// Load OAuth credentials for a specific provider from `credentials.json`. +/// Credentials are stored under the `oauth_providers.{provider_id}` key. +pub fn load_provider_oauth(provider_id: &str) -> io::Result> { + let path = credentials_path()?; + let root = read_credentials_root(&path)?; + let Some(oauth_providers) = root.get("oauth_providers") else { + return Ok(None); + }; + let Some(provider_value) = oauth_providers.get(provider_id) else { + return Ok(None); + }; + if provider_value.is_null() { + return Ok(None); + } + let stored = serde_json::from_value::(provider_value.clone()) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + Ok(Some(stored.into())) +} + +/// Save OAuth credentials for a specific provider to `credentials.json`. +/// Preserves other providers and the legacy `"oauth"` key. +pub fn save_provider_oauth(provider_id: &str, token_set: &OAuthTokenSet) -> io::Result<()> { + let path = credentials_path()?; + let mut root = read_credentials_root(&path)?; + let oauth_providers = root + .entry("oauth_providers") + .or_insert_with(|| Value::Object(Map::new())); + let provider_map = oauth_providers + .as_object_mut() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "oauth_providers must be an object"))?; + provider_map.insert( + provider_id.to_string(), + serde_json::to_value(StoredOAuthCredentials::from(token_set.clone())) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?, + ); + write_credentials_root(&path, &root) +} + +/// Clear OAuth credentials for a specific provider. +pub fn clear_provider_oauth(provider_id: &str) -> io::Result<()> { + let path = credentials_path()?; + let mut root = read_credentials_root(&path)?; + let Some(oauth_providers) = root.get_mut("oauth_providers") else { + return Ok(()); + }; + let Some(provider_map) = oauth_providers.as_object_mut() else { + return Ok(()); + }; + provider_map.remove(provider_id); + if provider_map.is_empty() { + root.remove("oauth_providers"); + } + write_credentials_root(&path, &root) +} + +// --------------------------------------------------------------------------- +// Browser launcher +// --------------------------------------------------------------------------- + +/// Open a URL in the user's default browser. +/// Falls back to printing the URL if the platform command fails. +pub fn open_browser(url: &str) -> io::Result<()> { + let (cmd, args): (&str, Vec<&str>) = if cfg!(target_os = "macos") { + ("open", vec![url]) + } else if cfg!(target_os = "linux") { + ("xdg-open", vec![url]) + } else if cfg!(target_os = "windows") { + ("cmd", vec!["/C", "start", "", url]) + } else { + eprintln!("Please open this URL in your browser:"); + eprintln!(" {url}"); + return Ok(()); + }; + match std::process::Command::new(cmd).args(&args).output() { + Ok(output) if output.status.success() => Ok(()), + _ => { + eprintln!("Could not open browser automatically. Please open this URL:"); + eprintln!(" {url}"); + Ok(()) + } + } +} + +// --------------------------------------------------------------------------- +// Local HTTP callback server (blocking, single-request) +// --------------------------------------------------------------------------- + +/// Result of a successful OAuth callback. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthCallbackResult { + pub code: String, + pub state: String, +} + +/// Run a blocking local HTTP server that waits for a single `/callback` request. +/// Returns the authorization code and state on success. +/// Times out after `timeout` duration. +pub fn run_oauth_callback_server( + port: u16, + timeout: std::time::Duration, +) -> io::Result { + use std::io::{BufRead, BufReader, Write}; + use std::net::{SocketAddr, TcpListener}; + + let addr: SocketAddr = format!("127.0.0.1:{port}").parse().map_err(|e| { + io::Error::new(io::ErrorKind::InvalidInput, format!("invalid address: {e}")) + })?; + let listener = TcpListener::bind(addr)?; + listener.set_nonblocking(false)?; + + let start = std::time::Instant::now(); + + loop { + if start.elapsed() > timeout { + return Err(io::Error::new( + io::ErrorKind::TimedOut, + "OAuth callback timed out waiting for browser redirect", + )); + } + + let (mut stream, _) = match listener.accept() { + Ok(conn) => conn, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + std::thread::sleep(std::time::Duration::from_millis(100)); + continue; + } + Err(e) => return Err(e), + }; + + let mut reader = BufReader::new(&mut stream); + let mut first_line = String::new(); + reader.read_line(&mut first_line)?; + + // Parse "GET /callback?code=...&state=... HTTP/1.1" + let target = first_line + .split_whitespace() + .nth(1) + .unwrap_or(""); + + // Consume remaining headers so browser doesn't reset connection + loop { + let mut line = String::new(); + if reader.read_line(&mut line)? == 0 { + break; + } + if line == "\r\n" || line == "\n" { + break; + } + } + + match parse_oauth_callback_request_target(target) { + Ok(params) => { + if let (Some(code), Some(state)) = (¶ms.code, ¶ms.state) { + // Success page + let body = r#" +Authentication Successful + +

✅ Authentication Successful

+

You can close this tab and return to the terminal.

+"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + return Ok(OAuthCallbackResult { + code: code.clone(), + state: state.clone(), + }); + } + if let Some(error) = ¶ms.error { + let err_desc = params.error_description.as_deref().unwrap_or(error); + let body = format!( + r#" +Authentication Failed + +

❌ Authentication Failed

+

{}

+

You can close this tab and return to the terminal.

+"#, + html_escape(err_desc) + ); + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + return Err(io::Error::new( + io::ErrorKind::Other, + format!("OAuth error: {error} - {err_desc}"), + )); + } + } + Err(e) => { + let body = format!( + r#" +Authentication Failed + +

❌ Invalid Callback

+

{}

+"#, + html_escape(&e) + ); + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + return Err(io::Error::new(io::ErrorKind::Other, e)); + } + } + } +} + +#[must_use] +fn html_escape(input: &str) -> String { + input + .replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) +} + +// --------------------------------------------------------------------------- +// Device Authorization Flow (RFC 8628) +// --------------------------------------------------------------------------- + +/// Request body for starting a device authorization flow. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DeviceAuthRequest { + pub client_id: String, + pub scope: String, +} + +/// Response from a device authorization endpoint. +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +pub struct DeviceAuthResponse { + pub device_code: String, + pub user_code: String, + pub verification_uri: String, + #[serde(default)] + pub verification_uri_complete: Option, + pub expires_in: u64, + pub interval: u64, +} + +/// Poll a device token endpoint until the user authorizes or the flow expires. +/// Returns `Ok(None)` if the user hasn't authorized yet but we should keep polling. +/// Returns `Ok(Some(token_set))` on success. +/// Returns `Err` on fatal errors. +pub async fn poll_device_token( + client: &reqwest::Client, + device_code: &str, + client_id: &str, + token_url: &str, +) -> io::Result> { + let params = [ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("device_code", device_code), + ("client_id", client_id), + ]; + + let response = client + .post(token_url) + .form(¶ms) + .send() + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("HTTP request failed: {e}")))?; + let status = response.status(); + let body = response.text().await.map_err(|e| { + io::Error::new(io::ErrorKind::Other, format!("Failed to read response body: {e}")) + })?; + + if status.is_success() { + let token: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let access_token = token["access_token"] + .as_str() + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "access_token missing from device token response", + ) + })? + .to_string(); + let refresh_token = token["refresh_token"].as_str().map(String::from); + let expires_at = token["expires_in"].as_u64().map(|secs| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + secs + }); + let scopes = token["scope"] + .as_str() + .map(|s| s.split(' ').map(String::from).collect()) + .unwrap_or_default(); + return Ok(Some(OAuthTokenSet { + access_token, + refresh_token, + expires_at, + scopes, + })); + } + + // Parse OAuth error response + let error_json: serde_json::Value = serde_json::from_str(&body).unwrap_or_default(); + let error = error_json["error"].as_str().unwrap_or("unknown"); + + match error { + "authorization_pending" => Ok(None), + "slow_down" => { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + Ok(None) + } + "expired_token" => Err(io::Error::new( + io::ErrorKind::Other, + "Device authorization expired. Please try again.", + )), + "access_denied" => Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "User denied authorization.", + )), + _ => Err(io::Error::new( + io::ErrorKind::Other, + format!("Device token error: {error}: {body}"), + )), + } +} + pub fn parse_oauth_callback_request_target(target: &str) -> Result { let (path, query) = target .split_once('?') diff --git a/rust/crates/rusty-claude-cli/Cargo.toml b/rust/crates/rusty-claude-cli/Cargo.toml index 635fdb32f7..a1e8fba543 100644 --- a/rust/crates/rusty-claude-cli/Cargo.toml +++ b/rust/crates/rusty-claude-cli/Cargo.toml @@ -15,6 +15,7 @@ commands = { path = "../commands" } compat-harness = { path = "../compat-harness" } crossterm = "0.28" pulldown-cmark = "0.13" +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } rustyline = "15" runtime = { path = "../runtime" } plugins = { path = "../plugins" } diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index dbdbd07b64..b8c033bd86 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -24,10 +24,11 @@ use std::thread::{self, JoinHandle}; use std::time::{Duration, Instant, UNIX_EPOCH}; use api::{ - detect_provider_kind, resolve_startup_auth_source, AnthropicClient, AuthSource, - ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest, MessageResponse, - OutputContentBlock, PromptCache, ProviderClient as ApiProviderClient, ProviderKind, - StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, + anthropic_has_auth, detect_provider_kind, has_api_key, resolve_startup_auth_source, + AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, InputMessage, + MessageRequest, MessageResponse, OutputContentBlock, PromptCache, + ProviderClient as ApiProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, + ToolDefinition, ToolResultContentBlock, }; use commands::{ @@ -407,7 +408,16 @@ fn run() -> Result<(), Box> { None }; let effective_prompt = merge_prompt_with_stdin(&prompt, stdin_context.as_deref()); - let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode)?; + let resolved_model = resolve_model_alias_with_config(&model); + let final_model = if !check_model_auth_available(&resolved_model)? { + match run_provider_welcome(&resolved_model)? { + Some(new_model) => new_model, + None => return Ok(()), + } + } else { + resolved_model + }; + let mut cli = LiveCli::new(final_model, true, allowed_tools, permission_mode)?; cli.set_reasoning_effort(reasoning_effort); cli.run_turn_with_output(&effective_prompt, output_format, compact)?; } @@ -464,6 +474,7 @@ fn run() -> Result<(), Box> { reasoning_effort, allow_broad_cwd, )?, + CliAction::Auth { provider } => run_auth_command(provider.as_deref())?, CliAction::HelpTopic(topic) => print_help_topic(topic), CliAction::Help { output_format } => print_help(output_format)?, } @@ -567,6 +578,9 @@ enum CliAction { reasoning_effort: Option, allow_broad_cwd: bool, }, + Auth { + provider: Option, + }, HelpTopic(LocalHelpTopic), // prompt-mode formatting is only supported for non-interactive runs Help { @@ -949,6 +963,20 @@ fn parse_args(args: &[String]) -> Result { "system-prompt" => parse_system_prompt_args(&rest[1..], output_format), "acp" => parse_acp_args(&rest[1..], output_format), "login" | "logout" => Err(removed_auth_surface_error(rest[0].as_str())), + "auth" => { + if rest.len() == 2 && is_help_flag(&rest[1]) { + return Ok(CliAction::Help { output_format }); + } + let provider = rest.get(1).cloned(); + if rest.len() > 2 { + return Err(format!( + "unexpected extra arguments after `claw auth {}`: {}", + provider.as_deref().unwrap_or(""), + rest[2..].join(" ") + )); + } + Ok(CliAction::Auth { provider }) + } "init" => Ok(CliAction::Init { output_format }), "export" => parse_export_args(&rest[1..], output_format), "prompt" => { @@ -1339,6 +1367,7 @@ fn suggest_similar_subcommand(input: &str) -> Option> { "init", "export", "prompt", + "auth", ]; let normalized_input = input.to_ascii_lowercase(); @@ -1446,6 +1475,372 @@ fn resolve_model_alias_with_config(model: &str) -> String { resolve_model_alias(trimmed).to_string() } +struct LoginProviderTemplate { + id: &'static str, + label: &'static str, + provider_type: &'static str, + base_url: &'static str, + api_key_env: &'static str, + models: &'static [&'static str], + default_model: &'static str, + oauth: Option, +} + +const LOGIN_PROVIDER_TEMPLATES: &[LoginProviderTemplate] = &[ + LoginProviderTemplate { + id: "zai", + label: "Z.AI", + provider_type: "openai-compatible", + base_url: "https://api.z.ai/api/paas/v4", + api_key_env: "Z_AI_API_KEY", + models: &[ + "glm-5.1", + "glm-5", + "glm-5-turbo", + "glm-4.7", + "glm-4.7-flashx", + "glm-4.7-flash", + "glm-4.6", + "glm-4.5", + "glm-4.5-x", + "glm-4.5-air", + "glm-4.5-airx", + "glm-4.5-flash", + "glm-4-32b-0414-128k", + ], + default_model: "glm-5.1", + oauth: None, + }, + LoginProviderTemplate { + id: "zai-coding-plan", + label: "Z.AI Coding Plan", + provider_type: "openai-compatible", + base_url: "https://api.z.ai/api/coding/paas/v4", + api_key_env: "Z_AI_API_KEY", + models: &[ + "glm-4.5-air", + "glm-4.7", + "glm-5-turbo", + "glm-5.1", + "glm-5v-turbo", + ], + default_model: "glm-5.1", + oauth: None, + }, + LoginProviderTemplate { + id: "minimax-coding-plan", + label: "MiniMax Coding Plan", + provider_type: "anthropic-compatible", + base_url: "https://api.minimax.io/anthropic/v1", + api_key_env: "MINIMAX_API_KEY", + models: &[ + "MiniMax-M2", + "MiniMax-M2.1", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + ], + default_model: "MiniMax-M2.7-highspeed", + oauth: None, + }, + LoginProviderTemplate { + id: "openai", + label: "OpenAI", + provider_type: "openai-compatible", + base_url: "https://api.openai.com/v1", + api_key_env: "OPENAI_API_KEY", + models: &[ + "gpt-5-codex", + "gpt-5.1-codex", + "gpt-5.1-codex-max", + "gpt-5.1-codex-mini", + "gpt-5.2", + "gpt-5.2-codex", + "gpt-5.3-codex", + "gpt-5.3-codex-spark", + "gpt-5.4", + "gpt-5.4-fast", + "gpt-5.4-mini", + "gpt-5.4-mini-fast", + "gpt-5.5", + "gpt-5.5-fast", + "gpt-5.5-pro", + ], + default_model: "gpt-5.5", + oauth: Some(ProviderOAuthConfig { + client_id: "app_EMoamEEZ73f0CkXaXp7hrann", + callback_port: 1455, + flow: OAuthFlowType::Pkce { + authorize_url: "https://auth.openai.com/oauth/authorize", + token_url: "https://auth.openai.com/oauth/token", + scopes: &["openid", "profile", "email", "offline_access"], + }, + }), + }, + LoginProviderTemplate { + id: "kimi-for-coding", + label: "Kimi For Coding", + provider_type: "anthropic-compatible", + base_url: "https://api.kimi.com/coding/v1", + api_key_env: "KIMI_API_KEY", + models: &["k2p5", "k2p6", "kimi-k2-thinking"], + default_model: "k2p6", + oauth: None, + }, + LoginProviderTemplate { + id: "moonshot", + label: "Moonshot / Kimi", + provider_type: "openai-compatible", + base_url: "https://api.moonshot.ai/v1", + api_key_env: "MOONSHOT_API_KEY", + models: &[ + "kimi-k2.6", + "kimi-k2.5", + "kimi-k2-0905-preview", + "kimi-k2-0711-preview", + "kimi-k2-turbo-preview", + "kimi-k2-thinking", + "kimi-k2-thinking-turbo", + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", + "moonshot-v1-8k-vision-preview", + "moonshot-v1-32k-vision-preview", + "moonshot-v1-128k-vision-preview", + ], + default_model: "kimi-k2.6", + oauth: Some(ProviderOAuthConfig { + client_id: "17e5f671-d194-4dfb-9706-5516cb48c098", + callback_port: 4546, + flow: OAuthFlowType::Device { + device_auth_url: "https://auth.kimi.com/api/oauth/device_authorization", + token_url: "https://auth.kimi.com/api/oauth/token", + scopes: &["openid", "profile", "email"], + }, + }), + }, +]; + +fn run_login_wizard() -> Result, Box> { + if !io::stdin().is_terminal() { + return Err("login requires an interactive terminal".into()); + } + + println!(); + println!("Claw provider login"); + println!("Configure a model provider profile."); + println!("Press Enter to accept defaults."); + println!(); + + for (index, provider) in LOGIN_PROVIDER_TEMPLATES.iter().enumerate() { + println!(" [{}] {}", index + 1, provider.label); + } + println!( + " [{}] Custom compatible endpoint", + LOGIN_PROVIDER_TEMPLATES.len() + 1 + ); + + let choice = read_prompt("Select provider [1]: ")?; + let choice = if choice.trim().is_empty() { + 1 + } else { + choice.trim().parse::()? + }; + + let ( + provider_id, + label, + provider_type, + default_base_url, + default_api_key_env, + default_models, + default_model, + ) = if choice == LOGIN_PROVIDER_TEMPLATES.len() + 1 { + let id = read_required_prompt("Provider id (e.g. openrouter): ")?; + let provider_type = read_prompt( + "Provider type [openai-compatible, anthropic-compatible] [openai-compatible]: ", + )?; + let provider_type = if provider_type.trim().is_empty() { + "openai-compatible".to_string() + } else { + provider_type.trim().to_string() + }; + if !matches!( + provider_type.as_str(), + "openai-compatible" | "openai" | "anthropic-compatible" | "anthropic" + ) { + return Err(format!("unsupported provider type: {provider_type}").into()); + } + let base_url = read_required_prompt("Base URL: ")?; + let api_key_env = read_prompt("API key env var [OPENAI_API_KEY]: ")?; + let model = read_required_prompt("Default model: ")?; + ( + id, + "Custom".to_string(), + provider_type, + base_url, + if api_key_env.trim().is_empty() { + "OPENAI_API_KEY".to_string() + } else { + api_key_env.trim().to_string() + }, + vec![model.clone()], + model, + ) + } else { + let template = LOGIN_PROVIDER_TEMPLATES + .get(choice.saturating_sub(1)) + .ok_or_else(|| format!("invalid provider choice: {choice}"))?; + ( + template.id.to_string(), + template.label.to_string(), + template.provider_type.to_string(), + template.base_url.to_string(), + template.api_key_env.to_string(), + template + .models + .iter() + .map(|model| (*model).to_string()) + .collect::>(), + template.default_model.to_string(), + ) + }; + + println!(); + println!("{label}"); + println!("Provider type: {provider_type}"); + let base_url = read_prompt(&format!("Base URL [{default_base_url}]: "))?; + let base_url = if base_url.trim().is_empty() { + default_base_url + } else { + base_url.trim().to_string() + }; + let api_key_env = read_prompt(&format!("API key env var [{default_api_key_env}]: "))?; + let api_key_env = if api_key_env.trim().is_empty() { + default_api_key_env + } else { + api_key_env.trim().to_string() + }; + + let token = read_prompt("Paste API key / bearer token now, or press Enter to use env var: ")?; + let api_key = (!token.trim().is_empty()).then(|| token.trim().to_string()); + + println!("Available models: {}", default_models.join(", ")); + let model = read_prompt(&format!("Default model [{default_model}]: "))?; + let model = if model.trim().is_empty() { + default_model + } else { + model.trim().to_string() + }; + let mut models = default_models; + if !models.iter().any(|known| known == &model) { + models.push(model.clone()); + } + + save_model_provider_profile( + &provider_id, + &provider_type, + &base_url, + &api_key_env, + api_key.as_deref(), + &models, + &model, + )?; + Ok(Some(format!("{provider_id}/{model}"))) +} + +fn read_prompt(prompt: &str) -> Result> { + print!("{prompt}"); + io::stdout().flush()?; + let mut buffer = String::new(); + io::stdin().read_line(&mut buffer)?; + Ok(buffer) +} + +fn read_required_prompt(prompt: &str) -> Result> { + let value = read_prompt(prompt)?; + let value = value.trim(); + if value.is_empty() { + return Err(format!("{prompt} is required").into()); + } + Ok(value.to_string()) +} + +fn save_model_provider_profile( + provider_id: &str, + provider_type: &str, + base_url: &str, + api_key_env: &str, + api_key: Option<&str>, + models: &[String], + default_model: &str, +) -> Result<(), Box> { + let cwd = env::current_dir()?; + let config_home = ConfigLoader::default_for(&cwd).config_home().to_path_buf(); + fs::create_dir_all(&config_home)?; + let settings_path = config_home.join("settings.json"); + let mut root = match fs::read_to_string(&settings_path) { + Ok(contents) if !contents.trim().is_empty() => serde_json::from_str::(&contents)?, + Ok(_) => Value::Object(Map::new()), + Err(error) if error.kind() == io::ErrorKind::NotFound => Value::Object(Map::new()), + Err(error) => return Err(error.into()), + }; + if !root.is_object() { + root = Value::Object(Map::new()); + } + let root_object = root.as_object_mut().expect("root object initialized"); + let providers = root_object + .entry("modelProviders") + .or_insert_with(|| Value::Object(Map::new())); + if !providers.is_object() { + *providers = Value::Object(Map::new()); + } + let provider_map = providers + .as_object_mut() + .expect("modelProviders object initialized"); + + let mut provider = Map::new(); + provider.insert("type".to_string(), Value::String(provider_type.to_string())); + provider.insert("baseUrl".to_string(), Value::String(base_url.to_string())); + provider.insert( + "apiKeyEnv".to_string(), + Value::String(api_key_env.to_string()), + ); + if let Some(api_key) = api_key { + provider.insert("apiKey".to_string(), Value::String(api_key.to_string())); + } + provider.insert( + "models".to_string(), + Value::Array( + models + .iter() + .map(|model| Value::String(model.clone())) + .collect(), + ), + ); + provider.insert( + "defaultModel".to_string(), + Value::String(default_model.to_string()), + ); + provider_map.insert(provider_id.to_string(), Value::Object(provider)); + root_object.insert( + "model".to_string(), + Value::String(format!("{provider_id}/{default_model}")), + ); + + let serialized = format!("{}\n", serde_json::to_string_pretty(&root)?); + fs::write(&settings_path, serialized)?; + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut permissions = fs::metadata(&settings_path)?.permissions(); + permissions.set_mode(0o600); + fs::set_permissions(&settings_path, permissions)?; + } + Ok(()) +} + /// Validate model syntax at parse time. /// Accepts: known aliases (opus, sonnet, haiku) or provider/model pattern. /// Rejects: empty, whitespace-only, strings with spaces, or invalid chars. @@ -3772,6 +4167,21 @@ fn run_repl( enforce_broad_cwd_policy(allow_broad_cwd, CliOutputFormat::Text)?; run_stale_base_preflight(base_commit.as_deref()); let resolved_model = resolve_repl_model(model); + if !check_model_auth_available(&resolved_model)? { + match run_provider_welcome(&resolved_model)? { + Some(new_model) => { + return run_repl( + new_model, + allowed_tools, + permission_mode, + base_commit, + reasoning_effort, + allow_broad_cwd, + ); + } + None => return Ok(()), + } + } let mut cli = LiveCli::new(resolved_model, true, allowed_tools, permission_mode)?; cli.set_reasoning_effort(reasoning_effort); let mut editor = @@ -7737,6 +8147,592 @@ fn resolve_cli_auth_source_for_cwd() -> Result { resolve_startup_auth_source(|| Ok(None)) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum OAuthFlowType { + Pkce { + authorize_url: &'static str, + token_url: &'static str, + scopes: &'static [&'static str], + }, + Device { + device_auth_url: &'static str, + token_url: &'static str, + scopes: &'static [&'static str], + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ProviderOAuthConfig { + client_id: &'static str, + callback_port: u16, + flow: OAuthFlowType, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct BuiltinProvider { + id: &'static str, + label: &'static str, + env_var: &'static str, + default_model: &'static str, + oauth: Option, +} + +const BUILTIN_PROVIDERS: &[BuiltinProvider] = &[ + BuiltinProvider { + id: "anthropic", + label: "Anthropic (Claude)", + env_var: "ANTHROPIC_API_KEY", + default_model: "claude-opus-4-6", + oauth: None, + }, + BuiltinProvider { + id: "openai", + label: "OpenAI", + env_var: "OPENAI_API_KEY", + default_model: "gpt-4o", + oauth: Some(ProviderOAuthConfig { + client_id: "app_EMoamEEZ73f0CkXaXp7hrann", + callback_port: 1455, + flow: OAuthFlowType::Pkce { + authorize_url: "https://auth.openai.com/oauth/authorize", + token_url: "https://auth.openai.com/oauth/token", + scopes: &["openid", "profile", "email", "offline_access"], + }, + }), + }, + BuiltinProvider { + id: "xai", + label: "xAI (Grok)", + env_var: "XAI_API_KEY", + default_model: "grok-3", + oauth: None, + }, +]; + +fn check_model_auth_available(model: &str) -> Result> { + let resolved = api::resolve_model_alias(model); + let provider = detect_provider_kind(&resolved); + let available = match provider { + ProviderKind::Anthropic => anthropic_has_auth().unwrap_or(false), + ProviderKind::Xai => { + has_api_key("XAI_API_KEY") + } + ProviderKind::OpenAi => { + has_api_key("OPENAI_API_KEY") + || has_api_key("DASHSCOPE_API_KEY") + || runtime::load_provider_oauth("openai").ok().flatten().is_some() + } + }; + Ok(available) +} + +fn run_provider_welcome( + default_model: &str, +) -> Result, Box> { + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err( + "Authentication required. Set one of these environment variables:\n\ + \n\ + ANTHROPIC_API_KEY= # For Anthropic (Claude)\n\ + OPENAI_API_KEY= # For OpenAI\n\ + XAI_API_KEY= # For xAI (Grok)\n\ + Z_AI_API_KEY= # For Z.AI\n\ + MINIMAX_API_KEY= # For MiniMax\n\ + KIMI_API_KEY= # For Kimi For Coding\n\ + MOONSHOT_API_KEY= # For Moonshot / Kimi\n\ + DASHSCOPE_API_KEY= # For DashScope/Qwen" + .into(), + ); + } + + println!( + "\n\x1b[1mWelcome to Claw Code\x1b[0m\n\n\ + No API key detected for the selected model ({default_model}).\n\ + Choose a provider to authenticate with:\n" + ); + + let builtin_count = BUILTIN_PROVIDERS.len(); + let template_count = LOGIN_PROVIDER_TEMPLATES.len(); + let total = builtin_count + template_count; + + println!(" Built-in:"); + for (i, provider) in BUILTIN_PROVIDERS.iter().enumerate() { + let oauth_tag = if provider.oauth.is_some() { " [OAuth]" } else { "" }; + println!(" {}. {}{}", i + 1, provider.label, oauth_tag); + } + + println!("\n Additional providers:"); + for (i, template) in LOGIN_PROVIDER_TEMPLATES.iter().enumerate() { + let oauth_tag = if template.oauth.is_some() { " [OAuth]" } else { "" }; + println!( + " {}. {}{}", + builtin_count + i + 1, + template.label, + oauth_tag + ); + } + + print!("\nEnter number (1-{total}): "); + std::io::stdout().flush()?; + let mut choice = String::new(); + std::io::stdin().read_line(&mut choice)?; + let choice = choice.trim(); + + let index: usize = choice.parse().map_err(|_| "invalid selection")?; + if index == 0 || index > total { + return Err("invalid selection".into()); + } + + // Built-in provider selected + if index <= builtin_count { + let provider = BUILTIN_PROVIDERS.get(index - 1).expect("valid builtin index"); + + // Offer OAuth if available + if let Some(ref oauth) = provider.oauth { + if prompt_oauth_or_api_key(provider.label, true)? { + run_pkce_oauth_flow(provider.id, oauth)?; + return Ok(Some(format!("{}/{}", provider.id, provider.default_model))); + } + } + + print!("Enter {} (or press Enter to cancel): ", provider.env_var); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(None); + } + + std::env::set_var(provider.env_var, key); + + // Optionally save to ~/.claw/settings.json as a simple model config. + if let Some(home) = std::env::var_os("HOME") { + let claw_dir = std::path::PathBuf::from(home).join(".claw"); + if claw_dir.exists() || std::fs::create_dir_all(&claw_dir).is_ok() { + let settings_path = claw_dir.join("settings.json"); + let model_str = format!("{}/{}", provider.id, provider.default_model); + let content = format!("{{\"model\": \"{model_str}\"}}"); + let _ = std::fs::write(&settings_path, content); + } + } + + return Ok(Some(format!("{}/{}", provider.id, provider.default_model))); + } + + // Template provider selected + let template = LOGIN_PROVIDER_TEMPLATES + .get(index - builtin_count - 1) + .expect("valid template index"); + + // Offer OAuth if available + if let Some(ref oauth) = template.oauth { + if prompt_oauth_or_api_key(template.label, true)? { + match oauth.flow { + OAuthFlowType::Pkce { .. } => { + run_pkce_oauth_flow(template.id, oauth)?; + } + OAuthFlowType::Device { .. } => { + run_device_oauth_flow(template.id, oauth)?; + } + } + return Ok(Some(format!("{}/{}", template.id, template.default_model))); + } + } + + print!( + "Enter {} (or press Enter to cancel): ", + template.api_key_env + ); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(None); + } + + std::env::set_var(template.api_key_env, key); + + save_model_provider_profile( + template.id, + template.provider_type, + template.base_url, + template.api_key_env, + Some(key), + &template.models.iter().map(|m| (*m).to_string()).collect::>(), + template.default_model, + )?; + + println!("✓ Provider profile saved for {}.", template.label); + Ok(Some(format!( + "{}/{}", + template.id, template.default_model + ))) +} + +fn run_auth_command(provider: Option<&str>) -> Result<(), Box> { + if let Some(provider_id) = provider { + // Try built-in provider first + if let Some(builtin) = BUILTIN_PROVIDERS.iter().find(|p| p.id == provider_id) { + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err(format!( + "Authentication requires a terminal. Set the environment variable instead:\n\ + \n\ + export {}=", + builtin.env_var + ) + .into()); + } + + // Offer OAuth if available + if let Some(ref oauth) = builtin.oauth { + if prompt_oauth_or_api_key(builtin.label, true)? { + run_pkce_oauth_flow(builtin.id, oauth)?; + println!("Authenticated with {} via OAuth.", builtin.label); + return Ok(()); + } + } + + print!("Enter {} (or press Enter to cancel): ", builtin.env_var); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(()); + } + + std::env::set_var(builtin.env_var, key); + println!("Authentication set for {}.", builtin.label); + return Ok(()); + } + + // Try template provider + if let Some(template) = LOGIN_PROVIDER_TEMPLATES.iter().find(|p| p.id == provider_id) { + if !std::io::stdin().is_terminal() || !std::io::stdout().is_terminal() { + return Err(format!( + "Authentication requires a terminal. Set the environment variable instead:\n\ + \n\ + export {}=", + template.api_key_env + ) + .into()); + } + + // Offer OAuth if available + if let Some(ref oauth) = template.oauth { + if prompt_oauth_or_api_key(template.label, true)? { + match oauth.flow { + OAuthFlowType::Pkce { .. } => { + run_pkce_oauth_flow(template.id, oauth)?; + } + OAuthFlowType::Device { .. } => { + run_device_oauth_flow(template.id, oauth)?; + } + } + println!( + "Authenticated with {label} via OAuth. Model: {id}/{model}", + label = template.label, + id = template.id, + model = template.default_model + ); + return Ok(()); + } + } + + print!( + "Enter {} (or press Enter to cancel): ", + template.api_key_env + ); + std::io::stdout().flush()?; + let mut key = String::new(); + std::io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + println!("Cancelled."); + return Ok(()); + } + + std::env::set_var(template.api_key_env, key); + save_model_provider_profile( + template.id, + template.provider_type, + template.base_url, + template.api_key_env, + Some(key), + &template.models.iter().map(|m| (*m).to_string()).collect::>(), + template.default_model, + )?; + println!( + "Authenticated with {label}. Model: {id}/{model}", + label = template.label, + id = template.id, + model = template.default_model + ); + return Ok(()); + } + + return Err(format!( + "unknown provider: '{provider_id}'. Run `claw auth` to see available providers." + ) + .into()); + } + + match run_provider_welcome(DEFAULT_MODEL)? { + Some(model) => { + println!("Authentication configured. Model set to {model}."); + Ok(()) + } + None => { + println!("Cancelled."); + Ok(()) + } + } +} + +// --------------------------------------------------------------------------- +// OAuth flow implementations +// --------------------------------------------------------------------------- + +fn run_pkce_oauth_flow( + provider_id: &str, + oauth: &ProviderOAuthConfig, +) -> Result> { + use runtime::{ + generate_pkce_pair, generate_state, loopback_redirect_uri, open_browser, + run_oauth_callback_server, OAuthAuthorizationRequest, OAuthTokenExchangeRequest, + }; + + let pkce = generate_pkce_pair()?; + let state = generate_state()?; + + let (authorize_url, token_url, scopes) = match oauth.flow { + OAuthFlowType::Pkce { + authorize_url, + token_url, + scopes, + } => (authorize_url, token_url, scopes), + _ => return Err("Expected PKCE flow".into()), + }; + + let config = runtime::OAuthConfig { + client_id: oauth.client_id.to_string(), + authorize_url: authorize_url.to_string(), + token_url: token_url.to_string(), + callback_port: Some(oauth.callback_port), + manual_redirect_url: None, + scopes: scopes.iter().map(|s| (*s).to_string()).collect(), + }; + + let auth_request = + OAuthAuthorizationRequest::from_config(&config, loopback_redirect_uri(oauth.callback_port), &state, &pkce); + let auth_url = auth_request.build_url(); + + println!("Opening browser for OAuth authentication..."); + println!("If the browser doesn't open automatically, visit:"); + println!(" {auth_url}"); + open_browser(&auth_url)?; + + println!("Waiting for authentication..."); + let callback = run_oauth_callback_server(oauth.callback_port, std::time::Duration::from_secs(300))?; + + if callback.state != state { + return Err("OAuth state mismatch. Possible CSRF attack.".into()); + } + + println!("Exchanging authorization code for tokens..."); + + let rt = tokio::runtime::Runtime::new()?; + let token_set = rt.block_on(async { + let client = reqwest::Client::new(); + let exchange = OAuthTokenExchangeRequest::from_config( + &config, + &callback.code, + &state, + &pkce.verifier, + &loopback_redirect_uri(oauth.callback_port), + ); + + let response = client + .post(&config.token_url) + .form(&exchange.form_params()) + .send() + .await + .map_err(|e| format!("Token exchange request failed: {e}"))?; + + let status = response.status(); + let body = response + .text() + .await + .map_err(|e| format!("Failed to read token response: {e}"))?; + + if !status.is_success() { + return Err(format!("Token exchange failed ({status}): {body}").into()); + } + + let json: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| format!("Invalid token response JSON: {e}"))?; + + let access_token = json["access_token"] + .as_str() + .ok_or("access_token missing from response")? + .to_string(); + let refresh_token = json["refresh_token"].as_str().map(String::from); + let expires_at = json["expires_in"].as_u64().map(|secs| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + secs + }); + let scopes = json["scope"] + .as_str() + .map(|s| s.split(' ').map(String::from).collect()) + .unwrap_or_default(); + + Ok::<_, Box>(runtime::OAuthTokenSet { + access_token, + refresh_token, + expires_at, + scopes, + }) + })?; + + runtime::save_provider_oauth(provider_id, &token_set)?; + println!("✓ OAuth authentication successful. Tokens saved."); + + Ok(token_set) +} + +fn run_device_oauth_flow( + provider_id: &str, + oauth: &ProviderOAuthConfig, +) -> Result> { + use runtime::{open_browser, poll_device_token, DeviceAuthRequest}; + + let (device_auth_url, token_url, scopes) = match oauth.flow { + OAuthFlowType::Device { + device_auth_url, + token_url, + scopes, + } => (device_auth_url, token_url, scopes), + _ => return Err("Expected Device flow".into()), + }; + + let rt = tokio::runtime::Runtime::new()?; + + // Step 1: Request device code + let device_response = rt.block_on(async { + let client = reqwest::Client::new(); + let scope_str = scopes.join(" "); + let params = [ + ("client_id", oauth.client_id), + ("scope", scope_str.as_str()), + ]; + let response = client + .post(device_auth_url) + .form(¶ms) + .send() + .await + .map_err(|e| format!("Device auth request failed: {e}"))?; + + let status = response.status(); + let body = response + .text() + .await + .map_err(|e| format!("Failed to read device auth response: {e}"))?; + + if !status.is_success() { + return Err(format!("Device auth failed ({status}): {body}").into()); + } + + let resp: runtime::DeviceAuthResponse = serde_json::from_str(&body) + .map_err(|e| format!("Invalid device auth response: {e}"))?; + Ok::<_, Box>(resp) + })?; + + println!("\nDevice authentication started."); + println!("User code: {}", device_response.user_code); + println!( + "Please visit: {}", + device_response + .verification_uri_complete + .as_deref() + .unwrap_or(&device_response.verification_uri) + ); + + if let Some(ref complete_uri) = device_response.verification_uri_complete { + open_browser(complete_uri)?; + } else { + open_browser(&device_response.verification_uri)?; + } + + // Step 2: Poll for token + let start = std::time::Instant::now(); + let expires_in = std::time::Duration::from_secs(device_response.expires_in); + let interval = std::time::Duration::from_secs(device_response.interval); + + let token_set = rt.block_on(async { + let client = reqwest::Client::new(); + loop { + if start.elapsed() > expires_in { + return Err::<_, Box>( + "Device authorization expired. Please try again.".into(), + ); + } + + tokio::time::sleep(interval).await; + + match poll_device_token( + &client, + &device_response.device_code, + oauth.client_id, + token_url, + ) + .await + { + Ok(Some(token_set)) => return Ok(token_set), + Ok(None) => { + println!("Waiting for authorization..."); + continue; + } + Err(e) => return Err(e.into()), + } + } + })?; + + runtime::save_provider_oauth(provider_id, &token_set)?; + println!("✓ OAuth authentication successful. Tokens saved."); + + Ok(token_set) +} + +fn prompt_oauth_or_api_key(provider_label: &str, has_oauth: bool) -> Result> { + if !has_oauth { + return Ok(false); + } + + println!("\nChoose authentication method for {provider_label}:"); + println!(" 1. Sign in with {provider_label} account (OAuth) — recommended"); + println!(" 2. Enter API key manually"); + print!("Select [1]: "); + std::io::stdout().flush()?; + + let mut choice = String::new(); + std::io::stdin().read_line(&mut choice)?; + let choice = choice.trim(); + + Ok(choice.is_empty() || choice == "1") +} + impl ApiClient for AnthropicRuntimeClient { #[allow(clippy::too_many_lines)] fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { @@ -9121,6 +10117,11 @@ fn print_help_to(out: &mut impl Write) -> io::Result<()> { writeln!(out, " claw skills")?; writeln!(out, " claw system-prompt [--cwd PATH] [--date YYYY-MM-DD]")?; writeln!(out, " claw init")?; + writeln!(out, " claw auth [PROVIDER]")?; + writeln!( + out, + " Authenticate with a model provider (anthropic, openai, xai)" + )?; writeln!( out, " claw export [PATH] [--session SESSION] [--output PATH]" @@ -9707,6 +10708,28 @@ mod tests { ); } + #[test] + fn parse_args_auth_without_provider() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["auth".to_string()]).expect("args should parse"), + CliAction::Auth { provider: None } + ); + } + + #[test] + fn parse_args_auth_with_provider() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["auth".to_string(), "openai".to_string()]).expect("args should parse"), + CliAction::Auth { + provider: Some("openai".to_string()), + } + ); + } + #[test] fn merge_prompt_with_stdin_returns_prompt_unchanged_when_no_pipe() { // given @@ -13619,3 +14642,29 @@ mod dump_manifests_tests { let _ = fs::remove_dir_all(&root); } } + + +#[cfg(test)] +mod auth_tests { + use super::{parse_args, CliAction}; + + #[test] + fn parse_args_auth_without_provider() { + let args = vec!["auth".to_string()]; + let action = parse_args(&args).expect("should parse"); + match action { + CliAction::Auth { provider } => assert!(provider.is_none()), + other => panic!("expected CliAction::Auth, got {other:?}"), + } + } + + #[test] + fn parse_args_auth_with_provider() { + let args = vec!["auth".to_string(), "openai".to_string()]; + let action = parse_args(&args).expect("should parse"); + match action { + CliAction::Auth { provider } => assert_eq!(provider, Some("openai".to_string())), + other => panic!("expected CliAction::Auth, got {other:?}"), + } + } +}