diff --git a/Cargo.lock b/Cargo.lock index 205e55b..df7488f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -558,6 +558,39 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "attestation" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=main#9fb3002d82918b85f780f37b5545c4e112e0e772" +dependencies = [ + "anyhow", + "az-tdx-vtpm", + "base64 0.22.1", + "configfs-tsm", + "dcap-qvl 0.3.12 (git+https://github.com/Phala-Network/dcap-qvl.git?rev=f1dcc65371e941a7b83e3234833d23a1fb232ab1)", + "hex", + "http", + "num-bigint", + "once_cell", + "openssl", + "parity-scale-codec", + "pccs", + "pem-rfc7468", + "rand_core 0.6.4", + "reqwest", + "rustls-webpki", + "serde", + "serde_json", + "tdx-quote", + "thiserror 2.0.17", + "time", + "tokio", + "tokio-rustls", + "tracing", + "tss-esapi", + "x509-parser 0.18.1", +] + [[package]] name = "attestation" version = "0.0.1" @@ -612,7 +645,7 @@ version = "0.0.1" dependencies = [ "alloy-rpc-client", "alloy-transport-http", - "attestation", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", "bytes", "futures-util", "http", @@ -638,10 +671,10 @@ dependencies = [ [[package]] name = "attested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#a96ec2d9096f491e652624c53d3df2b1526ef9f2" +source = "git+https://github.com/flashbots/attested-tls?branch=main#9fb3002d82918b85f780f37b5545c4e112e0e772" dependencies = [ "anyhow", - "attestation", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=main)", "ra-tls", "rcgen 0.14.7", "rustls", @@ -659,8 +692,8 @@ name = "attested-tls-proxy" version = "1.1.1" dependencies = [ "anyhow", - "attestation", - "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=main)", + "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=main)", "axum", "bytes", "clap", @@ -671,6 +704,7 @@ dependencies = [ "jsonrpsee", "nested-tls", "p256", + "pccs", "pem-rfc7468", "pin-project-lite", "pkcs1", @@ -710,6 +744,28 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-lc-rs" +version = "1.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ec6fb3fe69024a75fa7e1bfb48aa6cf59706a101658ea01bfd33b2b248a038f" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.40.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f50037ee5e1e41e7b8f9d161680a725bd1626cb6f8c7e901f91f942850852fe7" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "axum" version = "0.8.6" @@ -1066,6 +1122,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37521ac7aabe3d13122dc382493e20c9416f299d2ccd5b3a5340a2570cdeb0f3" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] @@ -1148,6 +1206,15 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "codicon" version = "3.0.0" @@ -1472,6 +1539,42 @@ dependencies = [ "x509-cert", ] +[[package]] +name = "dcap-qvl" +version = "0.3.12" +source = "git+https://github.com/Phala-Network/dcap-qvl.git?rev=f1dcc65371e941a7b83e3234833d23a1fb232ab1#f1dcc65371e941a7b83e3234833d23a1fb232ab1" +dependencies = [ + "anyhow", + "asn1_der", + "base64 0.22.1", + "borsh", + "byteorder", + "chrono", + "const-oid", + "dcap-qvl-webpki", + "der", + "derive_more 2.1.1", + "futures", + "hex", + "log", + "p256", + "parity-scale-codec", + "pem", + "reqwest", + "ring", + "rustls-pki-types", + "scale-info", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "signature", + "tracing", + "urlencoding", + "wasm-bindgen-futures", + "x509-cert", +] + [[package]] name = "dcap-qvl-webpki" version = "0.103.4+dcap.1" @@ -2009,6 +2112,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "funty" version = "2.0.0" @@ -2635,6 +2744,16 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.85" @@ -2976,7 +3095,7 @@ dependencies = [ [[package]] name = "nested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#a96ec2d9096f491e652624c53d3df2b1526ef9f2" +source = "git+https://github.com/flashbots/attested-tls?branch=main#9fb3002d82918b85f780f37b5545c4e112e0e772" dependencies = [ "rustls", "tokio", @@ -3333,6 +3452,24 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pccs" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=main#9fb3002d82918b85f780f37b5545c4e112e0e772" +dependencies = [ + "anyhow", + "dcap-qvl 0.3.12 (git+https://github.com/Phala-Network/dcap-qvl.git?rev=f1dcc65371e941a7b83e3234833d23a1fb232ab1)", + "hex", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.17", + "time", + "tokio", + "tracing", + "x509-parser 0.18.1", +] + [[package]] name = "pem" version = "3.0.6" @@ -4101,6 +4238,7 @@ version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ + "aws-lc-rs", "brotli", "brotli-decompressor", "once_cell", @@ -4136,6 +4274,7 @@ version = "0.103.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", diff --git a/Cargo.toml b/Cargo.toml index 9d6d341..e8cdf9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,11 +11,12 @@ repository = "https://github.com/flashbots/attested-tls-proxy" keywords = ["attested-TLS", "CVM", "TDX"] [dependencies] -attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } -nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } +attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "main" } +nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "main" } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "main" } +pccs = { git = "https://github.com/flashbots/attested-tls", branch = "main" } tokio = { version = "1.50.0", features = ["full"] } -tokio-rustls = { version = "0.26.4", default-features = false } +tokio-rustls = { version = "0.26.4", default-features = false, features = ["aws_lc_rs"] } x509-parser = { version = "0.18.0", features = ["verify"] } thiserror = "2.0.17" clap = { version = "4.5.51", features = ["derive", "env"] } @@ -47,7 +48,7 @@ pin-project-lite = "0.2.16" [dev-dependencies] tempfile = "3.23.0" tdx-quote = { version = "0.0.5", features = ["mock"] } -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate", features = ["mock"] } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "main", features = ["mock"] } tokio = { version = "1.48.0", features = ["full"] } jsonrpsee = { version = "0.26.0", features = ["server"] } diff --git a/attestation-provider-server/src/main.rs b/attestation-provider-server/src/main.rs index 4a91487..661ca2f 100644 --- a/attestation-provider-server/src/main.rs +++ b/attestation-provider-server/src/main.rs @@ -99,6 +99,7 @@ async fn main() -> anyhow::Result<()> { pccs_url: None, dump_dcap_quotes: cli.log_dcap_quote, override_azure_outdated_tcb: false, + internal_pccs: None, }; let attestation_message = diff --git a/attested-tls/Cargo.toml b/attested-tls/Cargo.toml index d40e7b6..60f74c0 100644 --- a/attested-tls/Cargo.toml +++ b/attested-tls/Cargo.toml @@ -9,7 +9,7 @@ keywords = ["attested-TLS", "CVM", "TDX"] [dependencies] tokio = { version = "1.48.0", features = ["full"] } -tokio-rustls = { version = "0.26.4", default-features = false } +tokio-rustls = { version = "0.26.4", default-features = false, features = ["aws_lc_rs"] } sha2 = "0.10.9" x509-parser = "0.18.0" thiserror = "2.0.17" diff --git a/attested-tls/src/test_helpers.rs b/attested-tls/src/test_helpers.rs index 9b218a3..2ba933d 100644 --- a/attested-tls/src/test_helpers.rs +++ b/attested-tls/src/test_helpers.rs @@ -10,6 +10,10 @@ use crate::SUPPORTED_ALPN_PROTOCOL_VERSIONS; pub use attestation::measurements::mock_dcap_measurements; +fn install_crypto_provider() { + let _ = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default(); +} + /// Helper to generate a self-signed certificate for testing pub fn generate_certificate_chain( ip: IpAddr, @@ -36,6 +40,8 @@ pub fn generate_tls_config( certificate_chain: Vec>, key: PrivateKeyDer<'static>, ) -> (ServerConfig, ClientConfig) { + install_crypto_provider(); + let supported_protocols: Vec<_> = SUPPORTED_ALPN_PROTOCOL_VERSIONS .into_iter() .map(|p| p.to_vec()) @@ -67,6 +73,8 @@ pub fn generate_tls_config_with_client_auth( bob_certificate_chain: Vec>, bob_key: PrivateKeyDer<'static>, ) -> ((ServerConfig, ClientConfig), (ServerConfig, ClientConfig)) { + install_crypto_provider(); + let supported_protocols: Vec<_> = SUPPORTED_ALPN_PROTOCOL_VERSIONS .into_iter() .map(|p| p.to_vec()) diff --git a/src/lib.rs b/src/lib.rs index 0f1d15a..dd9c17f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ mod http_version; #[cfg(test)] mod test_helpers; -use attestation::{AttestationError, AttestationVerifier}; +use attestation::{AttestationError, AttestationExchangeMessage, AttestationVerifier}; use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier, AttestedTlsError}; use bytes::Bytes; use http::{HeaderMap, HeaderName, HeaderValue}; @@ -36,6 +36,12 @@ use tracing::{debug, error, warn}; use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion}; +/// The header name for giving attestation type +const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; + +/// The header name for giving measurements +const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; + /// The header name for giving the forwarded for IP static X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); @@ -399,8 +405,31 @@ impl ProxyServer { ) -> Result<(), ProxyError> { debug!("[proxy-server] accepted connection"); + // Get attestation from the remote certificate from the inner session, if present. + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + match server_connection.peer_certificates() { + Some(remote_cert_chain) => remote_cert_chain + .first() + .and_then(|cert| { + match AttestedCertificateVerifier::extract_custom_attestation_from_cert(cert) + { + Ok(attestation) => Some(attestation), + Err(err) => { + warn!( + "Failed to extract remote attestation from inner-session certificate: {err}" + ); + None + } + } + }), + None => None, + } + }; + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); - Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + Self::serve_tls_stream(tls_stream, http_version, target, client_addr, attestation).await } async fn handle_inner_connection( @@ -410,8 +439,26 @@ impl ProxyServer { ) -> Result<(), ProxyError> { debug!("[proxy-server] accepted inner-only connection"); + // Get attestation from the remote certificate, if present + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + match server_connection.peer_certificates() { + Some(remote_cert_chain) => remote_cert_chain.first().and_then(|cert| { + match AttestedCertificateVerifier::extract_custom_attestation_from_cert(cert) { + Ok(attestation) => Some(attestation), + Err(err) => { + warn!("Failed to extract remote attestation from certificate: {err}"); + None + } + } + }), + None => None, + } + }; + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); - Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + Self::serve_tls_stream(tls_stream, http_version, target, client_addr, attestation).await } async fn serve_tls_stream( @@ -419,10 +466,25 @@ impl ProxyServer { http_version: HttpVersion, target: String, client_addr: SocketAddr, + attestation: Option, ) -> Result<(), ProxyError> where IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { + let (remote_attestation_type, measurements) = match attestation { + Some(attestation) => ( + Some(attestation.attestation_type), + match attestation.get_measurements() { + Ok(measurements) => measurements, + Err(err) => { + warn!("Failed to extract measurements from peer attestation: {err}"); + None + } + }, + ), + None => (None, None), + }; + // Setup a request handler let service = service_fn(move |mut req| { debug!("[proxy-server] Handling request {req:?}"); @@ -447,6 +509,34 @@ impl ProxyServer { update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for); + // Strip any caller-provided attestation metadata before injecting authenticated values. + headers.remove(ATTESTATION_TYPE_HEADER); + headers.remove(MEASUREMENT_HEADER); + + // If we have measurements, from the remote peer, add them to the request header + let measurements = measurements.clone(); + + if let Some(measurements) = measurements { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + error!("Failed to encode measurement values: {e}"); + } + } + } + + if let Some(remote_attestation_type) = remote_attestation_type { + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + } + let target = target.clone(); async move { match Self::handle_http_request(req, target).await { @@ -648,7 +738,7 @@ impl ProxyClient { let mut first = true; let mut ready_tx = Some(ready_tx); 'reconnect: loop { - let (mut sender, conn) = + let (mut sender, conn, attestation) = // Connect to the proxy server and provide / verify attestation match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) .await @@ -678,6 +768,9 @@ impl ProxyClient { let (conn_done_tx, mut conn_done_rx) = tokio::sync::watch::channel::>(None); + let remote_attestation_type = attestation.attestation_type; + let measurements = attestation.get_measurements().ok().flatten(); + tokio::spawn(async move { let res = conn.await; let _ = conn_done_tx.send(res.err()); @@ -690,8 +783,28 @@ impl ProxyClient { debug!("[proxy-client] Read incoming request from source client: {req:?}"); // Attempt to forward it to the proxy server let (response, should_reconnect) = match sender.send_request(req).await { - Ok(resp) => { + Ok(mut resp) => { debug!("[proxy-client] Read response from proxy-server: {resp:?}"); + let headers = resp.headers_mut(); + headers.remove(MEASUREMENT_HEADER); + + if let Some(measurements) = measurements.clone() { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + error!("Failed to encode measurement values: {e}"); + } + } + } + + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + (Ok(resp.map(|b| b.boxed())), false) } Err(e) => { @@ -799,7 +912,7 @@ impl ProxyClient { target: &str, nesting_tls_connector: &NestingTlsConnector, should_bail: bool, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); @@ -828,15 +941,28 @@ impl ProxyClient { async fn setup_connection( nesting_tls_connector: &NestingTlsConnector, target: &str, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; let domain = server_name_from_host(target)?; let tls_stream = nesting_tls_connector .connect(domain, outbound_stream) .await?; + debug!("[proxy-client] Connected to proxy server"); + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)?; + + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + )? + }; + // The attestation exchange is now complete - setup an HTTP client let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); @@ -860,8 +986,7 @@ impl ProxyClient { } }; - // Return the HTTP client, as well as remote measurements - Ok((sender, conn)) + Ok((sender, conn, attestation)) } // Handle a request from the source client to the proxy server @@ -984,10 +1109,14 @@ async fn build_attested_cert_resolver( attestation_generator: AttestationGenerator, certificate_name: String, ) -> Result { - Ok( - AttestedCertificateResolver::new(attestation_generator, None, certificate_name, vec![]) - .await?, + Ok(AttestedCertificateResolver::new( + attestation_generator, + None, + certificate_name, + vec![], + Duration::from_secs(30 * 60), ) + .await?) } async fn build_inner_server_config( @@ -1056,6 +1185,11 @@ where #[cfg(test)] mod tests { use attestation::{AttestationType, measurements::MeasurementPolicy}; + use std::collections::HashMap; + use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }; use tokio_rustls::TlsConnector; use super::*; @@ -1064,6 +1198,93 @@ mod tests { generate_tls_config_with_client_auth, init_tracing, }; + fn expected_mock_measurements() -> HashMap { + let zero_measurement = "0".repeat(96); + HashMap::from([ + ("0".to_string(), zero_measurement.clone()), + ("1".to_string(), zero_measurement.clone()), + ("2".to_string(), zero_measurement.clone()), + ("3".to_string(), zero_measurement.clone()), + ("4".to_string(), zero_measurement), + ]) + } + + fn assert_mock_measurements(body: &str) { + let parsed: HashMap = serde_json::from_str(body).unwrap(); + assert_eq!(parsed, expected_mock_measurements()); + } + + fn assert_mock_measurements_header(headers: &http::HeaderMap) { + let body = headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap(); + assert_mock_measurements(body); + } + + fn assert_attestation_type_header(headers: &http::HeaderMap, expected: &str) { + assert_eq!( + headers + .get(ATTESTATION_TYPE_HEADER) + .and_then(|v| v.to_str().ok()), + Some(expected) + ); + } + + fn assert_no_measurements_header(headers: &http::HeaderMap) { + assert!(headers.get(MEASUREMENT_HEADER).is_none()); + } + + /// Test service that echoes attestation-related request headers as JSON. + async fn request_header_echo_service() -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let app = axum::Router::new().route( + "/", + axum::routing::get(|headers: http::HeaderMap| async move { + axum::Json(serde_json::json!({ + "measurement": headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()), + "attestation_type": headers + .get(ATTESTATION_TYPE_HEADER) + .and_then(|v| v.to_str().ok()), + })) + }), + ); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + addr + } + + /// Test service that deliberately returns a spoofed measurement header. + async fn spoofed_response_measurement_service() -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let app = axum::Router::new().route( + "/", + axum::routing::get(|| async move { + let mut response = http::Response::new("ok".to_string()); + response.headers_mut().insert( + MEASUREMENT_HEADER, + HeaderValue::from_static("{\"spoofed\":\"value\"}"), + ); + response + }), + ); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + addr + } + #[test] fn proxy_alpn_protocols_prefer_http2() { let mut protocols = Vec::new(); @@ -1142,7 +1363,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn inner_only_listener_negotiates_http2_by_default() { - let _ = rustls::crypto::ring::default_provider().install_default(); + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let target_addr = example_http_service().await; let proxy_server = ProxyServer::new( @@ -1230,7 +1451,7 @@ mod tests { let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - let (sender, conn) = ProxyClient::setup_connection( + let (sender, conn, _attestation) = ProxyClient::setup_connection( &nesting_tls_connector, &format!("localhost:{}", proxy_addr.port()), ) @@ -1294,6 +1515,9 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } @@ -1362,8 +1586,11 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); - assert_eq!(res_body, "No measurements"); + assert_mock_measurements(&res_body); } // Server has no attestation, client has mock DCAP but no client auth @@ -1423,7 +1650,74 @@ mod tests { .await .unwrap(); - let _res_body = res.text().await.unwrap(); + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); + } + + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_strips_spoofed_request_attestation_headers() { + let target_addr = request_header_echo_service().await; + + let (server_cert_chain, server_private_key) = + generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = + generate_tls_config(server_cert_chain.clone(), server_private_key); + + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0", + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::Client::new() + .get(format!("http://{}", proxy_client_addr)) + .header(MEASUREMENT_HEADER, "{\"spoofed\":\"request\"}") + .header(ATTESTATION_TYPE_HEADER, "dcap-tdx") + .send() + .await + .unwrap(); + + let echoed: serde_json::Value = + serde_json::from_slice(&res.bytes().await.unwrap()).unwrap(); + assert!(echoed["measurement"].is_null()); + assert!(echoed["attestation_type"].is_null()); } // Server has mock DCAP, client has mock DCAP and client auth @@ -1490,12 +1784,16 @@ mod tests { let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - assert_eq!(res.text().await.unwrap(), "No measurements"); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_mock_measurements(&res.text().await.unwrap()); let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - assert_eq!(res.text().await.unwrap(), "No measurements"); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_mock_measurements(&res.text().await.unwrap()); } // Server has mock DCAP, client no attestation - just get the server certificate @@ -1646,6 +1944,7 @@ mod tests { pccs_url: None, dump_dcap_quotes: false, override_azure_outdated_tcb: false, + internal_pccs: None, }; let proxy_client_result = ProxyClient::new_with_tls_config( @@ -1692,6 +1991,7 @@ mod tests { // This is used to trigger a dropped connection to the proxy server let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + let (reconnected_tx, reconnected_rx) = oneshot::channel(); tokio::spawn(async move { let connection_handle = proxy_server.accept().await.unwrap(); @@ -1703,6 +2003,7 @@ mod tests { // Now accept another connection proxy_server.accept().await.unwrap(); + let _ = reconnected_tx.send(()); }); let proxy_client = ProxyClient::new_with_tls_config( @@ -1723,22 +2024,208 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) + let initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); + assert_attestation_type_header(initial_response.headers(), "dcap-tdx"); + assert_mock_measurements_header(initial_response.headers()); // Now break the connection connection_breaker_tx.send(()).unwrap(); + reconnected_rx.await.unwrap(); // Make another request let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_does_not_retry_failed_request() { + init_tracing(); + + let request_count = Arc::new(AtomicUsize::new(0)); + let request_seen = Arc::new(tokio::sync::Notify::new()); + let (release_tx, release_rx) = tokio::sync::watch::channel(false); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let target_addr = listener.local_addr().unwrap(); + + let app = axum::Router::new().route( + "/", + axum::routing::get({ + let request_count = request_count.clone(); + let request_seen = request_seen.clone(); + let release_rx = release_rx.clone(); + + move || { + let request_count = request_count.clone(); + let request_seen = request_seen.clone(); + let mut release_rx = release_rx.clone(); + + async move { + request_count.fetch_add(1, Ordering::SeqCst); + request_seen.notify_waiters(); + + if !*release_rx.borrow() { + release_rx.changed().await.unwrap(); + } + + "ok" + } + } + }), + ); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + let (reconnected_tx, reconnected_rx) = oneshot::channel(); + + tokio::spawn(async move { + let connection_handle = proxy_server.accept().await.unwrap(); + connection_breaker_rx.await.unwrap(); + connection_handle.abort(); + proxy_server.accept().await.unwrap(); + let _ = reconnected_tx.send(()); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); + }); + + let request_url = format!("http://{}", proxy_client_addr); + let failed_request = tokio::spawn(async move { reqwest::get(request_url).await.unwrap() }); + + loop { + if request_count.load(Ordering::SeqCst) > 0 { + break; + } + + request_seen.notified().await; + } + + connection_breaker_tx.send(()).unwrap(); + release_tx.send(true).unwrap(); + + let failed_response = failed_request.await.unwrap(); + assert_eq!(failed_response.status(), hyper::StatusCode::BAD_GATEWAY); + assert_eq!(request_count.load(Ordering::SeqCst), 1); + + reconnected_rx.await.unwrap(); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_eq!(res.text().await.unwrap(), "ok"); + assert_eq!(request_count.load(Ordering::SeqCst), 2); + } + + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_strips_spoofed_response_measurement_header() { + let target_addr = spoofed_response_measurement_service().await; + + let (server_cert_chain, server_private_key) = + generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = + generate_tls_config(server_cert_chain.clone(), server_private_key); + + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0", + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + assert_eq!(res.text().await.unwrap(), "ok"); + } + // Use HTTP 1.1 #[tokio::test(flavor = "multi_thread")] async fn http_proxy_with_http1() { @@ -1794,6 +2281,9 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } diff --git a/src/main.rs b/src/main.rs index dcb23d5..f4be61c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, ensure}; use attestation::{AttestationType, AttestationVerifier, measurements::MeasurementPolicy}; use clap::{Parser, Subcommand}; +use pccs::Pccs; use std::{fs::File, net::SocketAddr, path::PathBuf}; use tokio::io::AsyncWriteExt; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; @@ -28,7 +29,7 @@ struct Cli { /// If no measurements file is specified, a single attestion type to allow #[arg(long, global = true)] allowed_remote_attestation_type: Option, - /// The URL of a PCCS to use when verifying DCAP attestations. Defaults to Intel PCS. + /// The URL of a PCCS to use when verifying DCAP attestations. Defaults to an internal PCCS. #[arg(long, global = true)] pccs_url: Option, /// Log debug messages @@ -159,6 +160,8 @@ enum CliCommand { #[tokio::main] async fn main() -> anyhow::Result<()> { + let _ = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default(); + let cli = Cli::parse(); ensure!( @@ -216,6 +219,7 @@ async fn main() -> anyhow::Result<()> { pccs_url: cli.pccs_url, dump_dcap_quotes: cli.log_dcap_quote, override_azure_outdated_tcb: cli.override_azure_outdated_tcb, + internal_pccs: Some(Pccs::new(None)), }; match cli.command { diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 431c5f8..9448f7f 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -12,12 +12,24 @@ use tokio_rustls::rustls::{ }; use tracing_subscriber::{EnvFilter, fmt}; +use crate::MEASUREMENT_HEADER; + static INIT: Once = Once::new(); +pub fn install_crypto_provider() { + static CRYPTO_PROVIDER_INIT: Once = Once::new(); + + CRYPTO_PROVIDER_INIT.call_once(|| { + let _ = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default(); + }); +} + /// Helper to generate a self-signed certificate for testing with a DNS subject name pub fn generate_certificate_chain_for_host( host: &str, ) -> (Vec>, PrivateKeyDer<'static>) { + install_crypto_provider(); + let mut params = rcgen::CertificateParams::new(vec![host.to_string()]).unwrap(); params .subject_alt_names @@ -42,6 +54,8 @@ pub fn generate_tls_config( certificate_chain: Vec>, key: PrivateKeyDer<'static>, ) -> (ServerConfig, ClientConfig) { + install_crypto_provider(); + let server_config = ServerConfig::builder() .with_no_client_auth() .with_single_cert(certificate_chain.clone(), key) @@ -64,6 +78,8 @@ pub fn generate_tls_config_with_client_auth( bob_certificate_chain: Vec>, bob_key: PrivateKeyDer<'static>, ) -> ((ServerConfig, ClientConfig), (ServerConfig, ClientConfig)) { + install_crypto_provider(); + let (alice_client_verifier, alice_root_store) = client_verifier_from_remote_cert(bob_certificate_chain[0].clone()); @@ -115,6 +131,8 @@ fn client_verifier_from_remote_cert( /// Simple http server used in tests which returns in the response the measurement header from the /// request pub async fn example_http_service() -> SocketAddr { + install_crypto_provider(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -127,16 +145,17 @@ pub async fn example_http_service() -> SocketAddr { addr } -async fn get_handler(_headers: http::HeaderMap) -> impl IntoResponse { - // headers - // .get(MEASUREMENT_HEADER) - // .and_then(|v| v.to_str().ok()) - // .unwrap_or("No measurements") - // .to_string() - "No measurements".to_string() +async fn get_handler(headers: http::HeaderMap) -> impl IntoResponse { + headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap_or("No measurements") + .to_string() } pub fn init_tracing() { + install_crypto_provider(); + INIT.call_once(|| { let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));