From 86eabbaa1abb640ba89f8be9e8d51bb53c84288c Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 23 Mar 2026 10:04:11 +0100 Subject: [PATCH 01/10] Add measurement header injection --- Cargo.lock | 56 ++++++++++--- Cargo.toml | 8 +- src/http_version.rs | 15 ++-- src/lib.rs | 190 +++++++++++++++++++++++++++++++++++++------- 4 files changed, 220 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a169090..45f4bd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -590,6 +590,38 @@ dependencies = [ "x509-parser 0.18.1", ] +[[package]] +name = "attestation" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" +dependencies = [ + "anyhow", + "az-tdx-vtpm", + "base64 0.22.1", + "configfs-tsm", + "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override)", + "hex", + "http", + "num-bigint", + "once_cell", + "openssl", + "parity-scale-codec", + "pem-rfc7468", + "rand_core 0.6.4", + "reqwest", + "rustls-webpki", + "serde", + "serde_json", + "tdx-quote", + "thiserror 2.0.17", + "time", + "tokio", + "tokio-rustls", + "tracing", + "tss-esapi", + "x509-parser 0.18.1", +] + [[package]] name = "attestation-provider-server" version = "0.1.0" @@ -612,7 +644,7 @@ version = "0.0.1" dependencies = [ "alloy-rpc-client", "alloy-transport-http", - "attestation", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", "bytes", "futures-util", "http", @@ -638,10 +670,10 @@ dependencies = [ [[package]] name = "attested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" dependencies = [ "anyhow", - "attestation", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", "ra-tls", "rcgen 0.14.7", "rustls", @@ -659,8 +691,8 @@ name = "attested-tls-proxy" version = "1.1.1" dependencies = [ "anyhow", - "attestation", - "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", + "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", "axum", "bytes", "clap", @@ -1072,7 +1104,7 @@ dependencies = [ [[package]] name = "cc-eventlog" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "digest 0.10.7", @@ -1661,7 +1693,7 @@ dependencies = [ [[package]] name = "dstack-attest" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "cc-eventlog", @@ -1687,7 +1719,7 @@ dependencies = [ [[package]] name = "dstack-types" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "parity-scale-codec", "serde", @@ -2976,7 +3008,7 @@ dependencies = [ [[package]] name = "nested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" dependencies = [ "rustls", "tokio", @@ -3673,7 +3705,7 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "ra-tls" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "bon", @@ -4480,7 +4512,7 @@ checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "size-parser" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "serde", @@ -4671,7 +4703,7 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tdx-attest" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "cc-eventlog", diff --git a/Cargo.toml b/Cargo.toml index c285277..17b5749 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,9 @@ repository = "https://github.com/flashbots/attested-tls-proxy" keywords = ["attested-TLS", "CVM", "TDX"] [dependencies] -attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } -nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } +attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } +nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } tokio = { version = "1.50.0", features = ["full"] } tokio-rustls = { version = "0.26.4", default-features = false } x509-parser = { version = "0.18.0", features = ["verify"] } @@ -47,7 +47,7 @@ pin-project-lite = "0.2.16" [dev-dependencies] tempfile = "3.23.0" tdx-quote = { version = "0.0.5", features = ["mock"] } -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate", features = ["mock"] } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier", features = ["mock"] } tokio = { version = "1.48.0", features = ["full"] } jsonrpsee = { version = "0.26.0", features = ["server"] } diff --git a/src/http_version.rs b/src/http_version.rs index 901df66..d2f0af2 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -1,6 +1,7 @@ //! HTTP Version support and negotiation use hyper::Response; use hyper_util::rt::TokioIo; +use bytes::Bytes; use std::pin::Pin; use std::task::{Context, Poll}; @@ -55,15 +56,18 @@ impl HttpVersion { } } -type Http1Sender = hyper::client::conn::http1::SendRequest; -type Http2Sender = hyper::client::conn::http2::SendRequest; +type Http1Sender = hyper::client::conn::http1::SendRequest>; +type Http2Sender = hyper::client::conn::http2::SendRequest>; type Http1Connection = - hyper::client::conn::http1::Connection, hyper::body::Incoming>; + hyper::client::conn::http1::Connection< + TokioIo, + http_body_util::Full, + >; type Http2Connection = hyper::client::conn::http2::Connection< TokioIo, - hyper::body::Incoming, + http_body_util::Full, crate::TokioExecutor, >; @@ -88,8 +92,9 @@ impl From for HttpSender { impl HttpSender { pub async fn send_request( &mut self, - request: http::Request, + request: http::Request, ) -> Result, hyper::Error> { + let request = request.map(http_body_util::Full::new); match self { Self::Http1(sender) => sender.send_request(request).await, Self::Http2(sender) => sender.send_request(request).await, diff --git a/src/lib.rs b/src/lib.rs index 88aa200..8e44005 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ mod http_version; #[cfg(test)] mod test_helpers; -use attestation::{AttestationError, AttestationVerifier}; +use attestation::{AttestationError, AttestationExchangeMessage, AttestationVerifier}; use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier, AttestedTlsError}; use bytes::Bytes; use http::{HeaderMap, HeaderName, HeaderValue}; @@ -36,6 +36,12 @@ use tracing::{debug, error, warn}; use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion}; +/// The header name for giving attestation type +const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; + +/// The header name for giving measurements +const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; + /// The header name for giving the forwarded for IP static X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); @@ -48,7 +54,7 @@ const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; const KEEP_ALIVE_INTERVAL: u64 = 30; const KEEP_ALIVE_TIMEOUT: u64 = 10; type RequestWithResponseSender = ( - http::Request, + http::Request, oneshot::Sender>, hyper::Error>>, ); @@ -399,8 +405,22 @@ impl ProxyServer { ) -> Result<(), ProxyError> { debug!("[proxy-server] accepted connection"); + // Get attestation from the remote certificate from the inner session, if present. + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + match server_connection.peer_certificates() { + Some(remote_cert_chain) => Some( + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + )?, + ), + None => None, + } + }; + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); - Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + Self::serve_tls_stream(tls_stream, http_version, target, client_addr, attestation).await } async fn handle_inner_connection( @@ -410,8 +430,22 @@ impl ProxyServer { ) -> Result<(), ProxyError> { debug!("[proxy-server] accepted inner-only connection"); + // Get attestation from the remote certificate, if present + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + match server_connection.peer_certificates() { + Some(remote_cert_chain) => Some( + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + )?, + ), + None => None, + } + }; + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); - Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + Self::serve_tls_stream(tls_stream, http_version, target, client_addr, attestation).await } async fn serve_tls_stream( @@ -419,10 +453,19 @@ impl ProxyServer { http_version: HttpVersion, target: String, client_addr: SocketAddr, + attestation: Option, ) -> Result<(), ProxyError> where IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { + let (remote_attestation_type, measurements) = match attestation { + Some(attestation) => ( + Some(attestation.attestation_type), + attestation.get_measurements()?, + ), + None => (None, None), + }; + // Setup a request handler let service = service_fn(move |mut req| { debug!("[proxy-server] Handling request {req:?}"); @@ -447,6 +490,30 @@ impl ProxyServer { update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for); + // If we have measurements, from the remote peer, add them to the request header + let measurements = measurements.clone(); + + if let Some(measurements) = measurements { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + error!("Failed to encode measurement values: {e}"); + } + } + } + + if let Some(remote_attestation_type) = remote_attestation_type { + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + } + let target = target.clone(); async move { match Self::handle_http_request(req, target).await { @@ -635,7 +702,7 @@ impl ProxyClient { // Channel for getting incoming requests from the source client let (requests_tx, mut requests_rx) = mpsc::channel::<( - http::Request, + http::Request, oneshot::Sender< Result>, hyper::Error>, >, @@ -648,7 +715,7 @@ impl ProxyClient { let mut first = true; let mut ready_tx = Some(ready_tx); 'reconnect: loop { - let (mut sender, conn) = + let (mut sender, conn, attestation) = // Connect to the proxy server and provide / verify attestation match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) .await @@ -678,6 +745,9 @@ impl ProxyClient { let (conn_done_tx, mut conn_done_rx) = tokio::sync::watch::channel::>(None); + let mut remote_attestation_type = attestation.attestation_type; + let mut measurements = attestation.get_measurements().ok().flatten(); + tokio::spawn(async move { let res = conn.await; let _ = conn_done_tx.send(res.err()); @@ -689,17 +759,69 @@ impl ProxyClient { if let Some((req, response_tx)) = incoming_req_option { debug!("[proxy-client] Read incoming request from source client: {req:?}"); // Attempt to forward it to the proxy server - let (response, should_reconnect) = match sender.send_request(req).await { - Ok(resp) => { + let response = loop { + match sender.send_request(req.clone()).await { + Ok(mut resp) => { debug!("[proxy-client] Read response from proxy-server: {resp:?}"); - (Ok(resp.map(|b| b.boxed())), false) - } - Err(e) => { + // If we have measurements from the proxy-server, inject them into the + // response header + let headers = resp.headers_mut(); + if let Some(measurements) = measurements.clone() { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + error!("Failed to encode measurement values: {e}"); + } + } + } + + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + break Ok(resp.map(|b| b.boxed())); + } + Err(e) => { warn!("Failed to send request to proxy-server: {e}"); - let mut resp = Response::new(full(format!("Request failed: {e}"))); - *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - - (Ok(resp), true) + match Self::setup_connection_with_backoff( + &target, + &nesting_tls_connector, + false, + ) + .await + { + Ok((new_sender, new_conn, new_attestation)) => { + sender = new_sender; + remote_attestation_type = new_attestation.attestation_type; + measurements = new_attestation.get_measurements().ok().flatten(); + + let (new_conn_done_tx, new_conn_done_rx) = + tokio::sync::watch::channel::>(None); + conn_done_rx = new_conn_done_rx; + + tokio::spawn(async move { + let res = new_conn.await; + let _ = new_conn_done_tx.send(res.err()); + }); + + warn!("Reconnected to proxy-server, retrying request"); + continue; + } + Err(reconnect_err) => { + warn!("Reconnect after request failure failed: {reconnect_err}"); + let mut resp = Response::new(full(format!( + "Request failed: {e}" + ))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + break Ok(resp); + } + } + } } }; @@ -707,12 +829,6 @@ impl ProxyClient { if response_tx.send(response).is_err() { warn!("Failed to forward response to source client, probably they dropped the connection"); } - - if should_reconnect { - // Leave the inner loop and continue on the reconnect loop - warn!("Reconnecting to proxy-server due to failed request"); - break; - } } else { // The request sender was dropped - so no more incoming requests debug!("Request sender dropped - leaving connection handler loop"); @@ -799,7 +915,7 @@ impl ProxyClient { target: &str, nesting_tls_connector: &NestingTlsConnector, should_bail: bool, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); @@ -828,15 +944,29 @@ impl ProxyClient { async fn setup_connection( nesting_tls_connector: &NestingTlsConnector, target: &str, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; let domain = server_name_from_host(target)?; let tls_stream = nesting_tls_connector .connect(domain, outbound_stream) .await?; + debug!("[proxy-client] Connected to proxy server"); + // Get attestation from session + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)?; + + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + )? + }; + // The attestation exchange is now complete - setup an HTTP client let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); @@ -848,20 +978,20 @@ impl ProxyClient { .keep_alive_interval(Some(Duration::from_secs(KEEP_ALIVE_INTERVAL))) .keep_alive_timeout(Duration::from_secs(KEEP_ALIVE_TIMEOUT)) .keep_alive_while_idle(true) - .handshake::<_, hyper::body::Incoming>(outbound_io) + .handshake::<_, http_body_util::Full>(outbound_io) .await?; (sender.into(), conn.into()) } HttpVersion::Http1 => { let (sender, conn) = hyper::client::conn::http1::Builder::new() - .handshake::<_, hyper::body::Incoming>(outbound_io) + .handshake::<_, http_body_util::Full>(outbound_io) .await?; (sender.into(), conn.into()) } }; - // Return the HTTP client, as well as remote measurements - Ok((sender, conn)) + // Return the HTTP client, as well as remote attestation + Ok((sender, conn, attestation)) } // Handle a request from the source client to the proxy server @@ -869,6 +999,10 @@ impl ProxyClient { req: hyper::Request, requests_tx: mpsc::Sender, ) -> Result>, ProxyError> { + let (parts, body) = req.into_parts(); + let body = body.collect().await?.to_bytes(); + let req = http::Request::from_parts(parts, body); + let (response_tx, response_rx) = oneshot::channel(); requests_tx.send((req, response_tx)).await?; Ok(response_rx.await??) @@ -1230,7 +1364,7 @@ mod tests { let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - let (sender, conn) = ProxyClient::setup_connection( + let (sender, conn, _attestation) = ProxyClient::setup_connection( &nesting_tls_connector, &format!("localhost:{}", proxy_addr.port()), ) From d45a353b90ffa2e713caa5e0258a97825179d9eb Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 23 Mar 2026 10:37:16 +0100 Subject: [PATCH 02/10] Fix re-connection bug --- src/http_version.rs | 11 ++-- src/lib.rs | 131 +++++++++++++++++++++++++------------------- 2 files changed, 79 insertions(+), 63 deletions(-) diff --git a/src/http_version.rs b/src/http_version.rs index d2f0af2..91948c3 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -1,7 +1,7 @@ //! HTTP Version support and negotiation +use bytes::Bytes; use hyper::Response; use hyper_util::rt::TokioIo; -use bytes::Bytes; use std::pin::Pin; use std::task::{Context, Poll}; @@ -59,11 +59,10 @@ impl HttpVersion { type Http1Sender = hyper::client::conn::http1::SendRequest>; type Http2Sender = hyper::client::conn::http2::SendRequest>; -type Http1Connection = - hyper::client::conn::http1::Connection< - TokioIo, - http_body_util::Full, - >; +type Http1Connection = hyper::client::conn::http1::Connection< + TokioIo, + http_body_util::Full, +>; type Http2Connection = hyper::client::conn::http2::Connection< TokioIo, diff --git a/src/lib.rs b/src/lib.rs index 8e44005..acf25f4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -410,11 +410,20 @@ impl ProxyServer { let (_io, server_connection) = tls_stream.get_ref(); match server_connection.peer_certificates() { - Some(remote_cert_chain) => Some( - AttestedCertificateVerifier::extract_custom_attestation_from_cert( - remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, - )?, - ), + Some(remote_cert_chain) => remote_cert_chain + .first() + .and_then(|cert| { + match AttestedCertificateVerifier::extract_custom_attestation_from_cert(cert) + { + Ok(attestation) => Some(attestation), + Err(err) => { + warn!( + "Failed to extract remote attestation from inner-session certificate: {err}" + ); + None + } + } + }), None => None, } }; @@ -435,11 +444,15 @@ impl ProxyServer { let (_io, server_connection) = tls_stream.get_ref(); match server_connection.peer_certificates() { - Some(remote_cert_chain) => Some( - AttestedCertificateVerifier::extract_custom_attestation_from_cert( - remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, - )?, - ), + Some(remote_cert_chain) => remote_cert_chain.first().and_then(|cert| { + match AttestedCertificateVerifier::extract_custom_attestation_from_cert(cert) { + Ok(attestation) => Some(attestation), + Err(err) => { + warn!("Failed to extract remote attestation from certificate: {err}"); + None + } + } + }), None => None, } }; @@ -461,7 +474,13 @@ impl ProxyServer { let (remote_attestation_type, measurements) = match attestation { Some(attestation) => ( Some(attestation.attestation_type), - attestation.get_measurements()?, + match attestation.get_measurements() { + Ok(measurements) => measurements, + Err(err) => { + warn!("Failed to extract measurements from peer attestation: {err}"); + None + } + }, ), None => (None, None), }; @@ -715,7 +734,7 @@ impl ProxyClient { let mut first = true; let mut ready_tx = Some(ready_tx); 'reconnect: loop { - let (mut sender, conn, attestation) = + let (mut sender, conn) = // Connect to the proxy server and provide / verify attestation match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) .await @@ -745,9 +764,6 @@ impl ProxyClient { let (conn_done_tx, mut conn_done_rx) = tokio::sync::watch::channel::>(None); - let mut remote_attestation_type = attestation.attestation_type; - let mut measurements = attestation.get_measurements().ok().flatten(); - tokio::spawn(async move { let res = conn.await; let _ = conn_done_tx.send(res.err()); @@ -760,45 +776,60 @@ impl ProxyClient { debug!("[proxy-client] Read incoming request from source client: {req:?}"); // Attempt to forward it to the proxy server let response = loop { - match sender.send_request(req.clone()).await { - Ok(mut resp) => { - debug!("[proxy-client] Read response from proxy-server: {resp:?}"); - // If we have measurements from the proxy-server, inject them into the - // response header - let headers = resp.headers_mut(); - if let Some(measurements) = measurements.clone() { - match measurements.to_header_format() { - Ok(header_value) => { - headers.insert(MEASUREMENT_HEADER, header_value); + let send_result = tokio::select! { + result = sender.send_request(req.clone()) => result, + _ = conn_done_rx.changed() => { + warn!("Connection dropped while request was in flight"); + match Self::setup_connection_with_backoff( + &target, + &nesting_tls_connector, + true, + ) + .await + { + Ok((new_sender, new_conn)) => { + sender = new_sender; + + let (new_conn_done_tx, new_conn_done_rx) = + tokio::sync::watch::channel::>(None); + conn_done_rx = new_conn_done_rx; + + tokio::spawn(async move { + let res = new_conn.await; + let _ = new_conn_done_tx.send(res.err()); + }); + + warn!("Reconnected to proxy-server, retrying request"); + continue; } - Err(e) => { - // This error is highly unlikely - that the measurement values fail to - // encode to JSON or fit in an HTTP header - error!("Failed to encode measurement values: {e}"); + Err(reconnect_err) => { + warn!("Reconnect after in-flight drop failed: {reconnect_err}"); + let mut resp = Response::new(full( + "Request failed: connection to proxy-server dropped", + )); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + break Ok(resp); } } } + }; - update_header( - headers, - ATTESTATION_TYPE_HEADER, - remote_attestation_type.as_str(), - ); + match send_result { + Ok(resp) => { + debug!("[proxy-client] Read response from proxy-server: {resp:?}"); break Ok(resp.map(|b| b.boxed())); } Err(e) => { - warn!("Failed to send request to proxy-server: {e}"); + warn!("Failed to send request to proxy-server: {e}"); match Self::setup_connection_with_backoff( &target, &nesting_tls_connector, - false, + true, ) .await { - Ok((new_sender, new_conn, new_attestation)) => { + Ok((new_sender, new_conn)) => { sender = new_sender; - remote_attestation_type = new_attestation.attestation_type; - measurements = new_attestation.get_measurements().ok().flatten(); let (new_conn_done_tx, new_conn_done_rx) = tokio::sync::watch::channel::>(None); @@ -915,7 +946,7 @@ impl ProxyClient { target: &str, nesting_tls_connector: &NestingTlsConnector, should_bail: bool, - ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { + ) -> Result<(HttpSender, HttpConnection), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); @@ -944,7 +975,7 @@ impl ProxyClient { async fn setup_connection( nesting_tls_connector: &NestingTlsConnector, target: &str, - ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { + ) -> Result<(HttpSender, HttpConnection), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; let domain = server_name_from_host(target)?; @@ -954,19 +985,6 @@ impl ProxyClient { debug!("[proxy-client] Connected to proxy server"); - // Get attestation from session - let attestation = { - let (_io, server_connection) = tls_stream.get_ref(); - - let remote_cert_chain = server_connection - .peer_certificates() - .ok_or(ProxyError::NoCertificate)?; - - AttestedCertificateVerifier::extract_custom_attestation_from_cert( - remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, - )? - }; - // The attestation exchange is now complete - setup an HTTP client let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); @@ -990,8 +1008,7 @@ impl ProxyClient { } }; - // Return the HTTP client, as well as remote attestation - Ok((sender, conn, attestation)) + Ok((sender, conn)) } // Handle a request from the source client to the proxy server @@ -1364,7 +1381,7 @@ mod tests { let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - let (sender, conn, _attestation) = ProxyClient::setup_connection( + let (sender, conn) = ProxyClient::setup_connection( &nesting_tls_connector, &format!("localhost:{}", proxy_addr.port()), ) From 7ed43a9a9e060f704ca9ab75bc3f4ef435110f26 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 23 Mar 2026 11:30:00 +0100 Subject: [PATCH 03/10] Fully restore measurement header injection --- Cargo.lock | 6 +-- src/lib.rs | 123 +++++++++++++++++++++++++++++++++++++++----- src/test_helpers.rs | 15 +++--- 3 files changed, 121 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 45f4bd6..f04e874 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -593,7 +593,7 @@ dependencies = [ [[package]] name = "attestation" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" dependencies = [ "anyhow", "az-tdx-vtpm", @@ -670,7 +670,7 @@ dependencies = [ [[package]] name = "attested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" dependencies = [ "anyhow", "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", @@ -3008,7 +3008,7 @@ dependencies = [ [[package]] name = "nested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#ed0a0c5125562b45631a2522d1212b4ece143393" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" dependencies = [ "rustls", "tokio", diff --git a/src/lib.rs b/src/lib.rs index acf25f4..8d60d37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -734,7 +734,7 @@ impl ProxyClient { let mut first = true; let mut ready_tx = Some(ready_tx); 'reconnect: loop { - let (mut sender, conn) = + let (mut sender, conn, attestation) = // Connect to the proxy server and provide / verify attestation match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) .await @@ -764,6 +764,9 @@ impl ProxyClient { let (conn_done_tx, mut conn_done_rx) = tokio::sync::watch::channel::>(None); + let mut remote_attestation_type = attestation.attestation_type; + let mut measurements = attestation.get_measurements().ok().flatten(); + tokio::spawn(async move { let res = conn.await; let _ = conn_done_tx.send(res.err()); @@ -787,8 +790,10 @@ impl ProxyClient { ) .await { - Ok((new_sender, new_conn)) => { + Ok((new_sender, new_conn, new_attestation)) => { sender = new_sender; + remote_attestation_type = new_attestation.attestation_type; + measurements = new_attestation.get_measurements().ok().flatten(); let (new_conn_done_tx, new_conn_done_rx) = tokio::sync::watch::channel::>(None); @@ -815,8 +820,26 @@ impl ProxyClient { }; match send_result { - Ok(resp) => { + Ok(mut resp) => { debug!("[proxy-client] Read response from proxy-server: {resp:?}"); + let headers = resp.headers_mut(); + if let Some(measurements) = measurements.clone() { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + error!("Failed to encode measurement values: {e}"); + } + } + } + + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + break Ok(resp.map(|b| b.boxed())); } Err(e) => { @@ -828,8 +851,10 @@ impl ProxyClient { ) .await { - Ok((new_sender, new_conn)) => { + Ok((new_sender, new_conn, new_attestation)) => { sender = new_sender; + remote_attestation_type = new_attestation.attestation_type; + measurements = new_attestation.get_measurements().ok().flatten(); let (new_conn_done_tx, new_conn_done_rx) = tokio::sync::watch::channel::>(None); @@ -946,7 +971,7 @@ impl ProxyClient { target: &str, nesting_tls_connector: &NestingTlsConnector, should_bail: bool, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); @@ -975,7 +1000,7 @@ impl ProxyClient { async fn setup_connection( nesting_tls_connector: &NestingTlsConnector, target: &str, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; let domain = server_name_from_host(target)?; @@ -985,6 +1010,18 @@ impl ProxyClient { debug!("[proxy-client] Connected to proxy server"); + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)?; + + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + )? + }; + // The attestation exchange is now complete - setup an HTTP client let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); @@ -1008,7 +1045,7 @@ impl ProxyClient { } }; - Ok((sender, conn)) + Ok((sender, conn, attestation)) } // Handle a request from the source client to the proxy server @@ -1207,6 +1244,7 @@ where #[cfg(test)] mod tests { use attestation::{AttestationType, measurements::MeasurementPolicy}; + use std::collections::HashMap; use tokio_rustls::TlsConnector; use super::*; @@ -1215,6 +1253,43 @@ mod tests { generate_tls_config_with_client_auth, init_tracing, }; + fn expected_mock_measurements() -> HashMap { + let zero_measurement = "0".repeat(96); + HashMap::from([ + ("0".to_string(), zero_measurement.clone()), + ("1".to_string(), zero_measurement.clone()), + ("2".to_string(), zero_measurement.clone()), + ("3".to_string(), zero_measurement.clone()), + ("4".to_string(), zero_measurement), + ]) + } + + fn assert_mock_measurements(body: &str) { + let parsed: HashMap = serde_json::from_str(body).unwrap(); + assert_eq!(parsed, expected_mock_measurements()); + } + + fn assert_mock_measurements_header(headers: &http::HeaderMap) { + let body = headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap(); + assert_mock_measurements(body); + } + + fn assert_attestation_type_header(headers: &http::HeaderMap, expected: &str) { + assert_eq!( + headers + .get(ATTESTATION_TYPE_HEADER) + .and_then(|v| v.to_str().ok()), + Some(expected) + ); + } + + fn assert_no_measurements_header(headers: &http::HeaderMap) { + assert!(headers.get(MEASUREMENT_HEADER).is_none()); + } + #[test] fn proxy_alpn_protocols_prefer_http2() { let mut protocols = Vec::new(); @@ -1381,7 +1456,7 @@ mod tests { let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - let (sender, conn) = ProxyClient::setup_connection( + let (sender, conn, _attestation) = ProxyClient::setup_connection( &nesting_tls_connector, &format!("localhost:{}", proxy_addr.port()), ) @@ -1445,6 +1520,9 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } @@ -1513,8 +1591,11 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); - assert_eq!(res_body, "No measurements"); + assert_mock_measurements(&res_body); } // Server has no attestation, client has mock DCAP but no client auth @@ -1574,7 +1655,11 @@ mod tests { .await .unwrap(); - let _res_body = res.text().await.unwrap(); + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); } // Server has mock DCAP, client has mock DCAP and client auth @@ -1641,12 +1726,16 @@ mod tests { let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - assert_eq!(res.text().await.unwrap(), "No measurements"); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_mock_measurements(&res.text().await.unwrap()); let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - assert_eq!(res.text().await.unwrap(), "No measurements"); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_mock_measurements(&res.text().await.unwrap()); } // Server has mock DCAP, client no attestation - just get the server certificate @@ -1874,9 +1963,11 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) + let initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); + assert_attestation_type_header(initial_response.headers(), "dcap-tdx"); + assert_mock_measurements_header(initial_response.headers()); // Now break the connection connection_breaker_tx.send(()).unwrap(); @@ -1886,6 +1977,9 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } @@ -1945,6 +2039,9 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 431c5f8..b8509c0 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -12,6 +12,8 @@ use tokio_rustls::rustls::{ }; use tracing_subscriber::{EnvFilter, fmt}; +use crate::MEASUREMENT_HEADER; + static INIT: Once = Once::new(); /// Helper to generate a self-signed certificate for testing with a DNS subject name @@ -127,13 +129,12 @@ pub async fn example_http_service() -> SocketAddr { addr } -async fn get_handler(_headers: http::HeaderMap) -> impl IntoResponse { - // headers - // .get(MEASUREMENT_HEADER) - // .and_then(|v| v.to_str().ok()) - // .unwrap_or("No measurements") - // .to_string() - "No measurements".to_string() +async fn get_handler(headers: http::HeaderMap) -> impl IntoResponse { + headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap_or("No measurements") + .to_string() } pub fn init_tracing() { From a0478373522ec99bd95a39d6d901dd1c167988f8 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 23 Mar 2026 12:44:34 +0100 Subject: [PATCH 04/10] Use the same reconnect behavior as before --- src/http_version.rs | 16 +-- src/lib.rs | 268 +++++++++++++++++++++++++++----------------- 2 files changed, 170 insertions(+), 114 deletions(-) diff --git a/src/http_version.rs b/src/http_version.rs index 91948c3..901df66 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -1,5 +1,4 @@ //! HTTP Version support and negotiation -use bytes::Bytes; use hyper::Response; use hyper_util::rt::TokioIo; use std::pin::Pin; @@ -56,17 +55,15 @@ impl HttpVersion { } } -type Http1Sender = hyper::client::conn::http1::SendRequest>; -type Http2Sender = hyper::client::conn::http2::SendRequest>; +type Http1Sender = hyper::client::conn::http1::SendRequest; +type Http2Sender = hyper::client::conn::http2::SendRequest; -type Http1Connection = hyper::client::conn::http1::Connection< - TokioIo, - http_body_util::Full, ->; +type Http1Connection = + hyper::client::conn::http1::Connection, hyper::body::Incoming>; type Http2Connection = hyper::client::conn::http2::Connection< TokioIo, - http_body_util::Full, + hyper::body::Incoming, crate::TokioExecutor, >; @@ -91,9 +88,8 @@ impl From for HttpSender { impl HttpSender { pub async fn send_request( &mut self, - request: http::Request, + request: http::Request, ) -> Result, hyper::Error> { - let request = request.map(http_body_util::Full::new); match self { Self::Http1(sender) => sender.send_request(request).await, Self::Http2(sender) => sender.send_request(request).await, diff --git a/src/lib.rs b/src/lib.rs index 8d60d37..e7e012a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,7 +54,7 @@ const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; const KEEP_ALIVE_INTERVAL: u64 = 30; const KEEP_ALIVE_TIMEOUT: u64 = 10; type RequestWithResponseSender = ( - http::Request, + http::Request, oneshot::Sender>, hyper::Error>>, ); @@ -721,7 +721,7 @@ impl ProxyClient { // Channel for getting incoming requests from the source client let (requests_tx, mut requests_rx) = mpsc::channel::<( - http::Request, + http::Request, oneshot::Sender< Result>, hyper::Error>, >, @@ -764,8 +764,8 @@ impl ProxyClient { let (conn_done_tx, mut conn_done_rx) = tokio::sync::watch::channel::>(None); - let mut remote_attestation_type = attestation.attestation_type; - let mut measurements = attestation.get_measurements().ok().flatten(); + let remote_attestation_type = attestation.attestation_type; + let measurements = attestation.get_measurements().ok().flatten(); tokio::spawn(async move { let res = conn.await; @@ -778,106 +778,35 @@ impl ProxyClient { if let Some((req, response_tx)) = incoming_req_option { debug!("[proxy-client] Read incoming request from source client: {req:?}"); // Attempt to forward it to the proxy server - let response = loop { - let send_result = tokio::select! { - result = sender.send_request(req.clone()) => result, - _ = conn_done_rx.changed() => { - warn!("Connection dropped while request was in flight"); - match Self::setup_connection_with_backoff( - &target, - &nesting_tls_connector, - true, - ) - .await - { - Ok((new_sender, new_conn, new_attestation)) => { - sender = new_sender; - remote_attestation_type = new_attestation.attestation_type; - measurements = new_attestation.get_measurements().ok().flatten(); - - let (new_conn_done_tx, new_conn_done_rx) = - tokio::sync::watch::channel::>(None); - conn_done_rx = new_conn_done_rx; - - tokio::spawn(async move { - let res = new_conn.await; - let _ = new_conn_done_tx.send(res.err()); - }); - - warn!("Reconnected to proxy-server, retrying request"); - continue; + let (response, should_reconnect) = match sender.send_request(req).await { + Ok(mut resp) => { + debug!("[proxy-client] Read response from proxy-server: {resp:?}"); + let headers = resp.headers_mut(); + if let Some(measurements) = measurements.clone() { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); } - Err(reconnect_err) => { - warn!("Reconnect after in-flight drop failed: {reconnect_err}"); - let mut resp = Response::new(full( - "Request failed: connection to proxy-server dropped", - )); - *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - break Ok(resp); + Err(e) => { + error!("Failed to encode measurement values: {e}"); } } } - }; - - match send_result { - Ok(mut resp) => { - debug!("[proxy-client] Read response from proxy-server: {resp:?}"); - let headers = resp.headers_mut(); - if let Some(measurements) = measurements.clone() { - match measurements.to_header_format() { - Ok(header_value) => { - headers.insert(MEASUREMENT_HEADER, header_value); - } - Err(e) => { - error!("Failed to encode measurement values: {e}"); - } - } - } - update_header( - headers, - ATTESTATION_TYPE_HEADER, - remote_attestation_type.as_str(), - ); + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); - break Ok(resp.map(|b| b.boxed())); - } - Err(e) => { - warn!("Failed to send request to proxy-server: {e}"); - match Self::setup_connection_with_backoff( - &target, - &nesting_tls_connector, - true, - ) - .await - { - Ok((new_sender, new_conn, new_attestation)) => { - sender = new_sender; - remote_attestation_type = new_attestation.attestation_type; - measurements = new_attestation.get_measurements().ok().flatten(); - - let (new_conn_done_tx, new_conn_done_rx) = - tokio::sync::watch::channel::>(None); - conn_done_rx = new_conn_done_rx; - - tokio::spawn(async move { - let res = new_conn.await; - let _ = new_conn_done_tx.send(res.err()); - }); - - warn!("Reconnected to proxy-server, retrying request"); - continue; - } - Err(reconnect_err) => { - warn!("Reconnect after request failure failed: {reconnect_err}"); - let mut resp = Response::new(full(format!( - "Request failed: {e}" - ))); - *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - break Ok(resp); - } - } - } + (Ok(resp.map(|b| b.boxed())), false) + } + Err(e) => { + warn!("Failed to send request to proxy-server: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + + (Ok(resp), true) } }; @@ -885,6 +814,12 @@ impl ProxyClient { if response_tx.send(response).is_err() { warn!("Failed to forward response to source client, probably they dropped the connection"); } + + if should_reconnect { + // Leave the inner loop and continue on the reconnect loop + warn!("Reconnecting to proxy-server due to failed request"); + break; + } } else { // The request sender was dropped - so no more incoming requests debug!("Request sender dropped - leaving connection handler loop"); @@ -1033,13 +968,13 @@ impl ProxyClient { .keep_alive_interval(Some(Duration::from_secs(KEEP_ALIVE_INTERVAL))) .keep_alive_timeout(Duration::from_secs(KEEP_ALIVE_TIMEOUT)) .keep_alive_while_idle(true) - .handshake::<_, http_body_util::Full>(outbound_io) + .handshake::<_, hyper::body::Incoming>(outbound_io) .await?; (sender.into(), conn.into()) } HttpVersion::Http1 => { let (sender, conn) = hyper::client::conn::http1::Builder::new() - .handshake::<_, http_body_util::Full>(outbound_io) + .handshake::<_, hyper::body::Incoming>(outbound_io) .await?; (sender.into(), conn.into()) } @@ -1053,10 +988,6 @@ impl ProxyClient { req: hyper::Request, requests_tx: mpsc::Sender, ) -> Result>, ProxyError> { - let (parts, body) = req.into_parts(); - let body = body.collect().await?.to_bytes(); - let req = http::Request::from_parts(parts, body); - let (response_tx, response_rx) = oneshot::channel(); requests_tx.send((req, response_tx)).await?; Ok(response_rx.await??) @@ -1245,6 +1176,10 @@ where mod tests { use attestation::{AttestationType, measurements::MeasurementPolicy}; use std::collections::HashMap; + use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }; use tokio_rustls::TlsConnector; use super::*; @@ -1932,6 +1867,7 @@ mod tests { // This is used to trigger a dropped connection to the proxy server let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + let (reconnected_tx, reconnected_rx) = oneshot::channel(); tokio::spawn(async move { let connection_handle = proxy_server.accept().await.unwrap(); @@ -1943,6 +1879,7 @@ mod tests { // Now accept another connection proxy_server.accept().await.unwrap(); + let _ = reconnected_tx.send(()); }); let proxy_client = ProxyClient::new_with_tls_config( @@ -1971,6 +1908,7 @@ mod tests { // Now break the connection connection_breaker_tx.send(()).unwrap(); + reconnected_rx.await.unwrap(); // Make another request let res = reqwest::get(format!("http://{}", proxy_client_addr)) @@ -1984,6 +1922,128 @@ mod tests { assert_eq!(res_body, "No measurements"); } + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_does_not_retry_failed_request() { + init_tracing(); + + let request_count = Arc::new(AtomicUsize::new(0)); + let request_seen = Arc::new(tokio::sync::Notify::new()); + let (release_tx, release_rx) = tokio::sync::watch::channel(false); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let target_addr = listener.local_addr().unwrap(); + + let app = axum::Router::new().route( + "/", + axum::routing::get({ + let request_count = request_count.clone(); + let request_seen = request_seen.clone(); + let release_rx = release_rx.clone(); + + move || { + let request_count = request_count.clone(); + let request_seen = request_seen.clone(); + let mut release_rx = release_rx.clone(); + + async move { + request_count.fetch_add(1, Ordering::SeqCst); + request_seen.notify_waiters(); + + if !*release_rx.borrow() { + release_rx.changed().await.unwrap(); + } + + "ok" + } + } + }), + ); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + let (reconnected_tx, reconnected_rx) = oneshot::channel(); + + tokio::spawn(async move { + let connection_handle = proxy_server.accept().await.unwrap(); + connection_breaker_rx.await.unwrap(); + connection_handle.abort(); + proxy_server.accept().await.unwrap(); + let _ = reconnected_tx.send(()); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); + }); + + let request_url = format!("http://{}", proxy_client_addr); + let failed_request = tokio::spawn(async move { reqwest::get(request_url).await.unwrap() }); + + loop { + if request_count.load(Ordering::SeqCst) > 0 { + break; + } + + request_seen.notified().await; + } + + connection_breaker_tx.send(()).unwrap(); + release_tx.send(true).unwrap(); + + let failed_response = failed_request.await.unwrap(); + assert_eq!(failed_response.status(), hyper::StatusCode::BAD_GATEWAY); + assert_eq!(request_count.load(Ordering::SeqCst), 1); + + reconnected_rx.await.unwrap(); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_eq!(res.text().await.unwrap(), "ok"); + assert_eq!(request_count.load(Ordering::SeqCst), 2); + } + // Use HTTP 1.1 #[tokio::test(flavor = "multi_thread")] async fn http_proxy_with_http1() { From c1ed5f038665cdb20c306d2ad9401598a97f1bae Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 23 Mar 2026 15:13:41 +0100 Subject: [PATCH 05/10] Strip measurements headers to avoid them being spoofed --- src/lib.rs | 176 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index e7e012a..25f0e19 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -509,6 +509,10 @@ impl ProxyServer { update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for); + // Strip any caller-provided attestation metadata before injecting authenticated values. + headers.remove(ATTESTATION_TYPE_HEADER); + headers.remove(MEASUREMENT_HEADER); + // If we have measurements, from the remote peer, add them to the request header let measurements = measurements.clone(); @@ -782,6 +786,8 @@ impl ProxyClient { Ok(mut resp) => { debug!("[proxy-client] Read response from proxy-server: {resp:?}"); let headers = resp.headers_mut(); + headers.remove(MEASUREMENT_HEADER); + if let Some(measurements) = measurements.clone() { match measurements.to_header_format() { Ok(header_value) => { @@ -1225,6 +1231,56 @@ mod tests { assert!(headers.get(MEASUREMENT_HEADER).is_none()); } + /// Test service that echoes attestation-related request headers as JSON. + async fn request_header_echo_service() -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let app = axum::Router::new().route( + "/", + axum::routing::get(|headers: http::HeaderMap| async move { + axum::Json(serde_json::json!({ + "measurement": headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()), + "attestation_type": headers + .get(ATTESTATION_TYPE_HEADER) + .and_then(|v| v.to_str().ok()), + })) + }), + ); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + addr + } + + /// Test service that deliberately returns a spoofed measurement header. + async fn spoofed_response_measurement_service() -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let app = axum::Router::new().route( + "/", + axum::routing::get(|| async move { + let mut response = http::Response::new("ok".to_string()); + response.headers_mut().insert( + MEASUREMENT_HEADER, + HeaderValue::from_static("{\"spoofed\":\"value\"}"), + ); + response + }), + ); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + addr + } + #[test] fn proxy_alpn_protocols_prefer_http2() { let mut protocols = Vec::new(); @@ -1597,6 +1653,68 @@ mod tests { assert_eq!(res_body, "No measurements"); } + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_strips_spoofed_request_attestation_headers() { + let target_addr = request_header_echo_service().await; + + let (server_cert_chain, server_private_key) = + generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = + generate_tls_config(server_cert_chain.clone(), server_private_key); + + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0", + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::Client::new() + .get(format!("http://{}", proxy_client_addr)) + .header(MEASUREMENT_HEADER, "{\"spoofed\":\"request\"}") + .header(ATTESTATION_TYPE_HEADER, "dcap-tdx") + .send() + .await + .unwrap(); + + let echoed: serde_json::Value = res.json().await.unwrap(); + assert!(echoed["measurement"].is_null()); + assert!(echoed["attestation_type"].is_null()); + } + // Server has mock DCAP, client has mock DCAP and client auth #[tokio::test(flavor = "multi_thread")] async fn http_proxy_mutual_attestation() { @@ -2044,6 +2162,64 @@ mod tests { assert_eq!(request_count.load(Ordering::SeqCst), 2); } + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_strips_spoofed_response_measurement_header() { + let target_addr = spoofed_response_measurement_service().await; + + let (server_cert_chain, server_private_key) = + generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = + generate_tls_config(server_cert_chain.clone(), server_private_key); + + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0", + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + assert_eq!(res.text().await.unwrap(), "ok"); + } + // Use HTTP 1.1 #[tokio::test(flavor = "multi_thread")] async fn http_proxy_with_http1() { From 1e409971c31e5ddb7fee2eba29bc73dc540c0bd9 Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 1 Apr 2026 10:31:21 +0200 Subject: [PATCH 06/10] Fmt --- src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 9544d95..b45c682 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1710,7 +1710,8 @@ mod tests { .await .unwrap(); - let echoed: serde_json::Value = serde_json::from_slice(&res.bytes().await.unwrap()).unwrap(); + let echoed: serde_json::Value = + serde_json::from_slice(&res.bytes().await.unwrap()).unwrap(); assert!(echoed["measurement"].is_null()); assert!(echoed["attestation_type"].is_null()); } From 312ba833c5ec0d1f227376be0bca24e5f91b5ff8 Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 1 Apr 2026 10:58:32 +0200 Subject: [PATCH 07/10] Update field name following merged dep --- Cargo.lock | 56 ++++++++++------------------------------------------- src/lib.rs | 2 +- src/main.rs | 2 +- 3 files changed, 12 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9e9d5ed..efca042 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -593,13 +593,13 @@ dependencies = [ [[package]] name = "attestation" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#f977f11b3a9275fb29430a2d84a8672ded4bc73b" dependencies = [ "anyhow", "az-tdx-vtpm", "base64 0.22.1", "configfs-tsm", - "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override)", + "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?rev=f1dcc65371e941a7b83e3234833d23a1fb232ab1)", "hex", "http", "num-bigint", @@ -670,7 +670,7 @@ dependencies = [ [[package]] name = "attested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#f977f11b3a9275fb29430a2d84a8672ded4bc73b" dependencies = [ "anyhow", "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", @@ -1104,7 +1104,7 @@ dependencies = [ [[package]] name = "cc-eventlog" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#01c26b7878890f765b32b2bde5b4a0086e55fb24" +source = "git+https://github.com/Dstack-TEE/dstack.git?rev=4f602dddc0542cd34da031c90ac0b3a560f316ed#4f602dddc0542cd34da031c90ac0b3a560f316ed" dependencies = [ "anyhow", "digest 0.10.7", @@ -1468,42 +1468,6 @@ dependencies = [ "x509-cert", ] -[[package]] -name = "dcap-qvl" -version = "0.3.12" -source = "git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override#e38818f0b7b600ceadad1ec3efd9e681bcbdc1e5" -dependencies = [ - "anyhow", - "asn1_der", - "base64 0.22.1", - "borsh", - "byteorder", - "chrono", - "const-oid", - "dcap-qvl-webpki", - "der", - "derive_more 2.1.1", - "futures", - "hex", - "log", - "p256", - "parity-scale-codec", - "pem", - "reqwest", - "ring", - "rustls-pki-types", - "scale-info", - "serde", - "serde-human-bytes", - "serde_json", - "sha2", - "signature", - "tracing", - "urlencoding", - "wasm-bindgen-futures", - "x509-cert", -] - [[package]] name = "dcap-qvl" version = "0.3.12" @@ -1729,7 +1693,7 @@ dependencies = [ [[package]] name = "dstack-attest" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#01c26b7878890f765b32b2bde5b4a0086e55fb24" +source = "git+https://github.com/Dstack-TEE/dstack.git?rev=4f602dddc0542cd34da031c90ac0b3a560f316ed#4f602dddc0542cd34da031c90ac0b3a560f316ed" dependencies = [ "anyhow", "cc-eventlog", @@ -1755,7 +1719,7 @@ dependencies = [ [[package]] name = "dstack-types" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#01c26b7878890f765b32b2bde5b4a0086e55fb24" +source = "git+https://github.com/Dstack-TEE/dstack.git?rev=4f602dddc0542cd34da031c90ac0b3a560f316ed#4f602dddc0542cd34da031c90ac0b3a560f316ed" dependencies = [ "parity-scale-codec", "serde", @@ -3044,7 +3008,7 @@ dependencies = [ [[package]] name = "nested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#f977f11b3a9275fb29430a2d84a8672ded4bc73b" dependencies = [ "rustls", "tokio", @@ -3741,7 +3705,7 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "ra-tls" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#01c26b7878890f765b32b2bde5b4a0086e55fb24" +source = "git+https://github.com/Dstack-TEE/dstack.git?rev=4f602dddc0542cd34da031c90ac0b3a560f316ed#4f602dddc0542cd34da031c90ac0b3a560f316ed" dependencies = [ "anyhow", "bon", @@ -4548,7 +4512,7 @@ checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "size-parser" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#01c26b7878890f765b32b2bde5b4a0086e55fb24" +source = "git+https://github.com/Dstack-TEE/dstack.git?rev=4f602dddc0542cd34da031c90ac0b3a560f316ed#4f602dddc0542cd34da031c90ac0b3a560f316ed" dependencies = [ "anyhow", "serde", @@ -4739,7 +4703,7 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tdx-attest" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#01c26b7878890f765b32b2bde5b4a0086e55fb24" +source = "git+https://github.com/Dstack-TEE/dstack.git?rev=4f602dddc0542cd34da031c90ac0b3a560f316ed#4f602dddc0542cd34da031c90ac0b3a560f316ed" dependencies = [ "anyhow", "cc-eventlog", diff --git a/src/lib.rs b/src/lib.rs index b45c682..3cecfd2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1938,7 +1938,7 @@ mod tests { let attestation_verifier = AttestationVerifier { measurement_policy, pccs_url: None, - log_dcap_quote: false, + dump_dcap_quotes: false, override_azure_outdated_tcb: false, }; diff --git a/src/main.rs b/src/main.rs index a80a54b..dcb23d5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -214,7 +214,7 @@ async fn main() -> anyhow::Result<()> { let attestation_verifier = AttestationVerifier { measurement_policy, pccs_url: cli.pccs_url, - log_dcap_quote: cli.log_dcap_quote, + dump_dcap_quotes: cli.log_dcap_quote, override_azure_outdated_tcb: cli.override_azure_outdated_tcb, }; From 7c262f7c87cd946f3034e42a9a3eb58b0989d6e3 Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 22 Apr 2026 09:13:43 +0200 Subject: [PATCH 08/10] Update branch of attested-tls following merging paired PR --- Cargo.lock | 72 +++++++++++++++++++++++++++++++++++++++++++++++------ Cargo.toml | 9 ++++--- src/lib.rs | 11 +++++--- src/main.rs | 4 ++- 4 files changed, 80 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index efca042..041e588 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -561,19 +561,20 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "attestation" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#a96ec2d9096f491e652624c53d3df2b1526ef9f2" +source = "git+https://github.com/flashbots/attested-tls?branch=main#9fb3002d82918b85f780f37b5545c4e112e0e772" dependencies = [ "anyhow", "az-tdx-vtpm", "base64 0.22.1", "configfs-tsm", - "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?rev=f1dcc65371e941a7b83e3234833d23a1fb232ab1)", + "dcap-qvl 0.3.12 (git+https://github.com/Phala-Network/dcap-qvl.git?rev=f1dcc65371e941a7b83e3234833d23a1fb232ab1)", "hex", "http", "num-bigint", "once_cell", "openssl", "parity-scale-codec", + "pccs", "pem-rfc7468", "rand_core 0.6.4", "reqwest", @@ -593,7 +594,7 @@ dependencies = [ [[package]] name = "attestation" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#f977f11b3a9275fb29430a2d84a8672ded4bc73b" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#a96ec2d9096f491e652624c53d3df2b1526ef9f2" dependencies = [ "anyhow", "az-tdx-vtpm", @@ -670,10 +671,10 @@ dependencies = [ [[package]] name = "attested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#f977f11b3a9275fb29430a2d84a8672ded4bc73b" +source = "git+https://github.com/flashbots/attested-tls?branch=main#9fb3002d82918b85f780f37b5545c4e112e0e772" dependencies = [ "anyhow", - "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=main)", "ra-tls", "rcgen 0.14.7", "rustls", @@ -691,8 +692,8 @@ name = "attested-tls-proxy" version = "1.1.1" dependencies = [ "anyhow", - "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", - "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=main)", + "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=main)", "axum", "bytes", "clap", @@ -703,6 +704,7 @@ dependencies = [ "jsonrpsee", "nested-tls", "p256", + "pccs", "pem-rfc7468", "pin-project-lite", "pkcs1", @@ -1504,6 +1506,42 @@ dependencies = [ "x509-cert", ] +[[package]] +name = "dcap-qvl" +version = "0.3.12" +source = "git+https://github.com/Phala-Network/dcap-qvl.git?rev=f1dcc65371e941a7b83e3234833d23a1fb232ab1#f1dcc65371e941a7b83e3234833d23a1fb232ab1" +dependencies = [ + "anyhow", + "asn1_der", + "base64 0.22.1", + "borsh", + "byteorder", + "chrono", + "const-oid", + "dcap-qvl-webpki", + "der", + "derive_more 2.1.1", + "futures", + "hex", + "log", + "p256", + "parity-scale-codec", + "pem", + "reqwest", + "ring", + "rustls-pki-types", + "scale-info", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "signature", + "tracing", + "urlencoding", + "wasm-bindgen-futures", + "x509-cert", +] + [[package]] name = "dcap-qvl-webpki" version = "0.103.4+dcap.1" @@ -3008,7 +3046,7 @@ dependencies = [ [[package]] name = "nested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#f977f11b3a9275fb29430a2d84a8672ded4bc73b" +source = "git+https://github.com/flashbots/attested-tls?branch=main#9fb3002d82918b85f780f37b5545c4e112e0e772" dependencies = [ "rustls", "tokio", @@ -3365,6 +3403,24 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pccs" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=main#9fb3002d82918b85f780f37b5545c4e112e0e772" +dependencies = [ + "anyhow", + "dcap-qvl 0.3.12 (git+https://github.com/Phala-Network/dcap-qvl.git?rev=f1dcc65371e941a7b83e3234833d23a1fb232ab1)", + "hex", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.17", + "time", + "tokio", + "tracing", + "x509-parser 0.18.1", +] + [[package]] name = "pem" version = "3.0.6" diff --git a/Cargo.toml b/Cargo.toml index 3f7e0db..5c4ca20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,10 @@ repository = "https://github.com/flashbots/attested-tls-proxy" keywords = ["attested-TLS", "CVM", "TDX"] [dependencies] -attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } -nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } +attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "main" } +nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "main" } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "main" } +pccs = { git = "https://github.com/flashbots/attested-tls", branch = "main" } tokio = { version = "1.50.0", features = ["full"] } tokio-rustls = { version = "0.26.4", default-features = false } x509-parser = { version = "0.18.0", features = ["verify"] } @@ -47,7 +48,7 @@ pin-project-lite = "0.2.16" [dev-dependencies] tempfile = "3.23.0" tdx-quote = { version = "0.0.5", features = ["mock"] } -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier", features = ["mock"] } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "main", features = ["mock"] } tokio = { version = "1.48.0", features = ["full"] } jsonrpsee = { version = "0.26.0", features = ["server"] } diff --git a/src/lib.rs b/src/lib.rs index 3cecfd2..19d6e8f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1109,10 +1109,14 @@ async fn build_attested_cert_resolver( attestation_generator: AttestationGenerator, certificate_name: String, ) -> Result { - Ok( - AttestedCertificateResolver::new(attestation_generator, None, certificate_name, vec![]) - .await?, + Ok(AttestedCertificateResolver::new( + attestation_generator, + None, + certificate_name, + vec![], + Duration::from_secs(30 * 60), ) + .await?) } async fn build_inner_server_config( @@ -1940,6 +1944,7 @@ mod tests { pccs_url: None, dump_dcap_quotes: false, override_azure_outdated_tcb: false, + internal_pccs: None, }; let proxy_client_result = ProxyClient::new_with_tls_config( diff --git a/src/main.rs b/src/main.rs index dcb23d5..de9d6ba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, ensure}; use attestation::{AttestationType, AttestationVerifier, measurements::MeasurementPolicy}; use clap::{Parser, Subcommand}; +use pccs::Pccs; use std::{fs::File, net::SocketAddr, path::PathBuf}; use tokio::io::AsyncWriteExt; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; @@ -28,7 +29,7 @@ struct Cli { /// If no measurements file is specified, a single attestion type to allow #[arg(long, global = true)] allowed_remote_attestation_type: Option, - /// The URL of a PCCS to use when verifying DCAP attestations. Defaults to Intel PCS. + /// The URL of a PCCS to use when verifying DCAP attestations. Defaults to an internal PCCS. #[arg(long, global = true)] pccs_url: Option, /// Log debug messages @@ -216,6 +217,7 @@ async fn main() -> anyhow::Result<()> { pccs_url: cli.pccs_url, dump_dcap_quotes: cli.log_dcap_quote, override_azure_outdated_tcb: cli.override_azure_outdated_tcb, + internal_pccs: Some(Pccs::new(None)), }; match cli.command { From f51266546a4d80616c30e99fad2992385e240b79 Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 22 Apr 2026 09:45:52 +0200 Subject: [PATCH 09/10] Use aws lc as default crypto provider --- Cargo.lock | 51 +++++++++++++++++++++++++ Cargo.toml | 2 +- attestation-provider-server/src/main.rs | 1 + src/lib.rs | 2 +- src/main.rs | 2 + 5 files changed, 56 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 041e588..df7488f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -744,6 +744,28 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-lc-rs" +version = "1.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ec6fb3fe69024a75fa7e1bfb48aa6cf59706a101658ea01bfd33b2b248a038f" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.40.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f50037ee5e1e41e7b8f9d161680a725bd1626cb6f8c7e901f91f942850852fe7" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "axum" version = "0.8.6" @@ -1100,6 +1122,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37521ac7aabe3d13122dc382493e20c9416f299d2ccd5b3a5340a2570cdeb0f3" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] @@ -1182,6 +1206,15 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "codicon" version = "3.0.0" @@ -2079,6 +2112,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "funty" version = "2.0.0" @@ -2705,6 +2744,16 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.85" @@ -4189,6 +4238,7 @@ version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ + "aws-lc-rs", "brotli", "brotli-decompressor", "once_cell", @@ -4224,6 +4274,7 @@ version = "0.103.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", diff --git a/Cargo.toml b/Cargo.toml index 5c4ca20..e8cdf9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "main attestation = { git = "https://github.com/flashbots/attested-tls", branch = "main" } pccs = { git = "https://github.com/flashbots/attested-tls", branch = "main" } tokio = { version = "1.50.0", features = ["full"] } -tokio-rustls = { version = "0.26.4", default-features = false } +tokio-rustls = { version = "0.26.4", default-features = false, features = ["aws_lc_rs"] } x509-parser = { version = "0.18.0", features = ["verify"] } thiserror = "2.0.17" clap = { version = "4.5.51", features = ["derive", "env"] } diff --git a/attestation-provider-server/src/main.rs b/attestation-provider-server/src/main.rs index 4a91487..661ca2f 100644 --- a/attestation-provider-server/src/main.rs +++ b/attestation-provider-server/src/main.rs @@ -99,6 +99,7 @@ async fn main() -> anyhow::Result<()> { pccs_url: None, dump_dcap_quotes: cli.log_dcap_quote, override_azure_outdated_tcb: false, + internal_pccs: None, }; let attestation_message = diff --git a/src/lib.rs b/src/lib.rs index 19d6e8f..dd9c17f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1363,7 +1363,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn inner_only_listener_negotiates_http2_by_default() { - let _ = rustls::crypto::ring::default_provider().install_default(); + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let target_addr = example_http_service().await; let proxy_server = ProxyServer::new( diff --git a/src/main.rs b/src/main.rs index de9d6ba..f4be61c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -160,6 +160,8 @@ enum CliCommand { #[tokio::main] async fn main() -> anyhow::Result<()> { + let _ = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default(); + let cli = Cli::parse(); ensure!( From 735e368c1e350b87c8d6d3d340cbf2de3f10f964 Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 22 Apr 2026 10:01:29 +0200 Subject: [PATCH 10/10] Use aws lc as default crypto provider --- attested-tls/Cargo.toml | 2 +- attested-tls/src/test_helpers.rs | 8 ++++++++ src/test_helpers.rs | 18 ++++++++++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/attested-tls/Cargo.toml b/attested-tls/Cargo.toml index d40e7b6..60f74c0 100644 --- a/attested-tls/Cargo.toml +++ b/attested-tls/Cargo.toml @@ -9,7 +9,7 @@ keywords = ["attested-TLS", "CVM", "TDX"] [dependencies] tokio = { version = "1.48.0", features = ["full"] } -tokio-rustls = { version = "0.26.4", default-features = false } +tokio-rustls = { version = "0.26.4", default-features = false, features = ["aws_lc_rs"] } sha2 = "0.10.9" x509-parser = "0.18.0" thiserror = "2.0.17" diff --git a/attested-tls/src/test_helpers.rs b/attested-tls/src/test_helpers.rs index 9b218a3..2ba933d 100644 --- a/attested-tls/src/test_helpers.rs +++ b/attested-tls/src/test_helpers.rs @@ -10,6 +10,10 @@ use crate::SUPPORTED_ALPN_PROTOCOL_VERSIONS; pub use attestation::measurements::mock_dcap_measurements; +fn install_crypto_provider() { + let _ = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default(); +} + /// Helper to generate a self-signed certificate for testing pub fn generate_certificate_chain( ip: IpAddr, @@ -36,6 +40,8 @@ pub fn generate_tls_config( certificate_chain: Vec>, key: PrivateKeyDer<'static>, ) -> (ServerConfig, ClientConfig) { + install_crypto_provider(); + let supported_protocols: Vec<_> = SUPPORTED_ALPN_PROTOCOL_VERSIONS .into_iter() .map(|p| p.to_vec()) @@ -67,6 +73,8 @@ pub fn generate_tls_config_with_client_auth( bob_certificate_chain: Vec>, bob_key: PrivateKeyDer<'static>, ) -> ((ServerConfig, ClientConfig), (ServerConfig, ClientConfig)) { + install_crypto_provider(); + let supported_protocols: Vec<_> = SUPPORTED_ALPN_PROTOCOL_VERSIONS .into_iter() .map(|p| p.to_vec()) diff --git a/src/test_helpers.rs b/src/test_helpers.rs index b8509c0..9448f7f 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -16,10 +16,20 @@ use crate::MEASUREMENT_HEADER; static INIT: Once = Once::new(); +pub fn install_crypto_provider() { + static CRYPTO_PROVIDER_INIT: Once = Once::new(); + + CRYPTO_PROVIDER_INIT.call_once(|| { + let _ = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default(); + }); +} + /// Helper to generate a self-signed certificate for testing with a DNS subject name pub fn generate_certificate_chain_for_host( host: &str, ) -> (Vec>, PrivateKeyDer<'static>) { + install_crypto_provider(); + let mut params = rcgen::CertificateParams::new(vec![host.to_string()]).unwrap(); params .subject_alt_names @@ -44,6 +54,8 @@ pub fn generate_tls_config( certificate_chain: Vec>, key: PrivateKeyDer<'static>, ) -> (ServerConfig, ClientConfig) { + install_crypto_provider(); + let server_config = ServerConfig::builder() .with_no_client_auth() .with_single_cert(certificate_chain.clone(), key) @@ -66,6 +78,8 @@ pub fn generate_tls_config_with_client_auth( bob_certificate_chain: Vec>, bob_key: PrivateKeyDer<'static>, ) -> ((ServerConfig, ClientConfig), (ServerConfig, ClientConfig)) { + install_crypto_provider(); + let (alice_client_verifier, alice_root_store) = client_verifier_from_remote_cert(bob_certificate_chain[0].clone()); @@ -117,6 +131,8 @@ fn client_verifier_from_remote_cert( /// Simple http server used in tests which returns in the response the measurement header from the /// request pub async fn example_http_service() -> SocketAddr { + install_crypto_provider(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -138,6 +154,8 @@ async fn get_handler(headers: http::HeaderMap) -> impl IntoResponse { } pub fn init_tracing() { + install_crypto_provider(); + INIT.call_once(|| { let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));