diff --git a/Cargo.lock b/Cargo.lock index a1c390646..28d86ea93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3633,6 +3633,7 @@ dependencies = [ "metrics", "metrics-exporter-prometheus", "miette", + "nix", "openshell-core", "openshell-driver-docker", "openshell-driver-kubernetes", diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index 095b7d020..52beda472 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -23,7 +23,13 @@ Each runtime receives a sandbox spec from the gateway and is responsible for: | Docker | Local development with Docker available. | Container plus nested sandbox namespace. | Uses host networking so loopback gateway endpoints work from the supervisor. | | Podman | Rootless or single-machine deployments. | Container plus nested sandbox namespace. | Uses the Podman REST API, OCI image volumes, and CDI GPU devices when available. | | Kubernetes | Cluster deployment through Helm. | Pod plus nested sandbox namespace. | Uses Kubernetes API objects, service accounts, secrets, PVC-backed workspace storage, and GPU resources. | -| VM | Experimental microVM isolation. | Per-sandbox libkrun VM. | Gateway spawns `openshell-driver-vm` as a subprocess over a Unix socket. | +| VM | Experimental microVM isolation. | Per-sandbox libkrun VM. | Gateway spawns `openshell-driver-vm` as a subprocess over a private, state-local Unix socket. | + +VM runtime state paths are derived only from driver-validated sandbox IDs +matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a +private `run/` directory plus Unix peer UID/PID checks. Standalone +unauthenticated TCP mode is disabled unless explicitly enabled for local +development. Runtime-specific implementation notes belong in the driver crate README: diff --git a/crates/openshell-driver-vm/README.md b/crates/openshell-driver-vm/README.md index a3bdf9822..0a11ceb0a 100644 --- a/crates/openshell-driver-vm/README.md +++ b/crates/openshell-driver-vm/README.md @@ -42,7 +42,7 @@ By default `mise run gateway:vm`: - Listens on plaintext HTTP at `127.0.0.1:18081`. - Registers the CLI gateway `vm-dev` by writing `~/.config/openshell/gateways/vm-dev/metadata.json`. It does not modify the workspace `.env`. - Persists the gateway SQLite DB under `.cache/gateway-vm/gateway.db`. -- Places the VM driver state (per-sandbox rootfs + `compute-driver.sock`) under `/tmp/openshell-vm-driver-$USER-vm-dev/` so the AF_UNIX socket path stays under macOS `SUN_LEN`. +- Places the VM driver state (per-sandbox rootfs plus `run/compute-driver.sock`) under `/tmp/openshell-vm-driver-$USER-vm-dev/` so the AF_UNIX socket path stays under macOS `SUN_LEN`. - Passes `--driver-dir $PWD/target/debug` so the freshly built `openshell-driver-vm` is used instead of an older installed copy from `~/.local/libexec/openshell`, `/usr/libexec/openshell`, or `/usr/local/libexec`. For GPU passthrough (VFIO), pass `-- --gpu` and run with root privileges: @@ -124,7 +124,7 @@ The gateway resolves `openshell-driver-vm` in this order: `--driver-dir`, conven |---|---|---|---| | `--drivers vm` | `OPENSHELL_DRIVERS` | `kubernetes` | Select the VM compute driver. | | `--grpc-endpoint URL` | `OPENSHELL_GRPC_ENDPOINT` | — | Required. URL the sandbox guest dials to reach the gateway. Use `http://host.containers.internal:` (or `host.docker.internal` / `host.openshell.internal`) so traffic flows through gvproxy's host-loopback NAT (HostIP `192.168.127.254` → host `127.0.0.1`). Loopback URLs like `http://127.0.0.1:` are rewritten automatically by the driver. The bare gateway IP (`192.168.127.1`) only carries gvproxy's own services and will not reach host-bound ports. | -| `--vm-driver-state-dir DIR` | `OPENSHELL_VM_DRIVER_STATE_DIR` | `target/openshell-vm-driver` | Per-sandbox rootfs, console logs, and the `compute-driver.sock` UDS. | +| `--vm-driver-state-dir DIR` | `OPENSHELL_VM_DRIVER_STATE_DIR` | `target/openshell-vm-driver` | Per-sandbox rootfs, console logs, image cache, and private `run/compute-driver.sock` UDS. | | `--driver-dir DIR` | `OPENSHELL_DRIVER_DIR` | unset | Override the directory searched for `openshell-driver-vm`. | | `--vm-driver-vcpus N` | `OPENSHELL_VM_DRIVER_VCPUS` | `2` | vCPUs per sandbox. | | `--vm-driver-mem-mib N` | `OPENSHELL_VM_DRIVER_MEM_MIB` | `2048` | Memory per sandbox, in MiB. | @@ -156,7 +156,7 @@ RUST_LOG=openshell_server=debug,openshell_driver_vm=debug \ mise run gateway:vm ``` -The VM guest's serial console is appended to `//console.log`. The `compute-driver.sock` lives at `/compute-driver.sock`; the gateway removes it on clean shutdown via `ManagedDriverProcess::drop`. +The VM guest's serial console is appended to `//console.log`. Sandbox IDs must match `[A-Za-z0-9._-]{1,128}` before the driver uses them in host paths. The gateway-owned compute-driver socket lives at `/run/compute-driver.sock`; OpenShell creates `run/` with owner-only permissions, removes same-owner stale sockets, and the gateway removes the socket on clean shutdown via `ManagedDriverProcess::drop`. UDS clients must match the driver UID and provide the expected gateway process PID by default. Standalone same-UID UDS mode requires the explicit `--allow-same-uid-peer` development flag. TCP mode is disabled by default because it is unauthenticated; use `--allow-unauthenticated-tcp --bind-address 127.0.0.1:50061` only for local development. ## Prerequisites diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 92cab23af..b797f4835 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -38,7 +38,7 @@ use std::fs; use std::io::Read; use std::net::Ipv4Addr; use std::os::unix::fs::PermissionsExt; -use std::path::{Path, PathBuf}; +use std::path::{Component, Path, PathBuf}; use std::pin::Pin; use std::process::Stdio; use std::sync::Arc; @@ -362,7 +362,7 @@ impl VmDriver { let is_gpu = spec.is_some_and(|s| s.gpu); let gpu_device = spec.map_or("", |s| s.gpu_device.as_str()); - let state_dir = sandbox_state_dir(&self.config.state_dir, &sandbox.id); + let state_dir = sandbox_state_dir(&self.config.state_dir, &sandbox.id)?; let rootfs = state_dir.join("rootfs"); let image_ref = self.resolved_sandbox_image(sandbox).ok_or_else(|| { Status::failed_precondition( @@ -620,6 +620,10 @@ impl VmDriver { sandbox_id: &str, sandbox_name: &str, ) -> Result { + if !sandbox_id.is_empty() { + validate_sandbox_id(sandbox_id)?; + } + let record = { let registry = self.registry.lock().await; if let Some((id, record)) = registry.get_key_value(sandbox_id) { @@ -670,13 +674,7 @@ impl VmDriver { self.release_gpu_and_subnet(&record_id); } - if let Err(err) = tokio::fs::remove_dir_all(&state_dir).await - && err.kind() != std::io::ErrorKind::NotFound - { - return Err(Status::internal(format!( - "failed to remove state dir: {err}" - ))); - } + remove_sandbox_state_dir(&self.config.state_dir, &state_dir).await?; { let mut registry = self.registry.lock().await; @@ -692,6 +690,10 @@ impl VmDriver { sandbox_id: &str, sandbox_name: &str, ) -> Result, Status> { + if !sandbox_id.is_empty() { + validate_sandbox_id(sandbox_id)?; + } + let registry = self.registry.lock().await; let sandbox = if sandbox_id.is_empty() { registry @@ -1452,6 +1454,8 @@ fn check_gpu_privileges() -> Result<(), String> { // gRPC API surface, so boxing here would diverge from every other handler. #[allow(clippy::result_large_err)] fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Status> { + validate_sandbox_id(&sandbox.id)?; + let spec = sandbox .spec .as_ref() @@ -1487,6 +1491,32 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu Ok(()) } +#[allow(clippy::result_large_err)] +fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { + if sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox id is required")); + } + if sandbox_id.len() > 128 { + return Err(Status::invalid_argument( + "sandbox id exceeds maximum length (128 bytes)", + )); + } + if matches!(sandbox_id, "." | "..") { + return Err(Status::invalid_argument( + "sandbox id must match [A-Za-z0-9._-]{1,128}", + )); + } + if !sandbox_id + .bytes() + .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'_' | b'-')) + { + return Err(Status::invalid_argument( + "sandbox id must match [A-Za-z0-9._-]{1,128}", + )); + } + Ok(()) +} + #[allow(clippy::result_large_err)] fn parse_registry_reference(image_ref: &str) -> Result { Reference::try_from(image_ref).map_err(|err| { @@ -2236,8 +2266,71 @@ fn sandboxes_root_dir(root: &Path) -> PathBuf { root.join("sandboxes") } -fn sandbox_state_dir(root: &Path, sandbox_id: &str) -> PathBuf { - sandboxes_root_dir(root).join(sandbox_id) +#[allow(clippy::result_large_err)] +fn sandbox_state_dir(root: &Path, sandbox_id: &str) -> Result { + validate_sandbox_id(sandbox_id)?; + Ok(sandboxes_root_dir(root).join(sandbox_id)) +} + +#[allow(clippy::result_large_err)] +fn validate_sandbox_state_dir(root: &Path, state_dir: &Path) -> Result<(), Status> { + let sandboxes_root = sandboxes_root_dir(root); + let relative = state_dir.strip_prefix(&sandboxes_root).map_err(|_| { + Status::internal(format!( + "refusing to use sandbox state path outside vm state root: {}", + state_dir.display() + )) + })?; + + let mut components = relative.components(); + match components.next() { + Some(Component::Normal(_)) => {} + _ => { + return Err(Status::internal(format!( + "refusing to use malformed sandbox state path: {}", + state_dir.display() + ))); + } + } + if components.next().is_some() { + return Err(Status::internal(format!( + "refusing to use nested sandbox state path: {}", + state_dir.display() + ))); + } + + Ok(()) +} + +async fn remove_sandbox_state_dir(root: &Path, state_dir: &Path) -> Result<(), Status> { + validate_sandbox_state_dir(root, state_dir)?; + + let metadata = match tokio::fs::symlink_metadata(state_dir).await { + Ok(metadata) => metadata, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(()), + Err(err) => { + return Err(Status::internal(format!( + "failed to stat sandbox state dir: {err}" + ))); + } + }; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(Status::internal(format!( + "refusing to remove symlinked sandbox state dir: {}", + state_dir.display() + ))); + } + if !file_type.is_dir() { + return Err(Status::internal(format!( + "sandbox state path is not a directory: {}", + state_dir.display() + ))); + } + + tokio::fs::remove_dir_all(state_dir) + .await + .map_err(|err| Status::internal(format!("failed to remove state dir: {err}"))) } fn image_cache_root_dir(root: &Path) -> PathBuf { @@ -2430,6 +2523,7 @@ mod tests { }; use prost_types::{Struct, Value, value::Kind}; use std::fs; + use std::path::Path; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{SystemTime, UNIX_EPOCH}; use tonic::Code; @@ -2437,6 +2531,7 @@ mod tests { #[test] fn validate_vm_sandbox_rejects_gpu_when_not_enabled() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { gpu: true, ..Default::default() @@ -2452,6 +2547,7 @@ mod tests { #[test] fn validate_vm_sandbox_accepts_gpu_when_enabled() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { gpu: true, ..Default::default() @@ -2464,6 +2560,7 @@ mod tests { #[test] fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { gpu: false, gpu_device: "0000:2d:00.0".to_string(), @@ -2480,6 +2577,7 @@ mod tests { #[test] fn validate_vm_sandbox_rejects_platform_config() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { template: Some(SandboxTemplate { platform_config: Some(Struct { @@ -2506,6 +2604,7 @@ mod tests { #[test] fn validate_vm_sandbox_accepts_template_image() { let sandbox = Sandbox { + id: "sandbox-123".to_string(), spec: Some(SandboxSpec { template: Some(SandboxTemplate { image: "ghcr.io/example/sandbox:latest".to_string(), @@ -2518,6 +2617,51 @@ mod tests { validate_vm_sandbox(&sandbox, false).expect("template.image should be accepted"); } + #[test] + fn validate_vm_sandbox_rejects_path_unsafe_ids() { + let mut unsafe_ids = [ + "", + ".", + "..", + "../escape", + "/tmp/escape", + "nested/path", + "nested\\path", + "bad\nid", + "bad id", + "unicodé", + ] + .into_iter() + .map(str::to_string) + .collect::>(); + unsafe_ids.push("a".repeat(129)); + + for sandbox_id in unsafe_ids { + let sandbox = Sandbox { + id: sandbox_id.clone(), + spec: Some(SandboxSpec { + template: Some(SandboxTemplate { + image: "ghcr.io/example/sandbox:latest".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, false) + .expect_err("path-unsafe sandbox id should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument, "id={sandbox_id:?}"); + assert!(err.message().contains("sandbox id"), "id={sandbox_id:?}"); + } + } + + #[test] + fn sandbox_state_dir_rejects_path_unsafe_ids() { + let err = sandbox_state_dir(Path::new("/tmp/openshell-vm"), "../escape") + .expect_err("path traversal should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + } + #[test] fn capabilities_report_configured_default_image() { let driver = VmDriver { @@ -2919,9 +3063,14 @@ mod tests { #[tokio::test] async fn delete_sandbox_keeps_registry_entry_when_cleanup_fails() { + let base = unique_temp_dir(); + let driver_state = base.join("driver-state"); let (events, _) = broadcast::channel(WATCH_BUFFER); let driver = VmDriver { - config: VmDriverConfig::default(), + config: VmDriverConfig { + state_dir: driver_state.clone(), + ..Default::default() + }, launcher_bin: PathBuf::from("openshell-driver-vm"), registry: Arc::new(Mutex::new(HashMap::new())), image_cache_lock: Arc::new(Mutex::new(())), @@ -2933,9 +3082,8 @@ mod tests { ))), }; - let base = unique_temp_dir(); - std::fs::create_dir_all(&base).unwrap(); - let state_file = base.join("state-file"); + let state_file = sandbox_state_dir(&driver_state, "sandbox-123").unwrap(); + std::fs::create_dir_all(state_file.parent().unwrap()).unwrap(); std::fs::write(&state_file, "not a directory").unwrap(); insert_test_record( @@ -2950,10 +3098,11 @@ mod tests { .delete_sandbox("sandbox-123", "sandbox-123") .await .expect_err("state dir cleanup should fail for a file path"); - assert!(err.message().contains("failed to remove state dir")); + assert!(err.message().contains("not a directory")); assert!(driver.registry.lock().await.contains_key("sandbox-123")); - let retry_state_dir = base.join("state-dir"); + std::fs::remove_file(&state_file).unwrap(); + let retry_state_dir = sandbox_state_dir(&driver_state, "sandbox-123").unwrap(); std::fs::create_dir_all(&retry_state_dir).unwrap(); { let mut registry = driver.registry.lock().await; @@ -2975,6 +3124,40 @@ mod tests { let _ = std::fs::remove_dir_all(base); } + #[tokio::test] + async fn remove_sandbox_state_dir_rejects_paths_outside_state_root() { + let base = unique_temp_dir(); + let state_root = base.join("driver-state"); + let outside = base.join("outside"); + std::fs::create_dir_all(&outside).unwrap(); + + let err = remove_sandbox_state_dir(&state_root, &outside) + .await + .expect_err("outside state paths should be rejected"); + assert!(err.message().contains("outside vm state root")); + + let _ = std::fs::remove_dir_all(base); + } + + #[cfg(unix)] + #[tokio::test] + async fn remove_sandbox_state_dir_rejects_symlinked_state_dir() { + let base = unique_temp_dir(); + let state_root = base.join("driver-state"); + let target = base.join("target"); + let state_dir = sandbox_state_dir(&state_root, "sandbox-123").unwrap(); + std::fs::create_dir_all(&target).unwrap(); + std::fs::create_dir_all(state_dir.parent().unwrap()).unwrap(); + std::os::unix::fs::symlink(&target, &state_dir).unwrap(); + + let err = remove_sandbox_state_dir(&state_root, &state_dir) + .await + .expect_err("symlinked state dir should be rejected"); + assert!(err.message().contains("symlinked sandbox state dir")); + + let _ = std::fs::remove_dir_all(base); + } + #[test] fn validate_openshell_endpoint_accepts_loopback_hosts() { validate_openshell_endpoint("http://127.0.0.1:8080") diff --git a/crates/openshell-driver-vm/src/main.rs b/crates/openshell-driver-vm/src/main.rs index 596e6c88d..ed9967f4a 100644 --- a/crates/openshell-driver-vm/src/main.rs +++ b/crates/openshell-driver-vm/src/main.rs @@ -2,22 +2,27 @@ // SPDX-License-Identifier: Apache-2.0 use clap::Parser; +use futures::Stream; use miette::{IntoDiagnostic, Result}; use openshell_core::VERSION; use openshell_core::proto::compute::v1::compute_driver_server::ComputeDriverServer; #[cfg(target_os = "macos")] use openshell_driver_vm::{VM_RUNTIME_DIR_ENV, configured_runtime_dir}; use openshell_driver_vm::{VmBackend, VmDriver, VmDriverConfig, VmLaunchConfig, procguard, run_vm}; +use std::io; use std::net::SocketAddr; -use std::path::PathBuf; -use tokio::net::UnixListener; -use tokio_stream::wrappers::UnixListenerStream; +use std::os::unix::fs::{FileTypeExt, MetadataExt, PermissionsExt}; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::net::{UnixListener, UnixStream}; use tracing::info; use tracing_subscriber::EnvFilter; #[derive(Parser, Debug)] #[command(name = "openshell-driver-vm")] #[command(version = VERSION)] +#[allow(clippy::struct_excessive_bools)] struct Args { #[arg(long, hide = true, default_value_t = false)] internal_run_vm: bool, @@ -46,15 +51,28 @@ struct Args { #[arg(long, hide = true, default_value_t = 1)] vm_krun_log_level: u32, + #[arg(long, env = "OPENSHELL_COMPUTE_DRIVER_BIND")] + bind_address: Option, + + #[arg(long, env = "OPENSHELL_COMPUTE_DRIVER_SOCKET")] + bind_socket: Option, + + #[arg(long, hide = true)] + expected_peer_pid: Option, + #[arg( long, - env = "OPENSHELL_COMPUTE_DRIVER_BIND", - default_value = "127.0.0.1:50061" + env = "OPENSHELL_COMPUTE_DRIVER_ALLOW_UNAUTHENTICATED_TCP", + default_value_t = false )] - bind_address: SocketAddr, + allow_unauthenticated_tcp: bool, - #[arg(long, env = "OPENSHELL_COMPUTE_DRIVER_SOCKET")] - bind_socket: Option, + #[arg( + long, + env = "OPENSHELL_COMPUTE_DRIVER_ALLOW_SAME_UID_PEER", + default_value_t = false + )] + allow_same_uid_peer: bool, #[arg(long, env = "OPENSHELL_LOG_LEVEL", default_value = "info")] log_level: String, @@ -154,6 +172,8 @@ async fn main() -> Result<()> { ) .init(); + let listen_mode = compute_driver_listen_mode(&args).map_err(|err| miette::miette!("{err}"))?; + // Arm procguard so that if the gateway is killed (SIGKILL or crash) // we also die. Without this the driver is reparented to init and // keeps its per-sandbox VM launchers alive forever. Launchers have @@ -170,18 +190,18 @@ async fn main() -> Result<()> { openshell_endpoint: args .openshell_endpoint .ok_or_else(|| miette::miette!("OPENSHELL_GRPC_ENDPOINT is required"))?, - state_dir: args.state_dir, + state_dir: args.state_dir.clone(), launcher_bin: None, - default_image: args.default_image, - ssh_handshake_secret: args.ssh_handshake_secret.unwrap_or_default(), + default_image: args.default_image.clone(), + ssh_handshake_secret: args.ssh_handshake_secret.clone().unwrap_or_default(), ssh_handshake_skew_secs: args.ssh_handshake_skew_secs, - log_level: args.log_level, + log_level: args.log_level.clone(), krun_log_level: args.krun_log_level, vcpus: args.vcpus, mem_mib: args.mem_mib, - guest_tls_ca: args.guest_tls_ca, - guest_tls_cert: args.guest_tls_cert, - guest_tls_key: args.guest_tls_key, + guest_tls_ca: args.guest_tls_ca.clone(), + guest_tls_cert: args.guest_tls_cert.clone(), + guest_tls_key: args.guest_tls_key.clone(), gpu_enabled: args.gpu, gpu_mem_mib: args.gpu_mem_mib, gpu_vcpus: args.gpu_vcpus, @@ -189,32 +209,241 @@ async fn main() -> Result<()> { .await .map_err(|err| miette::miette!("{err}"))?; - if let Some(socket_path) = args.bind_socket { - if let Some(parent) = socket_path.parent() { - std::fs::create_dir_all(parent).into_diagnostic()?; + match listen_mode { + ComputeDriverListenMode::Unix { + socket_path, + expected_peer_pid, + } => { + prepare_compute_driver_socket(&socket_path).map_err(|err| miette::miette!("{err}"))?; + + info!(socket = %socket_path.display(), "Starting vm compute driver"); + let listener = UnixListener::bind(&socket_path).into_diagnostic()?; + restrict_socket_permissions(&socket_path).map_err(|err| miette::miette!("{err}"))?; + let result = tonic::transport::Server::builder() + .add_service(ComputeDriverServer::new(driver)) + .serve_with_incoming(AuthenticatedUnixIncoming::new(listener, expected_peer_pid)) + .await + .into_diagnostic(); + let _ = std::fs::remove_file(&socket_path); + result + } + ComputeDriverListenMode::Tcp(bind_address) => { + info!(address = %bind_address, "Starting unauthenticated dev vm compute driver"); + tonic::transport::Server::builder() + .add_service(ComputeDriverServer::new(driver)) + .serve(bind_address) + .await + .into_diagnostic() } - match std::fs::remove_file(&socket_path) { - Ok(()) => {} - Err(err) if err.kind() == std::io::ErrorKind::NotFound => {} - Err(err) => return Err(err).into_diagnostic(), + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ComputeDriverListenMode { + Unix { + socket_path: PathBuf, + expected_peer_pid: Option, + }, + Tcp(SocketAddr), +} + +fn compute_driver_listen_mode(args: &Args) -> std::result::Result { + if let Some(socket_path) = args.bind_socket.clone() { + if args.expected_peer_pid.is_none() && !args.allow_same_uid_peer { + return Err( + "--expected-peer-pid is required with --bind-socket; use --allow-same-uid-peer only for local development" + .to_string(), + ); + } + return Ok(ComputeDriverListenMode::Unix { + socket_path, + expected_peer_pid: args.expected_peer_pid, + }); + } + + if !args.allow_unauthenticated_tcp { + return Err( + "--bind-socket is required; unauthenticated TCP mode is disabled unless --allow-unauthenticated-tcp is set for local development" + .to_string(), + ); + } + + let Some(bind_address) = args.bind_address else { + return Err("--bind-address is required with --allow-unauthenticated-tcp".to_string()); + }; + + Ok(ComputeDriverListenMode::Tcp(bind_address)) +} + +fn prepare_compute_driver_socket(socket_path: &Path) -> std::result::Result<(), String> { + let Some(parent) = socket_path.parent() else { + return Err(format!( + "vm compute driver socket path '{}' has no parent directory", + socket_path.display() + )); + }; + let expected_uid = current_euid(); + prepare_private_socket_dir(parent, expected_uid)?; + remove_stale_socket(socket_path, expected_uid) +} + +fn current_euid() -> u32 { + nix::unistd::Uid::effective().as_raw() +} + +fn prepare_private_socket_dir( + socket_dir: &Path, + expected_uid: u32, +) -> std::result::Result<(), String> { + std::fs::create_dir_all(socket_dir) + .map_err(|err| format!("create socket dir {}: {err}", socket_dir.display()))?; + let metadata = std::fs::symlink_metadata(socket_dir) + .map_err(|err| format!("stat socket dir {}: {err}", socket_dir.display()))?; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(format!( + "socket dir {} is a symlink; refusing to use it", + socket_dir.display() + )); + } + if !file_type.is_dir() { + return Err(format!( + "socket dir {} is not a directory", + socket_dir.display() + )); + } + if metadata.uid() != expected_uid { + return Err(format!( + "socket dir {} is owned by uid {} but current euid is {}", + socket_dir.display(), + metadata.uid(), + expected_uid + )); + } + std::fs::set_permissions(socket_dir, std::fs::Permissions::from_mode(0o700)) + .map_err(|err| format!("chmod socket dir {}: {err}", socket_dir.display())) +} + +fn remove_stale_socket(socket_path: &Path, expected_uid: u32) -> std::result::Result<(), String> { + let metadata = match std::fs::symlink_metadata(socket_path) { + Ok(metadata) => metadata, + Err(err) if err.kind() == io::ErrorKind::NotFound => return Ok(()), + Err(err) => return Err(format!("stat socket {}: {err}", socket_path.display())), + }; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(format!( + "socket {} is a symlink; refusing to remove it", + socket_path.display() + )); + } + if metadata.uid() != expected_uid { + return Err(format!( + "socket {} is owned by uid {} but current euid is {}", + socket_path.display(), + metadata.uid(), + expected_uid + )); + } + if !file_type.is_socket() { + return Err(format!( + "socket path {} exists but is not a Unix socket", + socket_path.display() + )); + } + std::fs::remove_file(socket_path) + .map_err(|err| format!("remove stale socket {}: {err}", socket_path.display())) +} + +fn restrict_socket_permissions(socket_path: &Path) -> std::result::Result<(), String> { + std::fs::set_permissions(socket_path, std::fs::Permissions::from_mode(0o600)) + .map_err(|err| format!("chmod socket {}: {err}", socket_path.display())) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct PeerCredentials { + uid: u32, + pid: Option, +} + +fn peer_credentials(stream: &UnixStream) -> std::result::Result { + let credentials = stream + .peer_cred() + .map_err(|err| format!("read peer credentials: {err}"))?; + Ok(PeerCredentials { + uid: credentials.uid(), + pid: credentials.pid(), + }) +} + +fn authorize_peer_credentials( + peer: PeerCredentials, + driver_uid: u32, + gateway_pid: Option, +) -> std::result::Result<(), String> { + if peer.uid != driver_uid { + return Err(format!( + "peer uid {} does not match current euid {}", + peer.uid, driver_uid + )); + } + let Some(gateway_pid) = gateway_pid else { + return Ok(()); + }; + let Some(peer_process_id) = peer.pid.and_then(|pid| u32::try_from(pid).ok()) else { + return Err(format!( + "peer pid is unavailable; expected gateway pid {gateway_pid}" + )); + }; + if peer_process_id != gateway_pid { + return Err(format!( + "peer pid {peer_process_id} does not match expected gateway pid {gateway_pid}" + )); + } + Ok(()) +} + +struct AuthenticatedUnixIncoming { + listener: UnixListener, + expected_uid: u32, + expected_peer_pid: Option, +} + +impl AuthenticatedUnixIncoming { + fn new(listener: UnixListener, expected_peer_pid: Option) -> Self { + Self { + listener, + expected_uid: current_euid(), + expected_peer_pid, } + } +} - info!(socket = %socket_path.display(), "Starting vm compute driver"); - let listener = UnixListener::bind(&socket_path).into_diagnostic()?; - let result = tonic::transport::Server::builder() - .add_service(ComputeDriverServer::new(driver)) - .serve_with_incoming(UnixListenerStream::new(listener)) - .await - .into_diagnostic(); - let _ = std::fs::remove_file(&socket_path); - result - } else { - info!(address = %args.bind_address, "Starting vm compute driver"); - tonic::transport::Server::builder() - .add_service(ComputeDriverServer::new(driver)) - .serve(args.bind_address) - .await - .into_diagnostic() +impl Stream for AuthenticatedUnixIncoming { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + loop { + match this.listener.poll_accept(cx) { + Poll::Ready(Ok((stream, _addr))) => { + let authorized = peer_credentials(&stream).and_then(|peer| { + authorize_peer_credentials(peer, this.expected_uid, this.expected_peer_pid) + }); + match authorized { + Ok(()) => return Poll::Ready(Some(Ok(stream))), + Err(err) => { + tracing::warn!( + error = %err, + "rejected vm compute driver UDS client" + ); + } + } + } + Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), + Poll::Pending => return Poll::Pending, + } + } } } @@ -310,3 +539,165 @@ fn maybe_reexec_internal_vm_with_runtime_env() -> Result<()> { fn maybe_reexec_internal_vm_with_runtime_env() -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::{ + Args, ComputeDriverListenMode, PeerCredentials, authorize_peer_credentials, + compute_driver_listen_mode, + }; + use clap::Parser; + use std::path::PathBuf; + + #[test] + fn peer_authorization_accepts_matching_uid_and_pid() { + authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: Some(42), + }, + 1000, + Some(42), + ) + .unwrap(); + } + + #[test] + fn peer_authorization_rejects_wrong_pid() { + let err = authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: Some(7), + }, + 1000, + Some(42), + ) + .expect_err("wrong pid should be rejected"); + assert!(err.contains("does not match expected gateway pid")); + } + + #[test] + fn peer_authorization_rejects_wrong_uid() { + let err = authorize_peer_credentials( + PeerCredentials { + uid: 1001, + pid: Some(42), + }, + 1000, + Some(42), + ) + .expect_err("wrong uid should be rejected"); + assert!(err.contains("does not match current euid")); + } + + #[test] + fn peer_authorization_rejects_missing_pid_when_expected() { + let err = authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: None, + }, + 1000, + Some(42), + ) + .expect_err("missing pid should be rejected"); + assert!(err.contains("peer pid is unavailable")); + } + + #[test] + fn peer_authorization_accepts_matching_uid_without_expected_pid() { + authorize_peer_credentials( + PeerCredentials { + uid: 1000, + pid: None, + }, + 1000, + None, + ) + .unwrap(); + } + + #[test] + fn listen_mode_rejects_default_tcp() { + let args = Args::parse_from(["openshell-driver-vm"]); + let err = compute_driver_listen_mode(&args).expect_err("default TCP should be disabled"); + assert!(err.contains("--bind-socket is required")); + } + + #[test] + fn listen_mode_rejects_bind_address_without_tcp_opt_in() { + let args = Args::parse_from(["openshell-driver-vm", "--bind-address", "127.0.0.1:50061"]); + let err = + compute_driver_listen_mode(&args).expect_err("TCP bind should require explicit opt-in"); + assert!(err.contains("--allow-unauthenticated-tcp")); + } + + #[test] + fn listen_mode_requires_bind_address_with_tcp_opt_in() { + let args = Args::parse_from(["openshell-driver-vm", "--allow-unauthenticated-tcp"]); + let err = + compute_driver_listen_mode(&args).expect_err("TCP opt-in should require an address"); + assert!(err.contains("--bind-address is required")); + } + + #[test] + fn listen_mode_accepts_explicit_unauthenticated_tcp() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--allow-unauthenticated-tcp", + "--bind-address", + "127.0.0.1:50061", + ]); + assert_eq!( + compute_driver_listen_mode(&args).unwrap(), + ComputeDriverListenMode::Tcp("127.0.0.1:50061".parse().unwrap()) + ); + } + + #[test] + fn listen_mode_requires_expected_peer_pid_for_uds() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--bind-socket", + "/tmp/compute-driver.sock", + ]); + let err = compute_driver_listen_mode(&args) + .expect_err("UDS should require gateway peer pid by default"); + assert!(err.contains("--expected-peer-pid is required")); + } + + #[test] + fn listen_mode_accepts_uds_with_expected_peer_pid() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--bind-socket", + "/tmp/compute-driver.sock", + "--expected-peer-pid", + "42", + ]); + assert_eq!( + compute_driver_listen_mode(&args).unwrap(), + ComputeDriverListenMode::Unix { + socket_path: PathBuf::from("/tmp/compute-driver.sock"), + expected_peer_pid: Some(42), + } + ); + } + + #[test] + fn listen_mode_accepts_explicit_same_uid_uds_dev_mode() { + let args = Args::parse_from([ + "openshell-driver-vm", + "--bind-socket", + "/tmp/compute-driver.sock", + "--allow-same-uid-peer", + ]); + assert_eq!( + compute_driver_listen_mode(&args).unwrap(), + ComputeDriverListenMode::Unix { + socket_path: PathBuf::from("/tmp/compute-driver.sock"), + expected_peer_pid: None, + } + ); + } +} diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index cb6561f3e..fab20186c 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -82,6 +82,7 @@ rand = { workspace = true } petname = "2" ipnet = "2" tempfile = "3" +nix = { workspace = true } [features] dev-settings = ["openshell-core/dev-settings"] diff --git a/crates/openshell-server/src/compute/vm.rs b/crates/openshell-server/src/compute/vm.rs index e5b974f74..1e62d4942 100644 --- a/crates/openshell-server/src/compute/vm.rs +++ b/crates/openshell-server/src/compute/vm.rs @@ -38,6 +38,10 @@ use openshell_core::proto::compute::v1::{ GetCapabilitiesRequest, compute_driver_client::ComputeDriverClient, }; use openshell_core::{Config, Error, Result}; +#[cfg(unix)] +use std::os::unix::fs::{FileTypeExt, MetadataExt, PermissionsExt}; +#[cfg(unix)] +use std::path::Path; use std::path::PathBuf; #[cfg(unix)] use std::{io::ErrorKind, process::Stdio, sync::Arc, time::Duration}; @@ -52,6 +56,8 @@ use tonic::transport::Endpoint; use tower::service_fn; const DRIVER_BIN_NAME: &str = "openshell-driver-vm"; +const COMPUTE_DRIVER_SOCKET_RUN_DIR: &str = "run"; +const COMPUTE_DRIVER_SOCKET_NAME: &str = "compute-driver.sock"; /// Configuration for launching and talking to the VM compute driver. #[derive(Debug, Clone)] @@ -210,7 +216,145 @@ fn push_unique_path(paths: &mut Vec, path: PathBuf) { /// Path of the Unix domain socket the driver will listen on. pub fn compute_driver_socket_path(vm_config: &VmComputeConfig) -> PathBuf { - vm_config.state_dir.join("compute-driver.sock") + vm_config + .state_dir + .join(COMPUTE_DRIVER_SOCKET_RUN_DIR) + .join(COMPUTE_DRIVER_SOCKET_NAME) +} + +#[cfg(unix)] +fn prepare_compute_driver_socket_path( + vm_config: &VmComputeConfig, + socket_path: &Path, +) -> Result<()> { + let expected_uid = current_euid(); + prepare_vm_state_dir(&vm_config.state_dir, expected_uid)?; + let parent = socket_path.parent().ok_or_else(|| { + Error::execution(format!( + "vm compute driver socket path '{}' has no parent directory", + socket_path.display() + )) + })?; + prepare_private_socket_dir(parent, expected_uid)?; + remove_stale_socket(socket_path, expected_uid) +} + +#[cfg(unix)] +fn current_euid() -> u32 { + nix::unistd::Uid::effective().as_raw() +} + +#[cfg(unix)] +fn prepare_vm_state_dir(state_dir: &Path, expected_uid: u32) -> Result<()> { + std::fs::create_dir_all(state_dir).map_err(|err| { + Error::execution(format!( + "failed to create vm driver state dir '{}': {err}", + state_dir.display() + )) + })?; + let metadata = checked_directory_metadata(state_dir, expected_uid, "vm driver state dir")?; + let mode = metadata.permissions().mode() & 0o777; + if mode & 0o022 != 0 { + return Err(Error::execution(format!( + "vm driver state dir '{}' must not be group/world-writable (mode {mode:03o})", + state_dir.display() + ))); + } + Ok(()) +} + +#[cfg(unix)] +fn prepare_private_socket_dir(socket_dir: &Path, expected_uid: u32) -> Result<()> { + std::fs::create_dir_all(socket_dir).map_err(|err| { + Error::execution(format!( + "failed to create vm compute driver socket dir '{}': {err}", + socket_dir.display() + )) + })?; + let _ = checked_directory_metadata(socket_dir, expected_uid, "vm compute driver socket dir")?; + std::fs::set_permissions(socket_dir, std::fs::Permissions::from_mode(0o700)).map_err(|err| { + Error::execution(format!( + "failed to restrict vm compute driver socket dir '{}': {err}", + socket_dir.display() + )) + }) +} + +#[cfg(unix)] +fn checked_directory_metadata( + path: &Path, + expected_uid: u32, + label: &str, +) -> Result { + let metadata = std::fs::symlink_metadata(path).map_err(|err| { + Error::execution(format!( + "failed to stat {label} '{}': {err}", + path.display() + )) + })?; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(Error::execution(format!( + "{label} '{}' is a symlink; refusing to use it", + path.display() + ))); + } + if !file_type.is_dir() { + return Err(Error::execution(format!( + "{label} '{}' is not a directory", + path.display() + ))); + } + if metadata.uid() != expected_uid { + return Err(Error::execution(format!( + "{label} '{}' is owned by uid {} but current euid is {}", + path.display(), + metadata.uid(), + expected_uid + ))); + } + Ok(metadata) +} + +#[cfg(unix)] +fn remove_stale_socket(socket_path: &Path, expected_uid: u32) -> Result<()> { + let metadata = match std::fs::symlink_metadata(socket_path) { + Ok(metadata) => metadata, + Err(err) if err.kind() == ErrorKind::NotFound => return Ok(()), + Err(err) => { + return Err(Error::execution(format!( + "failed to stat vm compute driver socket '{}': {err}", + socket_path.display() + ))); + } + }; + let file_type = metadata.file_type(); + if file_type.is_symlink() { + return Err(Error::execution(format!( + "vm compute driver socket '{}' is a symlink; refusing to remove it", + socket_path.display() + ))); + } + if metadata.uid() != expected_uid { + return Err(Error::execution(format!( + "vm compute driver socket '{}' is owned by uid {} but current euid is {}", + socket_path.display(), + metadata.uid(), + expected_uid + ))); + } + if !file_type.is_socket() { + return Err(Error::execution(format!( + "vm compute driver socket path '{}' exists but is not a Unix socket", + socket_path.display() + ))); + } + std::fs::remove_file(socket_path).map_err(|err| { + Error::execution(format!( + "failed to remove stale vm compute driver socket '{}': {err}", + socket_path.display() + )) + }) } #[cfg(unix)] @@ -278,24 +422,7 @@ pub async fn spawn( let driver_bin = resolve_compute_driver_bin(vm_config)?; let socket_path = compute_driver_socket_path(vm_config); let guest_tls_paths = compute_driver_guest_tls_paths(config, vm_config)?; - if let Some(parent) = socket_path.parent() { - std::fs::create_dir_all(parent).map_err(|e| { - Error::execution(format!( - "failed to create vm compute driver socket dir '{}': {e}", - parent.display() - )) - })?; - } - match std::fs::remove_file(&socket_path) { - Ok(()) => {} - Err(err) if err.kind() == ErrorKind::NotFound => {} - Err(err) => { - return Err(Error::execution(format!( - "failed to remove stale vm compute driver socket '{}': {err}", - socket_path.display() - ))); - } - } + prepare_compute_driver_socket_path(vm_config, &socket_path)?; let mut command = Command::new(&driver_bin); command.kill_on_drop(true); @@ -303,6 +430,9 @@ pub async fn spawn( command.stdout(Stdio::inherit()); command.stderr(Stdio::inherit()); command.arg("--bind-socket").arg(&socket_path); + command + .arg("--expected-peer-pid") + .arg(std::process::id().to_string()); command.arg("--log-level").arg(&config.log_level); command .arg("--openshell-endpoint") @@ -356,7 +486,7 @@ pub async fn spawn( #[cfg(unix)] async fn wait_for_compute_driver( - socket_path: &std::path::Path, + socket_path: &Path, child: &mut tokio::process::Child, ) -> Result { let mut last_error: Option = None; @@ -395,7 +525,7 @@ async fn wait_for_compute_driver( } #[cfg(unix)] -async fn connect_compute_driver(socket_path: &std::path::Path) -> Result { +async fn connect_compute_driver(socket_path: &Path) -> Result { let socket_path = socket_path.to_path_buf(); let display_path = socket_path.clone(); Endpoint::from_static("http://[::]:50051") @@ -415,11 +545,13 @@ async fn connect_compute_driver(socket_path: &std::path::Path) -> Result` | `OPENSHELL_DRIVER_DIR` | Search a custom directory for `openshell-driver-vm`. | -| `--vm-driver-state-dir ` | `OPENSHELL_VM_DRIVER_STATE_DIR` | Store VM rootfs, console logs, runtime state, and image-rootfs cache under this directory. | +| `--vm-driver-state-dir ` | `OPENSHELL_VM_DRIVER_STATE_DIR` | Store VM rootfs, console logs, runtime state, image-rootfs cache, and the private `run/compute-driver.sock` socket under this directory. | | `--vm-driver-vcpus ` | `OPENSHELL_VM_DRIVER_VCPUS` | Set the default vCPU count for VM sandboxes. | | `--vm-driver-mem-mib ` | `OPENSHELL_VM_DRIVER_MEM_MIB` | Set the default memory allocation for VM sandboxes in MiB. | | `--vm-krun-log-level ` | `OPENSHELL_VM_KRUN_LOG_LEVEL` | Set the libkrun log level for VM helper processes. | | `--vm-tls-ca`, `--vm-tls-cert`, `--vm-tls-key` | `OPENSHELL_VM_TLS_CA`, `OPENSHELL_VM_TLS_CERT`, `OPENSHELL_VM_TLS_KEY` | Copy sandbox client TLS materials into VM guests for mTLS callback to the gateway. | +The gateway starts `openshell-driver-vm` over a private Unix socket and passes its process ID so the driver can reject unexpected local clients. The driver's standalone TCP listener is disabled unless `--allow-unauthenticated-tcp` is set for local development. + ## Kubernetes Driver Kubernetes-backed sandboxes run as pods in the configured sandbox namespace. Use Kubernetes for shared clusters, remote compute, GPU scheduling, and operator-managed environments.