From a65ea6905a4511654d2a2c67117433f9685021ca Mon Sep 17 00:00:00 2001 From: Evrard-Nil Daillet Date: Mon, 23 Mar 2026 13:59:16 -0700 Subject: [PATCH 1/7] fix: update aws-lc-sys and rustls-webpki to resolve security vulnerabilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - aws-lc-rs 1.15.4 → 1.16.2 (pulls aws-lc-sys 0.37.0 → 0.39.0) Fixes RUSTSEC-2026-0044 through 0048 (X.509 bypass, PKCS7 bypasses, AES-CCM timing, CRL scope check) - rustls-webpki 0.103.9 → 0.103.10 Fixes RUSTSEC-2026-0049 (CRL Distribution Point matching bypass) --- Cargo.lock | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 12140b6..d828946 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -825,9 +825,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-rs" -version = "1.15.4" +version = "1.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b7b6141e96a8c160799cc2d5adecd5cbbe5054cb8c7c4af53da0f83bb7ad256" +checksum = "a054912289d18629dc78375ba2c3726a3afe3ff71b4edba9dedfca0e3446d1fc" dependencies = [ "aws-lc-sys", "zeroize", @@ -835,9 +835,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.37.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c34dda4df7017c8db52132f0f8a2e0f8161649d15723ed63fc00c82d0f2081a" +checksum = "1fa7e52a4c5c547c741610a2c6f123f3881e409b714cd27e6798ef020c514f0a" dependencies = [ "cc", "cmake", @@ -3937,9 +3937,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.9" +version = "0.103.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" dependencies = [ "aws-lc-rs", "ring", From 9d067c3019bb87190f6cc7fd49fb304379432055 Mon Sep 17 00:00:00 2001 From: Evrard-Nil Daillet Date: Mon, 23 Mar 2026 14:38:52 -0700 Subject: [PATCH 2/7] perf: parallelize attestation and add e2e benchmarks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Attestation generation optimization: - Parallelize TDX quote + GPU evidence collection with tokio::try_join! (previously sequential — these are independent: TDX talks to dstack socket, GPU evidence spawns a Python subprocess) - Cache dstack info() with OnceCell (static data, never changes during process lifetime — was re-fetched on every attestation) New benchmark suite (benches/e2e.rs): - End-to-end JSON completion flow (1/5/20 messages) - End-to-end streaming completion flow (5/20/50 chunks) - Attestation cache operations (hit/miss/set/semaphore) - Attestation response serialization - Request body processing pipeline (SHA256, JSON parse, reserialize) - Response signing full pipeline (parse → hash → sign → cache) - Streaming SSE parse+hash pipeline - Auth token constant-time comparison - JSON body round-trip (parse, modify, reserialize) --- Cargo.lock | 2 + Cargo.toml | 6 +- benches/e2e.rs | 656 +++++++++++++++++++++++++++++++++++++++++++++ src/attestation.rs | 113 +++++--- 4 files changed, 743 insertions(+), 34 deletions(-) create mode 100644 benches/e2e.rs diff --git a/Cargo.lock b/Cargo.lock index d828946..e6b5540 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1322,6 +1322,7 @@ dependencies = [ "ciborium", "clap", "criterion-plot", + "futures", "is-terminal", "itertools 0.10.5", "num-traits", @@ -1334,6 +1335,7 @@ dependencies = [ "serde_derive", "serde_json", "tinytemplate", + "tokio", "walkdir", ] diff --git a/Cargo.toml b/Cargo.toml index 48a97f8..d44af38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,8 +57,12 @@ metrics-exporter-prometheus = "0.18" [dev-dependencies] wiremock = "0.6" tower = { version = "0.5", features = ["util"] } -criterion = { version = "0.5", features = ["html_reports"] } +criterion = { version = "0.5", features = ["html_reports", "async_tokio"] } [[bench]] name = "hot_path" harness = false + +[[bench]] +name = "e2e" +harness = false diff --git a/benches/e2e.rs b/benches/e2e.rs new file mode 100644 index 0000000..69dbf5f --- /dev/null +++ b/benches/e2e.rs @@ -0,0 +1,656 @@ +//! End-to-end benchmarks for attestation and completion flows. +//! +//! These benchmarks measure the proxy overhead — the time spent in our code +//! between receiving a request and forwarding/returning a response. Backend +//! latency is simulated with wiremock returning instantly. + +use std::sync::Arc; + +use axum::body::Body; +use axum::http::{Request, StatusCode}; +use axum::middleware; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use sha2::{Digest, Sha256}; +use tower::ServiceExt; +use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use vllm_proxy_rs::*; + +// ── Test keys (same as integration tests) ── + +const ECDSA_KEY: [u8; 32] = [ + 0xac, 0x09, 0x74, 0xbe, 0xc3, 0x9a, 0x17, 0xe3, 0x6b, 0xa4, 0xa6, 0xb4, 0xd2, 0x38, 0xff, + 0x94, 0x4b, 0xac, 0xb3, 0x5e, 0x5d, 0xc4, 0xaf, 0x0f, 0x33, 0x47, 0xe5, 0x87, 0x31, 0x79, + 0x67, 0x0f, +]; + +const ED25519_KEY: [u8; 32] = [ + 0x9d, 0x61, 0xb1, 0x9d, 0xef, 0xfd, 0x5a, 0x60, 0xba, 0x84, 0x4a, 0xf4, 0x92, 0xec, 0x2c, + 0xc4, 0x44, 0x49, 0xc5, 0x69, 0x7b, 0x32, 0x69, 0x19, 0x70, 0x3b, 0xac, 0x03, 0x1c, 0xae, + 0x7f, 0x60, +]; + +fn build_test_app(mock_url: &str) -> axum::Router { + let base = mock_url.trim_end_matches('/'); + + let config = config::Config { + model_name: "bench-model".to_string(), + token: "bench-token".to_string(), + vllm_base_url: mock_url.to_string(), + chat_completions_url: format!("{base}/v1/chat/completions"), + completions_url: format!("{base}/v1/completions"), + tokenize_url: format!("{base}/tokenize"), + metrics_url: format!("{base}/metrics"), + models_url: format!("{base}/v1/models"), + images_url: format!("{base}/v1/images/generations"), + images_edits_url: format!("{base}/v1/images/edits"), + transcriptions_url: format!("{base}/v1/audio/transcriptions"), + embeddings_url: format!("{base}/v1/embeddings"), + rerank_url: format!("{base}/v1/rerank"), + score_url: format!("{base}/v1/score"), + max_keepalive: 100, + max_request_size: 1024 * 1024, + max_image_request_size: 5 * 1024 * 1024, + max_audio_request_size: 10 * 1024 * 1024, + chat_cache_expiration_secs: 1200, + attestation_cache_ttl_secs: 300, + dev_mode: true, + gpu_no_hw_mode: true, + git_rev: "bench".to_string(), + rate_limit_per_second: 10000, + rate_limit_burst_size: 20000, + rate_limit_trust_proxy_headers: false, + cloud_api_url: None, + tls_cert_path: None, + timeout_secs: 30, + timeout_tokenize_secs: 5, + openai_chat_compatibility_check_enabled: false, + startup_check_retries: 0, + startup_check_retry_delay_secs: 0, + startup_check_timeout_secs: 1, + }; + + let ecdsa = signing::EcdsaContext::from_key_bytes(&ECDSA_KEY).unwrap(); + let ed25519 = signing::Ed25519Context::from_key_bytes(&ED25519_KEY).unwrap(); + let signing_pair = signing::SigningPair { ecdsa, ed25519 }; + + let chat_cache = cache::ChatCache::new("bench-model", 1200); + let http_client = reqwest::Client::new(); + + let metrics_handle = metrics_exporter_prometheus::PrometheusBuilder::new() + .build_recorder() + .handle(); + + let state = AppState { + config: Arc::new(config), + signing: Arc::new(signing_pair), + cache: Arc::new(chat_cache), + attestation_cache: Arc::new(attestation::AttestationCache::new(300)), + http_client, + metrics_handle, + tls_cert_fingerprint: None, + }; + + let rate_limiter = rate_limit::build_rate_limiter(10000, 20000); + let rate_limit_state = rate_limit::RateLimitState { + limiter: rate_limiter, + trust_proxy_headers: false, + }; + + routes::build_router() + .layer(middleware::from_fn(rate_limit::rate_limit_middleware)) + .layer(axum::Extension(rate_limit_state)) + .layer(middleware::from_fn(request_id_middleware)) + .with_state(state) +} + +// ── Helpers ── + +fn make_chat_request(messages: usize) -> String { + let msgs: Vec = (0..messages) + .map(|i| { + serde_json::json!({ + "role": if i % 2 == 0 { "user" } else { "assistant" }, + "content": format!("Message number {i} with some typical content that a user might send in a conversation.") + }) + }) + .collect(); + serde_json::json!({ + "model": "bench-model", + "messages": msgs, + "stream": false + }) + .to_string() +} + +fn make_chat_response(id: &str) -> String { + serde_json::json!({ + "id": id, + "object": "chat.completion", + "created": 1234567890, + "model": "bench-model", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "This is a typical assistant response with some content that would be returned from the model." + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 50, + "completion_tokens": 30, + "total_tokens": 80 + } + }) + .to_string() +} + +fn make_streaming_response(id: &str, num_chunks: usize) -> String { + let mut body = String::new(); + for i in 0..num_chunks { + let chunk = serde_json::json!({ + "id": id, + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "bench-model", + "choices": [{ + "index": 0, + "delta": { "content": format!("word{i} ") }, + "finish_reason": null + }] + }); + body.push_str(&format!("data: {}\n\n", serde_json::to_string(&chunk).unwrap())); + } + // Final chunk with usage + let final_chunk = serde_json::json!({ + "id": id, + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "bench-model", + "choices": [], + "usage": { + "prompt_tokens": 50, + "completion_tokens": num_chunks as i64, + "total_tokens": 50 + num_chunks as i64 + } + }); + body.push_str(&format!( + "data: {}\n\n", + serde_json::to_string(&final_chunk).unwrap() + )); + body.push_str("data: [DONE]\n\n"); + body +} + +// ── Attestation benchmarks ── + +fn bench_attestation_cache_operations(c: &mut Criterion) { + let mut group = c.benchmark_group("attestation_cache"); + let rt = tokio::runtime::Runtime::new().unwrap(); + + let cache = attestation::AttestationCache::new(300); + + // Pre-populate cache + let report = types::AttestationReport { + model_name: "bench-model".to_string(), + signing_address: "0xabcdef1234567890".to_string(), + signing_algo: "ecdsa".to_string(), + signing_public_key: "04abcdef".to_string(), + request_nonce: hex::encode([0xAA; 32]), + intel_quote: "base64quote".repeat(100), + nvidia_payload: serde_json::json!({ + "nonce": hex::encode([0xAA; 32]), + "evidence_list": [{"gpu": "H100"}], + "arch": "HOPPER" + }) + .to_string(), + event_log: serde_json::json!({"entries": []}), + info: serde_json::json!({"version": "1.0"}), + tls_cert_fingerprint: None, + }; + rt.block_on(cache.set("ecdsa", false, report.clone())); + + group.bench_function("cache_hit", |b| { + b.to_async(&rt).iter(|| async { + black_box(cache.get("ecdsa", false).await) + }) + }); + + group.bench_function("cache_miss", |b| { + b.to_async(&rt).iter(|| async { + black_box(cache.get("ed25519", true).await) + }) + }); + + group.bench_function("cache_set", |b| { + b.to_async(&rt).iter(|| async { + cache.set("ecdsa", false, report.clone()).await + }) + }); + + group.bench_function("semaphore_acquire_uncontended", |b| { + b.to_async(&rt).iter(|| async { + let permit = cache.acquire_gpu_permit().await; + drop(black_box(permit)); + }) + }); + + group.finish(); +} + +fn bench_attestation_report_serialization(c: &mut Criterion) { + let report = types::AttestationReport { + model_name: "bench-model".to_string(), + signing_address: "0xabcdef1234567890abcdef1234567890abcdef12".to_string(), + signing_algo: "ecdsa".to_string(), + signing_public_key: hex::encode([0xAB; 65]), + request_nonce: hex::encode([0xCD; 32]), + intel_quote: "base64encodedquotedata".repeat(50), + nvidia_payload: serde_json::json!({ + "nonce": hex::encode([0xCD; 32]), + "evidence_list": [ + {"gpu": "H100", "evidence": "base64data".repeat(20)}, + {"gpu": "H100", "evidence": "base64data".repeat(20)}, + ], + "arch": "HOPPER" + }) + .to_string(), + event_log: serde_json::json!({"entries": [ + {"type": "init", "data": "some event"}, + {"type": "measure", "data": "another event"}, + ]}), + info: serde_json::json!({"version": "1.0", "tcb": "123"}), + tls_cert_fingerprint: None, + }; + + let response = types::AttestationResponse { + report: report.clone(), + all_attestations: vec![report], + }; + + c.bench_function("attestation_response_serialize", |b| { + b.iter(|| serde_json::to_value(black_box(&response)).unwrap()) + }); +} + +// ── Completion flow benchmarks (end-to-end through the proxy) ── + +fn bench_json_completion_e2e(c: &mut Criterion) { + let mut group = c.benchmark_group("json_completion_e2e"); + let rt = tokio::runtime::Runtime::new().unwrap(); + + for msg_count in [1, 5, 20] { + let request_body = make_chat_request(msg_count); + let response_body = make_chat_response(&format!("chatcmpl-bench-{msg_count}")); + + // Set up mock server once per parameter + let mock_server = rt.block_on(async { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string(&response_body) + .insert_header("content-type", "application/json"), + ) + .mount(&server) + .await; + server + }); + + let mock_uri = mock_server.uri(); + + let app = build_test_app(&mock_uri); + + group.bench_with_input( + BenchmarkId::new("messages", msg_count), + &request_body, + |b, req_body| { + b.to_async(&rt).iter(|| { + let req_body = req_body.clone(); + let app = app.clone(); + async move { + let request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .header("authorization", "Bearer bench-token") + .body(Body::from(req_body)) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + let status = response.status(); + let body = axum::body::to_bytes(response.into_body(), 1024 * 1024) + .await + .unwrap(); + assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); + black_box(body) + } + }) + }, + ); + } + + group.finish(); +} + +fn bench_streaming_completion_e2e(c: &mut Criterion) { + let mut group = c.benchmark_group("streaming_completion_e2e"); + // Give streaming benchmarks more time since they involve spawned tasks + group.sample_size(50); + let rt = tokio::runtime::Runtime::new().unwrap(); + + for chunk_count in [5, 20, 50] { + let request_body = { + let mut v: serde_json::Value = serde_json::from_str(&make_chat_request(3)).unwrap(); + v["stream"] = serde_json::Value::Bool(true); + serde_json::to_string(&v).unwrap() + }; + let response_body = make_streaming_response("chatcmpl-stream-bench", chunk_count); + + let mock_server = rt.block_on(async { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string(&response_body) + .insert_header("content-type", "text/event-stream"), + ) + .mount(&server) + .await; + server + }); + + let mock_uri = mock_server.uri(); + let app = build_test_app(&mock_uri); + + group.bench_with_input( + BenchmarkId::new("chunks", chunk_count), + &request_body, + |b, req_body| { + b.to_async(&rt).iter(|| { + let req_body = req_body.clone(); + let app = app.clone(); + async move { + let request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .header("authorization", "Bearer bench-token") + .body(Body::from(req_body)) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + let body = axum::body::to_bytes(response.into_body(), 10 * 1024 * 1024) + .await + .unwrap(); + black_box(body) + } + }) + }, + ); + } + + group.finish(); +} + +// ── Proxy overhead component benchmarks ── + +fn bench_request_body_processing(c: &mut Criterion) { + let mut group = c.benchmark_group("request_body_processing"); + + for msg_count in [1, 5, 20, 50] { + let body = make_chat_request(msg_count); + let body_bytes = body.as_bytes(); + + group.bench_with_input( + BenchmarkId::new("sha256_hash", msg_count), + body_bytes, + |b, data| { + b.iter(|| hex::encode(Sha256::digest(black_box(data)))) + }, + ); + + group.bench_with_input( + BenchmarkId::new("json_parse", msg_count), + &body, + |b, data| { + b.iter(|| { + let v: serde_json::Value = serde_json::from_str(black_box(data)).unwrap(); + black_box(v) + }) + }, + ); + + group.bench_with_input( + BenchmarkId::new("json_parse_and_reserialize", msg_count), + &body, + |b, data| { + b.iter(|| { + let v: serde_json::Value = serde_json::from_str(black_box(data)).unwrap(); + let out = serde_json::to_vec(&v).unwrap(); + black_box(out) + }) + }, + ); + } + + group.finish(); +} + +fn bench_response_signing_full(c: &mut Criterion) { + let mut group = c.benchmark_group("response_signing_full"); + + let ecdsa = signing::EcdsaContext::from_key_bytes(&ECDSA_KEY).unwrap(); + let ed25519 = signing::Ed25519Context::from_key_bytes(&ED25519_KEY).unwrap(); + let pair = Arc::new(signing::SigningPair { ecdsa, ed25519 }); + let cache = Arc::new(cache::ChatCache::new("bench-model", 1200)); + + // Simulate the full response processing pipeline: + // parse JSON -> extract ID -> hash -> sign -> serialize signature -> cache + for response_size in ["small", "medium", "large"] { + let response_body = match response_size { + "small" => serde_json::json!({ + "id": "chatcmpl-bench", + "choices": [{"message": {"content": "Hi"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 5, "completion_tokens": 1, "total_tokens": 6} + }), + "medium" => serde_json::json!({ + "id": "chatcmpl-bench", + "choices": [{"message": {"content": "x".repeat(1000)}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 50, "completion_tokens": 200, "total_tokens": 250} + }), + "large" => serde_json::json!({ + "id": "chatcmpl-bench", + "choices": [{"message": {"content": "x".repeat(10000)}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 500, "completion_tokens": 2000, "total_tokens": 2500} + }), + _ => unreachable!(), + }; + let response_bytes = serde_json::to_vec(&response_body).unwrap(); + let request_sha256 = hex::encode(Sha256::digest(b"test request")); + + group.bench_with_input( + BenchmarkId::new("pipeline", response_size), + &(response_bytes, request_sha256), + |b, (resp_bytes, req_hash)| { + let pair = pair.clone(); + let cache = cache.clone(); + b.iter(|| { + // 1. Parse response JSON + let response_data: serde_json::Value = + serde_json::from_slice(black_box(resp_bytes)).unwrap(); + + // 2. Extract ID + let chat_id = response_data["id"].as_str().unwrap(); + + // 3. Serialize to final form + let final_body = serde_json::to_string(&response_data).unwrap(); + + // 4. Hash response + let response_sha256 = hex::encode(Sha256::digest(final_body.as_bytes())); + + // 5. Sign + let text = format!("bench-model:{req_hash}:{response_sha256}"); + let signed = pair.sign_chat(&text).unwrap(); + + // 6. Serialize signature and cache + let signed_json = serde_json::to_string(&signed).unwrap(); + cache.set_chat(chat_id, &signed_json); + + black_box(final_body) + }) + }, + ); + } + + group.finish(); +} + +fn bench_streaming_sse_processing(c: &mut Criterion) { + let mut group = c.benchmark_group("streaming_sse_processing"); + + // Benchmark the SSE parser + hasher pipeline (what runs per-chunk in streaming) + for chunk_count in [10, 50, 200] { + let chunks: Vec> = (0..chunk_count) + .map(|i| { + let chunk = serde_json::json!({ + "id": "chatcmpl-bench", + "object": "chat.completion.chunk", + "choices": [{"delta": {"content": format!("word{i} ")}}] + }); + format!("data: {}\n\n", serde_json::to_string(&chunk).unwrap()).into_bytes() + }) + .collect(); + + group.bench_with_input( + BenchmarkId::new("parse_and_hash", chunk_count), + &chunks, + |b, chunks| { + b.iter(|| { + let mut parser = proxy::SseParser::new(); + let mut hasher = Sha256::new(); + for chunk in chunks { + parser.process_chunk(black_box(chunk)); + hasher.update(chunk); + } + // Final sign text construction + let response_sha256 = hex::encode(hasher.finalize()); + black_box(response_sha256); + black_box(&parser.chat_id); + }) + }, + ); + } + + group.finish(); +} + +fn bench_auth_token_comparison(c: &mut Criterion) { + use subtle::ConstantTimeEq; + + let token = "rr9w3S91rog35JM6Sgr2YqwbMvKrbnLA95hQoiwip+4="; + let matching = "rr9w3S91rog35JM6Sgr2YqwbMvKrbnLA95hQoiwip+4="; + let non_matching = "xx9w3S91rog35JM6Sgr2YqwbMvKrbnLA95hQoiwip+4="; + + let mut group = c.benchmark_group("auth_token"); + + group.bench_function("constant_time_match", |b| { + b.iter(|| { + let result: bool = token + .as_bytes() + .ct_eq(black_box(matching.as_bytes())) + .into(); + black_box(result) + }) + }); + + group.bench_function("constant_time_mismatch", |b| { + b.iter(|| { + let result: bool = token + .as_bytes() + .ct_eq(black_box(non_matching.as_bytes())) + .into(); + black_box(result) + }) + }); + + group.finish(); +} + +fn bench_json_body_round_trip(c: &mut Criterion) { + let mut group = c.benchmark_group("json_round_trip"); + + // This measures the overhead of parsing request body, modifying it + // (strip_empty_tool_calls, force stream_options), and re-serializing. + let body_with_tools = serde_json::json!({ + "model": "bench-model", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi", "tool_calls": []}, + {"role": "user", "content": "How are you?"} + ], + "stream": true, + "stream_options": {"include_usage": false} + }); + let body_str = serde_json::to_string(&body_with_tools).unwrap(); + + group.bench_function("parse_modify_reserialize", |b| { + b.iter(|| { + let mut v: serde_json::Value = + serde_json::from_str(black_box(&body_str)).unwrap(); + + // strip_empty_tool_calls equivalent + if let Some(messages) = v.get_mut("messages").and_then(|m| m.as_array_mut()) { + for message in messages.iter_mut() { + if let Some(obj) = message.as_object_mut() { + if let Some(tool_calls) = obj.get("tool_calls") { + if tool_calls.as_array().map(|a| a.is_empty()).unwrap_or(false) { + obj.remove("tool_calls"); + } + } + } + } + } + + // Force stream_options + if let Some(stream_opts) = v.get_mut("stream_options").and_then(|v| v.as_object_mut()) { + stream_opts.insert("include_usage".into(), true.into()); + } + + let out = serde_json::to_vec(&v).unwrap(); + black_box(out) + }) + }); + + group.finish(); +} + +// ── Criterion configuration ── + +criterion_group!( + attestation, + bench_attestation_cache_operations, + bench_attestation_report_serialization, +); + +criterion_group!( + completion, + bench_json_completion_e2e, + bench_streaming_completion_e2e, +); + +criterion_group!( + proxy_overhead, + bench_request_body_processing, + bench_response_signing_full, + bench_streaming_sse_processing, + bench_auth_token_comparison, + bench_json_body_round_trip, +); + +criterion_main!(attestation, completion, proxy_overhead); diff --git a/src/attestation.rs b/src/attestation.rs index 473bab9..5266c45 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::time::Instant; use sha2::{Digest, Sha256}; -use tokio::sync::{RwLock, Semaphore}; +use tokio::sync::{OnceCell, RwLock, Semaphore}; use tracing::{error, info, warn}; use crate::types::AttestationReport; @@ -35,6 +35,8 @@ pub struct AttestationCache { gpu_semaphore: Semaphore, /// Cache TTL in seconds. ttl_secs: u64, + /// Cached dstack info (static for the lifetime of the process). + dstack_info: OnceCell, } impl AttestationCache { @@ -43,9 +45,22 @@ impl AttestationCache { reports: RwLock::new(HashMap::new()), gpu_semaphore: Semaphore::new(1), ttl_secs, + dstack_info: OnceCell::new(), } } + /// Get cached dstack info, fetching it once on first call. + async fn get_dstack_info(&self) -> anyhow::Result { + self.dstack_info + .get_or_try_init(|| async { + let client = dstack_sdk::dstack_client::DstackClient::new(None); + let info = client.info().await?; + serde_json::to_value(&info).map_err(anyhow::Error::from) + }) + .await + .cloned() + } + /// Get a cached report if it exists and is fresh. pub async fn get( &self, @@ -128,16 +143,19 @@ pub fn spawn_cache_refresh_task( // Refresh without TLS fingerprint (most common). let _permit = cache.acquire_gpu_permit().await; - match generate_attestation_inner(AttestationParams { - model_name: &model_name, - signing_address: &signing_address, - signing_algo: algo, - signing_public_key: &signing_public_key, - signing_address_bytes: &signing_address_bytes, - nonce: None, - gpu_no_hw_mode, - tls_cert_fingerprint: None, - }) + match generate_attestation_inner( + AttestationParams { + model_name: &model_name, + signing_address: &signing_address, + signing_algo: algo, + signing_public_key: &signing_public_key, + signing_address_bytes: &signing_address_bytes, + nonce: None, + gpu_no_hw_mode, + tls_cert_fingerprint: None, + }, + Some(&cache), + ) .await { Ok(report) => { @@ -153,16 +171,19 @@ pub fn spawn_cache_refresh_task( // Also refresh with TLS fingerprint if configured. if let Some(ref fp) = tls_cert_fingerprint { let _permit = cache.acquire_gpu_permit().await; - match generate_attestation_inner(AttestationParams { - model_name: &model_name, - signing_address: &signing_address, - signing_algo: algo, - signing_public_key: &signing_public_key, - signing_address_bytes: &signing_address_bytes, - nonce: None, - gpu_no_hw_mode, - tls_cert_fingerprint: Some(fp.as_str()), - }) + match generate_attestation_inner( + AttestationParams { + model_name: &model_name, + signing_address: &signing_address, + signing_algo: algo, + signing_public_key: &signing_public_key, + signing_address_bytes: &signing_address_bytes, + nonce: None, + gpu_no_hw_mode, + tls_cert_fingerprint: Some(fp.as_str()), + }, + Some(&cache), + ) .await { Ok(report) => { @@ -469,8 +490,15 @@ pub struct AttestationParams<'a> { } /// Generate a complete attestation report (core logic, no caching). +/// +/// Parallelizes the two slow operations: +/// - TDX quote generation (dstack Unix socket RPC) +/// - GPU evidence collection (Python subprocess with NVML) +/// +/// dstack info is cached for the process lifetime (it never changes). async fn generate_attestation_inner( params: AttestationParams<'_>, + cache: Option<&AttestationCache>, ) -> Result { let nonce_bytes = parse_nonce(params.nonce)?; let nonce_hex = hex::encode(nonce_bytes); @@ -489,19 +517,38 @@ async fn generate_attestation_inner( fp_bytes.as_deref(), ); - // Get TDX quote from dstack - let client = dstack_sdk::dstack_client::DstackClient::new(None); - let quote_result = client.get_quote(report_data).await?; - let event_log: serde_json::Value = - serde_json::from_str("e_result.event_log).map_err(anyhow::Error::from)?; + // Run TDX quote and GPU evidence collection in parallel. + // These are independent: TDX quote talks to dstack via Unix socket, + // GPU evidence spawns a Python subprocess calling NVML. + let gpu_no_hw_mode = params.gpu_no_hw_mode; + let nonce_hex_clone = nonce_hex.clone(); + let (quote_result, gpu_evidence) = tokio::try_join!( + async { + let client = dstack_sdk::dstack_client::DstackClient::new(None); + client + .get_quote(report_data) + .await + .map_err(AttestationError::Internal) + }, + async { + collect_gpu_evidence(&nonce_hex_clone, gpu_no_hw_mode) + .await + .map_err(AttestationError::Internal) + }, + )?; - // Collect GPU evidence - let gpu_evidence = collect_gpu_evidence(&nonce_hex, params.gpu_no_hw_mode).await?; + let event_log: serde_json::Value = serde_json::from_str("e_result.event_log) + .map_err(|e| AttestationError::Internal(anyhow::Error::from(e)))?; let nvidia_payload = build_nvidia_payload(&nonce_hex, &gpu_evidence); - // Get system info - let info = client.info().await?; - let info_value = serde_json::to_value(&info).map_err(anyhow::Error::from)?; + // dstack info is static — use cached value if available. + let info_value = if let Some(cache) = cache { + cache.get_dstack_info().await.map_err(AttestationError::Internal)? + } else { + let client = dstack_sdk::dstack_client::DstackClient::new(None); + let info = client.info().await.map_err(AttestationError::Internal)?; + serde_json::to_value(&info).map_err(|e| AttestationError::Internal(anyhow::Error::from(e)))? + }; Ok(AttestationReport { model_name: params.model_name.to_string(), @@ -549,13 +596,13 @@ pub async fn generate_attestation( return Ok(report); } } - let report = generate_attestation_inner(params).await?; + let report = generate_attestation_inner(params, Some(cache)).await?; if is_nonceless { cache.set(&signing_algo, include_tls, report.clone()).await; } report } else { - generate_attestation_inner(params).await? + generate_attestation_inner(params, None).await? }; Ok(report) From 382845ce163132abb0d87413a79a848ad9b8fa89 Mon Sep 17 00:00:00 2001 From: Evrard-Nil Daillet Date: Mon, 23 Mar 2026 14:45:21 -0700 Subject: [PATCH 3/7] perf: persistent Python worker for GPU evidence collection Replace subprocess-per-call with a long-running Python process that keeps the interpreter, verifier module imports, and NVML driver initialized across requests. Communication via JSON lines over stdin/stdout pipes. Before: each attestation spawns python3, imports verifier/cc_admin, calls nvmlInit(), collects evidence, exits. ~0.5-2s overhead per call just from Python startup + module loading + NVML initialization. After: worker spawns once, stays alive, processes nonce requests via pipe. Only the actual GPU evidence collection time remains (~1-5s depending on GPU load). Python startup + import + nvmlInit amortized to zero after first call. Design: - GpuEvidenceWorker struct manages the child process lifecycle - Worker sends {"ready": true} on stdout after initialization - Requests: {"nonce": "", "no_gpu_mode": bool} - Responses: {"ok": true, "evidence": [...]} or {"ok": false, "error": "..."} - Auto-restart on worker death with one retry - Falls back to subprocess-per-call if worker can't spawn - All access serialized by existing gpu_semaphore (NVML constraint) --- Dockerfile | 5 +- gpu_evidence_worker.py | 77 +++++++++++ src/attestation.rs | 281 +++++++++++++++++++++++++++++++++++------ 3 files changed, 324 insertions(+), 39 deletions(-) create mode 100644 gpu_evidence_worker.py diff --git a/Dockerfile b/Dockerfile index 1d5a8f6..bf02bf3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ ENV SOURCE_DATE_EPOCH=${SOURCE_DATE_EPOCH} # Cache dependencies: copy manifests first, then do a dummy build COPY Cargo.toml Cargo.lock ./ RUN mkdir src && echo "fn main() {}" > src/main.rs && echo "" > src/lib.rs \ - && mkdir -p benches && echo "fn main() {}" > benches/hot_path.rs \ + && mkdir -p benches && echo "fn main() {}" > benches/hot_path.rs && echo "fn main() {}" > benches/e2e.rs \ && cargo build --release --locked 2>/dev/null || true \ && rm -rf src benches \ && rm -f target/release/deps/*vllm_proxy_rs* \ @@ -37,8 +37,9 @@ RUN pip install --no-cache-dir nv-attestation-sdk nv-ppcie-verifier WORKDIR /app -# Copy compiled binary from builder +# Copy compiled binary and GPU evidence worker from builder COPY --from=builder /build/target/release/vllm-proxy-rs /app/vllm-proxy-rs +COPY gpu_evidence_worker.py /app/gpu_evidence_worker.py # Bake in git revision for version tracking COPY --chmod=664 .GIT_REV /etc/ diff --git a/gpu_evidence_worker.py b/gpu_evidence_worker.py new file mode 100644 index 0000000..752ea7a --- /dev/null +++ b/gpu_evidence_worker.py @@ -0,0 +1,77 @@ +"""Long-running GPU evidence worker. + +Reads JSON requests from stdin (one per line), collects GPU evidence, +writes JSON responses to stdout (one per line). + +Protocol: + Request: {"nonce": "", "no_gpu_mode": false} + Response: {"ok": true, "evidence": [...]} + Error: {"ok": false, "error": "message"} + +Keeps the Python interpreter, verifier module, and NVML driver initialized +across requests, avoiding ~0.5-2s startup overhead per call. +""" + +import json +import sys +import traceback + +# Import verifier once at startup — this is the expensive part +# (loads shared libraries, may trigger nvmlInit on import). +try: + from verifier import cc_admin + IMPORT_OK = True + IMPORT_ERROR = None +except Exception as e: + IMPORT_OK = False + IMPORT_ERROR = str(e) + + +def collect(nonce_hex: str, no_gpu_mode: bool): + """Collect GPU evidence for the given nonce.""" + if not IMPORT_OK: + return {"ok": False, "error": f"verifier import failed: {IMPORT_ERROR}"} + try: + if no_gpu_mode: + evidence = cc_admin.collect_gpu_evidence_remote( + nonce_hex, no_gpu_mode=True + ) + else: + evidence = cc_admin.collect_gpu_evidence_remote( + nonce_hex, ppcie_mode=False + ) + return {"ok": True, "evidence": evidence} + except Exception as e: + return {"ok": False, "error": f"{type(e).__name__}: {e}"} + + +def main(): + # Signal readiness to the Rust parent. + ready_msg = {"ready": True, "import_ok": IMPORT_OK} + if not IMPORT_OK: + ready_msg["import_error"] = IMPORT_ERROR + sys.stdout.write(json.dumps(ready_msg) + "\n") + sys.stdout.flush() + + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + request = json.loads(line) + nonce_hex = request["nonce"] + no_gpu_mode = request.get("no_gpu_mode", False) + result = collect(nonce_hex, no_gpu_mode) + except json.JSONDecodeError as e: + result = {"ok": False, "error": f"invalid JSON: {e}"} + except KeyError as e: + result = {"ok": False, "error": f"missing field: {e}"} + except Exception as e: + result = {"ok": False, "error": f"{type(e).__name__}: {e}\n{traceback.format_exc()}"} + + sys.stdout.write(json.dumps(result) + "\n") + sys.stdout.flush() + + +if __name__ == "__main__": + main() diff --git a/src/attestation.rs b/src/attestation.rs index 5266c45..4d6c749 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use std::time::Instant; use sha2::{Digest, Sha256}; -use tokio::sync::{OnceCell, RwLock, Semaphore}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::sync::{Mutex, OnceCell, RwLock, Semaphore}; use tracing::{error, info, warn}; use crate::types::AttestationReport; @@ -20,23 +21,188 @@ struct CachedReport { created_at: Instant, } +/// Persistent Python worker process for GPU evidence collection. +/// +/// Keeps the Python interpreter, verifier module imports, and NVML driver +/// initialized across requests, avoiding ~0.5-2s startup overhead per call. +/// Communication is via JSON lines over stdin/stdout pipes. +/// +/// The worker is automatically restarted if it dies. All access is serialized +/// by the gpu_semaphore in AttestationCache (only one evidence collection at a time). +struct GpuEvidenceWorker { + stdin: tokio::process::ChildStdin, + stdout: BufReader, + child: tokio::process::Child, +} + +/// Path to the worker script, resolved relative to the binary. +fn worker_script_path() -> String { + // In Docker: /app/gpu_evidence_worker.py (next to /app/vllm-proxy-rs) + // In dev: ./gpu_evidence_worker.py + let exe_dir = std::env::current_exe() + .ok() + .and_then(|p| p.parent().map(|d| d.to_path_buf())); + if let Some(dir) = exe_dir { + let candidate = dir.join("gpu_evidence_worker.py"); + if candidate.exists() { + return candidate.to_string_lossy().to_string(); + } + } + // Fallback: current directory or CARGO_MANIFEST_DIR for dev + if let Ok(manifest) = std::env::var("CARGO_MANIFEST_DIR") { + let candidate = std::path::Path::new(&manifest).join("gpu_evidence_worker.py"); + if candidate.exists() { + return candidate.to_string_lossy().to_string(); + } + } + "gpu_evidence_worker.py".to_string() +} + +impl GpuEvidenceWorker { + /// Spawn a new persistent Python worker process. + async fn spawn() -> anyhow::Result { + let script_path = worker_script_path(); + info!(script = %script_path, "Spawning GPU evidence worker"); + + let mut child = tokio::process::Command::new("python3") + .arg(&script_path) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true) + .spawn() + .map_err(|e| anyhow::anyhow!("Failed to spawn GPU evidence worker: {e}"))?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| anyhow::anyhow!("Failed to capture worker stdin"))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| anyhow::anyhow!("Failed to capture worker stdout"))?; + let mut stdout = BufReader::new(stdout); + + // Wait for the ready signal (first line of output). + let mut ready_line = String::new(); + tokio::time::timeout( + std::time::Duration::from_secs(30), + stdout.read_line(&mut ready_line), + ) + .await + .map_err(|_| anyhow::anyhow!("GPU evidence worker did not send ready signal within 30s"))? + .map_err(|e| anyhow::anyhow!("Failed to read worker ready signal: {e}"))?; + + let ready: serde_json::Value = serde_json::from_str(ready_line.trim()) + .map_err(|e| anyhow::anyhow!("Worker ready signal is not valid JSON: {e}"))?; + + if ready.get("ready") != Some(&serde_json::Value::Bool(true)) { + anyhow::bail!( + "Worker sent unexpected ready signal: {}", + ready_line.trim() + ); + } + + let import_ok = ready + .get("import_ok") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + if !import_ok { + let err = ready + .get("import_error") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + warn!(error = %err, "GPU evidence worker started but verifier import failed"); + } else { + info!("GPU evidence worker ready"); + } + + Ok(Self { + stdin, + stdout, + child, + }) + } + + /// Send a nonce to the worker and read back GPU evidence. + async fn collect( + &mut self, + nonce_hex: &str, + no_gpu_mode: bool, + ) -> anyhow::Result { + let request = serde_json::json!({ + "nonce": nonce_hex, + "no_gpu_mode": no_gpu_mode, + }); + let mut request_line = serde_json::to_string(&request)?; + request_line.push('\n'); + + // Write request + self.stdin + .write_all(request_line.as_bytes()) + .await + .map_err(|e| anyhow::anyhow!("Failed to write to GPU evidence worker: {e}"))?; + self.stdin.flush().await.map_err(|e| { + anyhow::anyhow!("Failed to flush GPU evidence worker stdin: {e}") + })?; + + // Read response (with timeout) + let mut response_line = String::new(); + tokio::time::timeout( + std::time::Duration::from_secs(60), + self.stdout.read_line(&mut response_line), + ) + .await + .map_err(|_| anyhow::anyhow!("GPU evidence worker timed out after 60s"))? + .map_err(|e| anyhow::anyhow!("Failed to read from GPU evidence worker: {e}"))?; + + if response_line.is_empty() { + anyhow::bail!("GPU evidence worker closed stdout (process may have died)"); + } + + let response: serde_json::Value = serde_json::from_str(response_line.trim()) + .map_err(|e| anyhow::anyhow!("Worker response is not valid JSON: {e}"))?; + + if response.get("ok") == Some(&serde_json::Value::Bool(true)) { + response + .get("evidence") + .cloned() + .ok_or_else(|| anyhow::anyhow!("Worker response missing 'evidence' field")) + } else { + let err = response + .get("error") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + anyhow::bail!("GPU evidence worker error: {err}") + } + } + + /// Check if the worker process is still alive. + fn is_alive(&mut self) -> bool { + matches!(self.child.try_wait(), Ok(None)) + } +} + /// Caches nonce-less attestation reports and serializes GPU evidence collection. /// -/// GPU evidence collection spawns a Python subprocess that calls `nvmlInit()`. -/// Under heavy GPU load, `nvmlInit` can intermittently time out (5s timeout in -/// the NVIDIA verifier library). This cache: +/// GPU evidence collection uses a persistent Python worker process that keeps +/// the verifier module and NVML driver initialized. This cache: /// 1. Serves pre-generated reports for requests without a nonce (the common case). -/// 2. Serializes subprocess calls so only one `nvmlInit` runs at a time. -/// 3. Retries once on GPU evidence failure. +/// 2. Serializes evidence calls so only one `nvmlInit`-using request runs at a time. +/// 3. Retries once on GPU evidence failure (restarting the worker if needed). pub struct AttestationCache { /// Cached reports keyed by (signing_algo, include_tls_fingerprint). reports: RwLock>, - /// Serializes GPU evidence subprocess calls (only 1 at a time). + /// Serializes GPU evidence calls (only 1 at a time). gpu_semaphore: Semaphore, /// Cache TTL in seconds. ttl_secs: u64, /// Cached dstack info (static for the lifetime of the process). dstack_info: OnceCell, + /// Persistent GPU evidence worker process. Protected by Mutex because + /// send/receive must be atomic (one request at a time). The outer Option + /// is None until first use; the worker is lazily spawned. + gpu_worker: Mutex>, } impl AttestationCache { @@ -46,6 +212,7 @@ impl AttestationCache { gpu_semaphore: Semaphore::new(1), ttl_secs, dstack_info: OnceCell::new(), + gpu_worker: Mutex::new(None), } } @@ -61,6 +228,59 @@ impl AttestationCache { .cloned() } + /// Collect GPU evidence using the persistent worker, with auto-restart. + /// + /// Caller must hold the gpu_semaphore permit. + async fn collect_gpu_evidence( + &self, + nonce_hex: &str, + no_gpu_mode: bool, + ) -> anyhow::Result { + let mut worker_guard = self.gpu_worker.lock().await; + + // Ensure we have a live worker + let needs_spawn = match worker_guard.as_mut() { + Some(w) => !w.is_alive(), + None => true, + }; + if needs_spawn { + match GpuEvidenceWorker::spawn().await { + Ok(w) => { + *worker_guard = Some(w); + } + Err(e) => { + warn!(error = %e, "Failed to spawn GPU evidence worker, falling back to subprocess"); + *worker_guard = None; + // Fall back to one-shot subprocess + return collect_gpu_evidence_subprocess(nonce_hex, no_gpu_mode).await; + } + } + } + + let worker = worker_guard.as_mut().unwrap(); + match worker.collect(nonce_hex, no_gpu_mode).await { + Ok(evidence) => Ok(evidence), + Err(first_err) => { + warn!(error = %first_err, "GPU evidence worker failed, restarting and retrying"); + metrics::counter!("gpu_evidence_retries_total").increment(1); + + // Kill old worker, spawn fresh one + *worker_guard = None; + match GpuEvidenceWorker::spawn().await { + Ok(mut new_worker) => { + let result = new_worker.collect(nonce_hex, no_gpu_mode).await; + *worker_guard = Some(new_worker); + result + } + Err(spawn_err) => { + warn!(error = %spawn_err, "Worker restart failed, falling back to subprocess"); + collect_gpu_evidence_subprocess(nonce_hex, no_gpu_mode).await + } + } + } + } + } + /// Get a cached report if it exists and is fresh. pub async fn get( &self, @@ -103,7 +323,7 @@ impl AttestationCache { ); } - /// Acquire the GPU evidence semaphore (serializes subprocess calls). + /// Acquire the GPU evidence semaphore (serializes GPU evidence calls). pub async fn acquire_gpu_permit(&self) -> tokio::sync::SemaphorePermit<'_> { self.gpu_semaphore .acquire() @@ -273,8 +493,11 @@ fn parse_nonce(nonce: Option<&str>) -> Result<[u8; 32], AttestationError> { } } -/// Collect GPU evidence via Python subprocess (single attempt). -async fn collect_gpu_evidence_once( +/// Fallback: collect GPU evidence via one-shot Python subprocess. +/// +/// Used when the persistent worker cannot be spawned (e.g., script not found, +/// Python not installed). Slower due to Python startup + module import overhead. +async fn collect_gpu_evidence_subprocess( nonce_hex: &str, no_gpu_mode: bool, ) -> anyhow::Result { @@ -282,7 +505,6 @@ async fn collect_gpu_evidence_once( info!("GPU evidence no-GPU mode enabled; using canned evidence"); } - // Build a small Python script that collects GPU evidence. // ppcie_mode=False is required on PPCIE systems (the default True triggers a // "standalone mode not supported" error). Safe on non-PPCIE systems too. let script = if no_gpu_mode { @@ -336,28 +558,6 @@ print(json.dumps(evidence)) Ok(evidence) } -/// Collect GPU evidence with one retry on failure. -/// -/// nvmlInit can intermittently time out under heavy GPU load. A single retry -/// after a short delay often succeeds once the driver lock is released. -async fn collect_gpu_evidence( - nonce_hex: &str, - no_gpu_mode: bool, -) -> anyhow::Result { - match collect_gpu_evidence_once(nonce_hex, no_gpu_mode).await { - Ok(evidence) => Ok(evidence), - Err(first_err) => { - warn!( - error = %first_err, - "GPU evidence collection failed, retrying after 2s" - ); - metrics::counter!("gpu_evidence_retries_total").increment(1); - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - collect_gpu_evidence_once(nonce_hex, no_gpu_mode).await - } - } -} - /// Build NVIDIA payload JSON. fn build_nvidia_payload(nonce_hex: &str, evidences: &serde_json::Value) -> String { serde_json::json!({ @@ -519,7 +719,7 @@ async fn generate_attestation_inner( // Run TDX quote and GPU evidence collection in parallel. // These are independent: TDX quote talks to dstack via Unix socket, - // GPU evidence spawns a Python subprocess calling NVML. + // GPU evidence uses the persistent Python worker (or subprocess fallback). let gpu_no_hw_mode = params.gpu_no_hw_mode; let nonce_hex_clone = nonce_hex.clone(); let (quote_result, gpu_evidence) = tokio::try_join!( @@ -531,9 +731,16 @@ async fn generate_attestation_inner( .map_err(AttestationError::Internal) }, async { - collect_gpu_evidence(&nonce_hex_clone, gpu_no_hw_mode) - .await - .map_err(AttestationError::Internal) + if let Some(cache) = cache { + cache + .collect_gpu_evidence(&nonce_hex_clone, gpu_no_hw_mode) + .await + .map_err(AttestationError::Internal) + } else { + collect_gpu_evidence_subprocess(&nonce_hex_clone, gpu_no_hw_mode) + .await + .map_err(AttestationError::Internal) + } }, )?; From 059114d6322e6b2c0f6aa747b6da3fd91cdba06d Mon Sep 17 00:00:00 2001 From: Evrard-Nil Daillet Date: Mon, 23 Mar 2026 14:48:31 -0700 Subject: [PATCH 4/7] style: cargo fmt --- .gitignore | 1 + benches/e2e.rs | 37 ++++++++++++++++--------------------- src/attestation.rs | 20 +++++++++++--------- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index ac65fd2..bd0c306 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target .GIT_REV oci.tar +.DS_Store diff --git a/benches/e2e.rs b/benches/e2e.rs index 69dbf5f..cf37ac2 100644 --- a/benches/e2e.rs +++ b/benches/e2e.rs @@ -20,15 +20,13 @@ use vllm_proxy_rs::*; // ── Test keys (same as integration tests) ── const ECDSA_KEY: [u8; 32] = [ - 0xac, 0x09, 0x74, 0xbe, 0xc3, 0x9a, 0x17, 0xe3, 0x6b, 0xa4, 0xa6, 0xb4, 0xd2, 0x38, 0xff, - 0x94, 0x4b, 0xac, 0xb3, 0x5e, 0x5d, 0xc4, 0xaf, 0x0f, 0x33, 0x47, 0xe5, 0x87, 0x31, 0x79, - 0x67, 0x0f, + 0xac, 0x09, 0x74, 0xbe, 0xc3, 0x9a, 0x17, 0xe3, 0x6b, 0xa4, 0xa6, 0xb4, 0xd2, 0x38, 0xff, 0x94, + 0x4b, 0xac, 0xb3, 0x5e, 0x5d, 0xc4, 0xaf, 0x0f, 0x33, 0x47, 0xe5, 0x87, 0x31, 0x79, 0x67, 0x0f, ]; const ED25519_KEY: [u8; 32] = [ - 0x9d, 0x61, 0xb1, 0x9d, 0xef, 0xfd, 0x5a, 0x60, 0xba, 0x84, 0x4a, 0xf4, 0x92, 0xec, 0x2c, - 0xc4, 0x44, 0x49, 0xc5, 0x69, 0x7b, 0x32, 0x69, 0x19, 0x70, 0x3b, 0xac, 0x03, 0x1c, 0xae, - 0x7f, 0x60, + 0x9d, 0x61, 0xb1, 0x9d, 0xef, 0xfd, 0x5a, 0x60, 0xba, 0x84, 0x4a, 0xf4, 0x92, 0xec, 0x2c, 0xc4, + 0x44, 0x49, 0xc5, 0x69, 0x7b, 0x32, 0x69, 0x19, 0x70, 0x3b, 0xac, 0x03, 0x1c, 0xae, 0x7f, 0x60, ]; fn build_test_app(mock_url: &str) -> axum::Router { @@ -161,7 +159,10 @@ fn make_streaming_response(id: &str, num_chunks: usize) -> String { "finish_reason": null }] }); - body.push_str(&format!("data: {}\n\n", serde_json::to_string(&chunk).unwrap())); + body.push_str(&format!( + "data: {}\n\n", + serde_json::to_string(&chunk).unwrap() + )); } // Final chunk with usage let final_chunk = serde_json::json!({ @@ -213,21 +214,18 @@ fn bench_attestation_cache_operations(c: &mut Criterion) { rt.block_on(cache.set("ecdsa", false, report.clone())); group.bench_function("cache_hit", |b| { - b.to_async(&rt).iter(|| async { - black_box(cache.get("ecdsa", false).await) - }) + b.to_async(&rt) + .iter(|| async { black_box(cache.get("ecdsa", false).await) }) }); group.bench_function("cache_miss", |b| { - b.to_async(&rt).iter(|| async { - black_box(cache.get("ed25519", true).await) - }) + b.to_async(&rt) + .iter(|| async { black_box(cache.get("ed25519", true).await) }) }); group.bench_function("cache_set", |b| { - b.to_async(&rt).iter(|| async { - cache.set("ecdsa", false, report.clone()).await - }) + b.to_async(&rt) + .iter(|| async { cache.set("ecdsa", false, report.clone()).await }) }); group.bench_function("semaphore_acquire_uncontended", |b| { @@ -410,9 +408,7 @@ fn bench_request_body_processing(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("sha256_hash", msg_count), body_bytes, - |b, data| { - b.iter(|| hex::encode(Sha256::digest(black_box(data)))) - }, + |b, data| b.iter(|| hex::encode(Sha256::digest(black_box(data)))), ); group.bench_with_input( @@ -601,8 +597,7 @@ fn bench_json_body_round_trip(c: &mut Criterion) { group.bench_function("parse_modify_reserialize", |b| { b.iter(|| { - let mut v: serde_json::Value = - serde_json::from_str(black_box(&body_str)).unwrap(); + let mut v: serde_json::Value = serde_json::from_str(black_box(&body_str)).unwrap(); // strip_empty_tool_calls equivalent if let Some(messages) = v.get_mut("messages").and_then(|m| m.as_array_mut()) { diff --git a/src/attestation.rs b/src/attestation.rs index 4d6c749..195bd9e 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -97,10 +97,7 @@ impl GpuEvidenceWorker { .map_err(|e| anyhow::anyhow!("Worker ready signal is not valid JSON: {e}"))?; if ready.get("ready") != Some(&serde_json::Value::Bool(true)) { - anyhow::bail!( - "Worker sent unexpected ready signal: {}", - ready_line.trim() - ); + anyhow::bail!("Worker sent unexpected ready signal: {}", ready_line.trim()); } let import_ok = ready @@ -142,9 +139,10 @@ impl GpuEvidenceWorker { .write_all(request_line.as_bytes()) .await .map_err(|e| anyhow::anyhow!("Failed to write to GPU evidence worker: {e}"))?; - self.stdin.flush().await.map_err(|e| { - anyhow::anyhow!("Failed to flush GPU evidence worker stdin: {e}") - })?; + self.stdin + .flush() + .await + .map_err(|e| anyhow::anyhow!("Failed to flush GPU evidence worker stdin: {e}"))?; // Read response (with timeout) let mut response_line = String::new(); @@ -750,11 +748,15 @@ async fn generate_attestation_inner( // dstack info is static — use cached value if available. let info_value = if let Some(cache) = cache { - cache.get_dstack_info().await.map_err(AttestationError::Internal)? + cache + .get_dstack_info() + .await + .map_err(AttestationError::Internal)? } else { let client = dstack_sdk::dstack_client::DstackClient::new(None); let info = client.info().await.map_err(AttestationError::Internal)?; - serde_json::to_value(&info).map_err(|e| AttestationError::Internal(anyhow::Error::from(e)))? + serde_json::to_value(&info) + .map_err(|e| AttestationError::Internal(anyhow::Error::from(e)))? }; Ok(AttestationReport { From d23490372f954a7495e66169f98b1d6b6884387f Mon Sep 17 00:00:00 2001 From: Evrard-Nil Daillet Date: Mon, 23 Mar 2026 14:59:57 -0700 Subject: [PATCH 5/7] add live endpoint benchmark script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Async benchmark for attestation and completion endpoints with configurable concurrency and duration. Tests: - Attestation cached (no nonce) — measures cache hit path - Attestation fresh (with nonce) — forces GPU evidence + TDX quote - Chat completion (non-streaming and streaming) Reports p50/p90/p99 latencies, throughput, error rates. Usage: uv run scripts/bench_live.py -c 20 -d 60 --- scripts/bench_live.py | 253 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100755 scripts/bench_live.py diff --git a/scripts/bench_live.py b/scripts/bench_live.py new file mode 100755 index 0000000..42b055b --- /dev/null +++ b/scripts/bench_live.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10" +# dependencies = ["aiohttp"] +# /// +"""Benchmark attestation and completion endpoints on a live inference-proxy. + +Usage: + uv run scripts/bench_live.py https://glm-5-fp8.completions.near.ai + uv run scripts/bench_live.py http://160.72.54.171:8000 --token secret123 + uv run scripts/bench_live.py http://160.72.54.171:8000 --concurrency 20 --duration 60 + +Tests: + 1. Attestation (no nonce) — should hit cache after first call + 2. Attestation (with nonce) — forces fresh generation every time + 3. Chat completion (non-streaming, short) + 4. Chat completion (streaming, short) +""" + +import argparse +import asyncio +import json +import statistics +import time +from dataclasses import dataclass, field + +import aiohttp + + +@dataclass +class Stats: + name: str + latencies: list[float] = field(default_factory=list) + errors: int = 0 + status_codes: dict[int, int] = field(default_factory=dict) + + def record(self, latency: float, status: int): + self.latencies.append(latency) + self.status_codes[status] = self.status_codes.get(status, 0) + 1 + if status >= 400: + self.errors += 1 + + def report(self) -> str: + if not self.latencies: + return f" {self.name}: no completed requests" + n = len(self.latencies) + ok = n - self.errors + s = sorted(self.latencies) + lines = [ + f" {self.name}:", + f" requests: {n} ({ok} ok, {self.errors} errors)", + f" latency: p50={s[n//2]*1000:.0f}ms p90={s[int(n*0.9)]*1000:.0f}ms p99={s[int(n*0.99)]*1000:.0f}ms", + f" min/avg/max: {s[0]*1000:.0f}/{statistics.mean(s)*1000:.0f}/{s[-1]*1000:.0f} ms", + f" throughput: {n / (s[-1] - s[0] + 0.001):.1f} req/s (wall-clock)" if n > 1 else "", + f" status codes: {dict(sorted(self.status_codes.items()))}", + ] + return "\n".join(l for l in lines if l) + + +async def run_bench( + session: aiohttp.ClientSession, + stats: Stats, + make_request, + concurrency: int, + duration: float, +): + """Run a benchmark: spawn `concurrency` workers hitting `make_request` for `duration` seconds.""" + stop = asyncio.Event() + + async def worker(): + while not stop.is_set(): + try: + t0 = time.monotonic() + method, url, kwargs = make_request() + async with session.request(method, url, **kwargs) as resp: + # Consume body to measure full latency + await resp.read() + elapsed = time.monotonic() - t0 + stats.record(elapsed, resp.status) + except asyncio.CancelledError: + break + except Exception as e: + elapsed = time.monotonic() - t0 + stats.record(elapsed, 0) + stats.errors += 1 + + tasks = [asyncio.create_task(worker()) for _ in range(concurrency)] + + await asyncio.sleep(duration) + stop.set() + + # Give workers a moment to finish in-flight requests + await asyncio.sleep(0.5) + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + +async def main(): + parser = argparse.ArgumentParser(description="Benchmark live inference-proxy endpoints") + parser.add_argument("endpoint", help="Base URL (e.g. https://glm-5-fp8.completions.near.ai)") + parser.add_argument("--token", default="secret123", help="Bearer token (default: secret123)") + parser.add_argument("--concurrency", "-c", type=int, default=20, help="Concurrent requests (default: 20)") + parser.add_argument("--duration", "-d", type=float, default=60, help="Duration in seconds (default: 60)") + parser.add_argument("--model", "-m", default=None, help="Model name for completions (auto-detected if not set)") + parser.add_argument("--skip-completions", action="store_true", help="Skip completion benchmarks") + parser.add_argument("--skip-attestation", action="store_true", help="Skip attestation benchmarks") + args = parser.parse_args() + + base = args.endpoint.rstrip("/") + headers = {"Authorization": f"Bearer {args.token}"} + + connector = aiohttp.TCPConnector(limit=args.concurrency + 5, ssl=False) + timeout = aiohttp.ClientTimeout(total=120) + + async with aiohttp.ClientSession(headers=headers, connector=connector, timeout=timeout) as session: + # Detect model name + model = args.model + if not model: + try: + async with session.get(f"{base}/v1/models") as resp: + if resp.status == 200: + data = await resp.json() + models = data.get("data", []) + if models: + model = models[0].get("id", "unknown") + print(f"Detected model: {model}") + except Exception: + pass + if not model: + model = "unknown" + + # Warmup: single request to each endpoint + print(f"\nEndpoint: {base}") + print(f"Concurrency: {args.concurrency}") + print(f"Duration: {args.duration}s per test") + print(f"Model: {model}") + print() + + # ── 1. Attestation (no nonce) — should be cached ── + if not args.skip_attestation: + print("Warming up attestation (no nonce)...") + try: + async with session.get(f"{base}/v1/attestation/report") as resp: + body = await resp.read() + print(f" warmup status: {resp.status} ({len(body)} bytes)") + except Exception as e: + print(f" warmup failed: {e}") + print(" (continuing anyway — benchmark will show errors)") + print() + + stats_att_cached = Stats("attestation_cached (no nonce)") + print(f"Running attestation (cached) benchmark ({args.concurrency}x for {args.duration}s)...") + + def make_att_cached(): + return ("GET", f"{base}/v1/attestation/report", {}) + + await run_bench(session, stats_att_cached, make_att_cached, args.concurrency, args.duration) + print(stats_att_cached.report()) + print() + + # ── 2. Attestation (with nonce) — forces fresh generation ── + stats_att_nonce = Stats("attestation_fresh (with nonce)") + print(f"Running attestation (fresh/nonce) benchmark ({args.concurrency}x for {args.duration}s)...") + + nonce_counter = 0 + def make_att_nonce(): + nonlocal nonce_counter + nonce_counter += 1 + # Each request gets a unique nonce → forces fresh GPU evidence + TDX quote + nonce = f"{nonce_counter:064x}" + return ("GET", f"{base}/v1/attestation/report?nonce={nonce}", {}) + + await run_bench(session, stats_att_nonce, make_att_nonce, args.concurrency, args.duration) + print(stats_att_nonce.report()) + print() + + # ── 3. Chat completion (non-streaming) ── + if not args.skip_completions: + stats_chat = Stats("chat_completion (non-streaming)") + print(f"Running chat completion (non-streaming) benchmark ({args.concurrency}x for {args.duration}s)...") + + def make_chat(): + body = { + "model": model, + "messages": [{"role": "user", "content": "Say 'hello' and nothing else."}], + "max_tokens": 5, + "stream": False, + } + return ("POST", f"{base}/v1/chat/completions", {"json": body}) + + await run_bench(session, stats_chat, make_chat, args.concurrency, args.duration) + print(stats_chat.report()) + print() + + # ── 4. Chat completion (streaming) ── + stats_stream = Stats("chat_completion (streaming)") + print(f"Running chat completion (streaming) benchmark ({args.concurrency}x for {args.duration}s)...") + + async def stream_worker_fn(): + """Custom worker that reads SSE stream to completion.""" + body = { + "model": model, + "messages": [{"role": "user", "content": "Say 'hello' and nothing else."}], + "max_tokens": 5, + "stream": True, + } + t0 = time.monotonic() + try: + async with session.post(f"{base}/v1/chat/completions", json=body) as resp: + async for _ in resp.content: + pass + elapsed = time.monotonic() - t0 + stats_stream.record(elapsed, resp.status) + except asyncio.CancelledError: + raise + except Exception: + elapsed = time.monotonic() - t0 + stats_stream.record(elapsed, 0) + stats_stream.errors += 1 + + stop = asyncio.Event() + + async def streaming_loop(): + while not stop.is_set(): + await stream_worker_fn() + + tasks = [asyncio.create_task(streaming_loop()) for _ in range(args.concurrency)] + await asyncio.sleep(args.duration) + stop.set() + await asyncio.sleep(0.5) + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + print(stats_stream.report()) + print() + + # ── Summary ── + print("=" * 60) + print(f"SUMMARY — {base}") + print(f" concurrency={args.concurrency} duration={args.duration}s") + print("=" * 60) + if not args.skip_attestation: + print(stats_att_cached.report()) + print(stats_att_nonce.report()) + if not args.skip_completions: + print(stats_chat.report()) + print(stats_stream.report()) + + +if __name__ == "__main__": + asyncio.run(main()) From acb3691b006257cc34f7e905a77c970ce7bf6560 Mon Sep 17 00:00:00 2001 From: Evrard-Nil Daillet Date: Mon, 23 Mar 2026 16:40:21 -0700 Subject: [PATCH 6/7] fix: redirect stdout in GPU evidence worker to prevent protocol corruption The NVIDIA verifier library (cc_admin) prints info messages directly to stdout (e.g. "Number of GPUs available : 8"), corrupting the JSON line protocol. Fix by dup'ing the real stdout fd at startup, redirecting sys.stdout to stderr, and using the saved fd for protocol messages. Also fix fallback: when the worker spawns but evidence collection fails on both attempts, now falls back to subprocess instead of returning error. --- gpu_evidence_worker.py | 31 ++++++++++++++++++++++++++----- src/attestation.rs | 18 ++++++++++++------ 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/gpu_evidence_worker.py b/gpu_evidence_worker.py index 752ea7a..9fa9b39 100644 --- a/gpu_evidence_worker.py +++ b/gpu_evidence_worker.py @@ -10,12 +10,28 @@ Keeps the Python interpreter, verifier module, and NVML driver initialized across requests, avoiding ~0.5-2s startup overhead per call. + +IMPORTANT: The NVIDIA verifier library (cc_admin) prints info messages +directly to stdout (e.g. "Number of GPUs available : 8"). We redirect +stdout to /dev/null during evidence collection and use a saved reference +to the real stdout for our JSON protocol. """ +import io import json +import os import sys import traceback +# Save the real stdout fd before anything can pollute it. +# We dup the fd so even if sys.stdout is replaced, we can still write. +_real_stdout_fd = os.dup(sys.stdout.fileno()) +_real_stdout = os.fdopen(_real_stdout_fd, "w", buffering=1) # line-buffered + +# Redirect sys.stdout to stderr so any library prints go to stderr +# (which the Rust parent reads separately / ignores). +sys.stdout = sys.stderr + # Import verifier once at startup — this is the expensive part # (loads shared libraries, may trigger nvmlInit on import). try: @@ -27,6 +43,12 @@ IMPORT_ERROR = str(e) +def _write_response(obj): + """Write a JSON response to the real stdout (not the redirected one).""" + _real_stdout.write(json.dumps(obj) + "\n") + _real_stdout.flush() + + def collect(nonce_hex: str, no_gpu_mode: bool): """Collect GPU evidence for the given nonce.""" if not IMPORT_OK: @@ -46,13 +68,13 @@ def collect(nonce_hex: str, no_gpu_mode: bool): def main(): - # Signal readiness to the Rust parent. + # Signal readiness to the Rust parent on the real stdout. ready_msg = {"ready": True, "import_ok": IMPORT_OK} if not IMPORT_OK: ready_msg["import_error"] = IMPORT_ERROR - sys.stdout.write(json.dumps(ready_msg) + "\n") - sys.stdout.flush() + _write_response(ready_msg) + # Read requests from stdin (which is NOT redirected). for line in sys.stdin: line = line.strip() if not line: @@ -69,8 +91,7 @@ def main(): except Exception as e: result = {"ok": False, "error": f"{type(e).__name__}: {e}\n{traceback.format_exc()}"} - sys.stdout.write(json.dumps(result) + "\n") - sys.stdout.flush() + _write_response(result) if __name__ == "__main__": diff --git a/src/attestation.rs b/src/attestation.rs index 195bd9e..a55ecc7 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -262,14 +262,20 @@ impl AttestationCache { warn!(error = %first_err, "GPU evidence worker failed, restarting and retrying"); metrics::counter!("gpu_evidence_retries_total").increment(1); - // Kill old worker, spawn fresh one + // Kill old worker, spawn fresh one and retry *worker_guard = None; match GpuEvidenceWorker::spawn().await { - Ok(mut new_worker) => { - let result = new_worker.collect(nonce_hex, no_gpu_mode).await; - *worker_guard = Some(new_worker); - result - } + Ok(mut new_worker) => match new_worker.collect(nonce_hex, no_gpu_mode).await { + Ok(evidence) => { + *worker_guard = Some(new_worker); + Ok(evidence) + } + Err(retry_err) => { + warn!(error = %retry_err, "Worker retry also failed, falling back to subprocess"); + *worker_guard = None; + collect_gpu_evidence_subprocess(nonce_hex, no_gpu_mode).await + } + }, Err(spawn_err) => { warn!(error = %spawn_err, "Worker restart failed, falling back to subprocess"); collect_gpu_evidence_subprocess(nonce_hex, no_gpu_mode).await From 0ebf876d9b2cee2a41ffbcbf8ce55316a6b0e60e Mon Sep 17 00:00:00 2001 From: Evrard-Nil Daillet Date: Mon, 23 Mar 2026 18:27:46 -0700 Subject: [PATCH 7/7] perf: cache serialized bytes, narrow GPU semaphore, remove double serialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three improvements for attestation latency: 1. Cache pre-serialized JSON bytes instead of structs - Cache hit now returns bytes::Bytes directly (zero-copy clone) - Eliminates report.clone() + AttestationResponse construction + serde_json::to_value + Json() serialization on every cached request - 297KB response was being fully re-serialized on every hit 2. Remove wide semaphore — use worker Mutex only - Previously: semaphore wrapped entire generate_attestation_inner() (TDX quote + GPU evidence + dstack info) - Now: only GPU evidence is serialized (via worker Mutex) - Concurrent fresh attestation requests can overlap TDX quotes 3. Return AttestationResult enum from generate_attestation - CachedBytes: pre-serialized bytes sent directly to client - Fresh: report that needs one-time serialization - Route handler branches on variant, avoids redundant work --- benches/e2e.rs | 7 -- src/attestation.rs | 146 ++++++++++++++++++++------------------ src/routes/attestation.rs | 27 +++++-- 3 files changed, 98 insertions(+), 82 deletions(-) diff --git a/benches/e2e.rs b/benches/e2e.rs index cf37ac2..1254385 100644 --- a/benches/e2e.rs +++ b/benches/e2e.rs @@ -228,13 +228,6 @@ fn bench_attestation_cache_operations(c: &mut Criterion) { .iter(|| async { cache.set("ecdsa", false, report.clone()).await }) }); - group.bench_function("semaphore_acquire_uncontended", |b| { - b.to_async(&rt).iter(|| async { - let permit = cache.acquire_gpu_permit().await; - drop(black_box(permit)); - }) - }); - group.finish(); } diff --git a/src/attestation.rs b/src/attestation.rs index a55ecc7..6d0fa04 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -4,7 +4,7 @@ use std::time::Instant; use sha2::{Digest, Sha256}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tokio::sync::{Mutex, OnceCell, RwLock, Semaphore}; +use tokio::sync::{Mutex, OnceCell, RwLock}; use tracing::{error, info, warn}; use crate::types::AttestationReport; @@ -17,6 +17,10 @@ struct AttestationCacheKey { } struct CachedReport { + /// Pre-serialized JSON bytes of the full AttestationResponse. + /// Avoids re-serializing 297KB on every cache hit. + response_bytes: bytes::Bytes, + /// The report struct, needed for background refresh to build new responses. report: AttestationReport, created_at: Instant, } @@ -191,15 +195,13 @@ impl GpuEvidenceWorker { pub struct AttestationCache { /// Cached reports keyed by (signing_algo, include_tls_fingerprint). reports: RwLock>, - /// Serializes GPU evidence calls (only 1 at a time). - gpu_semaphore: Semaphore, /// Cache TTL in seconds. ttl_secs: u64, /// Cached dstack info (static for the lifetime of the process). dstack_info: OnceCell, - /// Persistent GPU evidence worker process. Protected by Mutex because - /// send/receive must be atomic (one request at a time). The outer Option - /// is None until first use; the worker is lazily spawned. + /// Persistent GPU evidence worker process. Protected by Mutex which also + /// serializes GPU evidence calls (only one NVML call at a time). + /// The outer Option is None until first use; the worker is lazily spawned. gpu_worker: Mutex>, } @@ -207,7 +209,6 @@ impl AttestationCache { pub fn new(ttl_secs: u64) -> Self { Self { reports: RwLock::new(HashMap::new()), - gpu_semaphore: Semaphore::new(1), ttl_secs, dstack_info: OnceCell::new(), gpu_worker: Mutex::new(None), @@ -285,7 +286,29 @@ impl AttestationCache { } } - /// Get a cached report if it exists and is fresh. + /// Get pre-serialized JSON bytes for a cached report, if fresh. + pub async fn get_bytes( + &self, + signing_algo: &str, + include_tls_fingerprint: bool, + ) -> Option { + let key = AttestationCacheKey { + signing_algo: signing_algo.to_string(), + include_tls_fingerprint, + }; + let reports = self.reports.read().await; + if let Some(cached) = reports.get(&key) { + if cached.created_at.elapsed().as_secs() < self.ttl_secs { + metrics::counter!("attestation_cache_hits_total").increment(1); + return Some(cached.response_bytes.clone()); + } + } + metrics::counter!("attestation_cache_misses_total").increment(1); + None + } + + /// Get a cached report struct if it exists and is fresh. + /// Used by background refresh to check if a refresh is needed. pub async fn get( &self, signing_algo: &str, @@ -298,21 +321,30 @@ impl AttestationCache { let reports = self.reports.read().await; if let Some(cached) = reports.get(&key) { if cached.created_at.elapsed().as_secs() < self.ttl_secs { - metrics::counter!("attestation_cache_hits_total").increment(1); return Some(cached.report.clone()); } } - metrics::counter!("attestation_cache_misses_total").increment(1); None } - /// Store a report in the cache. + /// Store a report in the cache, pre-serializing to JSON bytes. pub async fn set( &self, signing_algo: &str, include_tls_fingerprint: bool, report: AttestationReport, ) { + let response = crate::types::AttestationResponse { + report: report.clone(), + all_attestations: vec![report.clone()], + }; + let response_bytes = match serde_json::to_vec(&response) { + Ok(bytes) => bytes::Bytes::from(bytes), + Err(e) => { + error!(error = %e, "Failed to serialize attestation response for cache"); + return; + } + }; let key = AttestationCacheKey { signing_algo: signing_algo.to_string(), include_tls_fingerprint, @@ -321,19 +353,12 @@ impl AttestationCache { reports.insert( key, CachedReport { + response_bytes, report, created_at: Instant::now(), }, ); } - - /// Acquire the GPU evidence semaphore (serializes GPU evidence calls). - pub async fn acquire_gpu_permit(&self) -> tokio::sync::SemaphorePermit<'_> { - self.gpu_semaphore - .acquire() - .await - .expect("semaphore closed") - } } /// Spawn a background task that periodically refreshes cached attestation reports. @@ -366,7 +391,7 @@ pub fn spawn_cache_refresh_task( }; // Refresh without TLS fingerprint (most common). - let _permit = cache.acquire_gpu_permit().await; + // GPU evidence serialization is handled by the worker Mutex. match generate_attestation_inner( AttestationParams { model_name: &model_name, @@ -390,11 +415,9 @@ pub fn spawn_cache_refresh_task( warn!(algo, error = %e, "Background attestation cache refresh failed"); } } - drop(_permit); // Also refresh with TLS fingerprint if configured. if let Some(ref fp) = tls_cert_fingerprint { - let _permit = cache.acquire_gpu_permit().await; match generate_attestation_inner( AttestationParams { model_name: &model_name, @@ -417,7 +440,6 @@ pub fn spawn_cache_refresh_task( warn!(algo, error = %e, "Background attestation cache refresh (with TLS) failed"); } } - drop(_permit); } } @@ -779,48 +801,50 @@ async fn generate_attestation_inner( }) } +/// Result of attestation generation — either pre-serialized cached bytes or a fresh report. +pub enum AttestationResult { + /// Cache hit: pre-serialized JSON bytes ready to send. + CachedBytes(bytes::Bytes), + /// Fresh report that needs serialization. + Fresh(Box), +} + /// Generate an attestation report, using the cache for nonce-less requests. /// /// When a caller provides a nonce, the GPU evidence and TDX quote are /// cryptographically bound to that nonce, so we must generate fresh. /// When no nonce is provided, we serve a cached report (which contains its /// own randomly-generated nonce) — the caller accepts whatever nonce we return. +/// +/// GPU evidence collection is serialized by the worker Mutex (NVML constraint), +/// but TDX quotes and dstack info calls run concurrently with other requests. pub async fn generate_attestation( params: AttestationParams<'_>, cache: Option<&AttestationCache>, -) -> Result { +) -> Result { let is_nonceless = params.nonce.is_none(); let include_tls = params.tls_cert_fingerprint.is_some(); let signing_algo = params.signing_algo.to_string(); - // For nonce-less requests, try the cache first. + // For nonce-less requests, try the cache first (returns pre-serialized bytes). if is_nonceless { if let Some(cache) = cache { - if let Some(report) = cache.get(&signing_algo, include_tls).await { - return Ok(report); + if let Some(bytes) = cache.get_bytes(&signing_algo, include_tls).await { + return Ok(AttestationResult::CachedBytes(bytes)); } } } - // Generate fresh report. Acquire semaphore to serialize GPU evidence calls. - let report = if let Some(cache) = cache { - let _permit = cache.acquire_gpu_permit().await; - // Double-check cache after acquiring permit (another request may have filled it). - if is_nonceless { - if let Some(report) = cache.get(&signing_algo, include_tls).await { - return Ok(report); - } - } - let report = generate_attestation_inner(params, Some(cache)).await?; - if is_nonceless { + // Generate fresh report. GPU evidence is serialized by the worker Mutex, + // but TDX quote runs concurrently. + let report = generate_attestation_inner(params, cache).await?; + if is_nonceless { + if let Some(cache) = cache { cache.set(&signing_algo, include_tls, report.clone()).await; } - report - } else { - generate_attestation_inner(params, None).await? - }; + } - Ok(report) + Ok(AttestationResult::Fresh(Box::new(report))) } #[cfg(test)] @@ -1007,30 +1031,16 @@ mod tests { } #[tokio::test] - async fn test_gpu_semaphore_serializes() { - let cache = Arc::new(AttestationCache::new(300)); - - // Hold the first permit. - let permit1 = cache.acquire_gpu_permit().await; - - // Spawn a task that tries to acquire the semaphore while we hold it. - let cache2 = cache.clone(); - let mut handle = tokio::spawn(async move { - let _permit = cache2.acquire_gpu_permit().await; - }); - - // The second acquire should block (not complete within 50ms). - let result = tokio::time::timeout(std::time::Duration::from_millis(50), &mut handle).await; - assert!( - result.is_err(), - "second acquire should block while first permit is held" - ); - - // Release the first permit — the second task should now complete. - drop(permit1); - tokio::time::timeout(std::time::Duration::from_millis(50), handle) - .await - .expect("second acquire should complete after first permit is dropped") - .expect("task should not panic"); + async fn test_cache_get_bytes_returns_preserialized() { + let cache = AttestationCache::new(300); + let report = make_test_report("ecdsa", "aabb"); + cache.set("ecdsa", false, report).await; + + let bytes = cache.get_bytes("ecdsa", false).await; + assert!(bytes.is_some()); + let parsed: serde_json::Value = + serde_json::from_slice(&bytes.unwrap()).expect("cached bytes should be valid JSON"); + assert_eq!(parsed["request_nonce"], "aabb"); + assert!(parsed["all_attestations"].is_array()); } } diff --git a/src/routes/attestation.rs b/src/routes/attestation.rs index df812fc..a989237 100644 --- a/src/routes/attestation.rs +++ b/src/routes/attestation.rs @@ -1,8 +1,10 @@ use axum::extract::{Query, State}; +use axum::http::StatusCode; use axum::response::IntoResponse; use axum::Json; use serde::Deserialize; +use crate::attestation::AttestationResult; use crate::error::AppError; use crate::types::AttestationResponse; use crate::AppState; @@ -53,7 +55,7 @@ pub async fn attestation_report( } } - let report = crate::attestation::generate_attestation( + let result = crate::attestation::generate_attestation( crate::attestation::AttestationParams { model_name: &state.config.model_name, signing_address: &signing_address, @@ -76,10 +78,21 @@ pub async fn attestation_report( crate::attestation::AttestationError::Internal(e) => AppError::Internal(e), })?; - let response = AttestationResponse { - report: report.clone(), - all_attestations: vec![report], - }; - - Ok(Json(serde_json::to_value(response).unwrap())) + match result { + // Cache hit: return pre-serialized bytes directly (no clone, no serialization). + AttestationResult::CachedBytes(bytes) => Ok(( + StatusCode::OK, + [("content-type", "application/json")], + bytes, + ) + .into_response()), + // Fresh report: serialize once. + AttestationResult::Fresh(report) => { + let response = AttestationResponse { + report: report.as_ref().clone(), + all_attestations: vec![*report], + }; + Ok(Json(serde_json::to_value(response).unwrap()).into_response()) + } + } }