From ff61ce1ef20faaae20ad3b620154af879b6b267e Mon Sep 17 00:00:00 2001 From: Josiah Bull Date: Tue, 19 May 2026 18:42:37 +1200 Subject: [PATCH] feat: give more context to middleware. --- crates/partly-proxy-lib/src/context.rs | 44 +++++ crates/partly-proxy-lib/src/lib.rs | 2 +- crates/partly-proxy-lib/src/listener.rs | 14 +- crates/partly-proxy-lib/tests/replay.rs | 204 +++++++++++++++++++++++- 4 files changed, 258 insertions(+), 6 deletions(-) diff --git a/crates/partly-proxy-lib/src/context.rs b/crates/partly-proxy-lib/src/context.rs index aaba947..2e173e5 100644 --- a/crates/partly-proxy-lib/src/context.rs +++ b/crates/partly-proxy-lib/src/context.rs @@ -8,6 +8,29 @@ use std::time::Instant; use http::Extensions; use uuid::Uuid; +/// Which terminal stage produced the response for this request. +/// +/// Stamped by `LiveTerminal` after its routing decision and read back by +/// middleware (post-`next.run`) and by the response-emit path. +/// +/// Absent from the context when a middleware short-circuited the chain and +/// the terminal never ran — middleware that need to distinguish their own +/// short-circuit from a terminal outcome should treat `None` as "produced +/// by middleware". +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ResponseSource { + /// A registered stub matched and fired. + Stub, + /// A recorded exchange satisfied the replay lookup. + Snapshot, + /// `Mode::Replay` with no stub and no snapshot match — synthetic 503. + ReplayMiss, + /// `Mode::Record` — the request was forwarded to the real upstream. + /// Stamped before the forward is awaited, so the marker is present + /// even if the upstream call ultimately errors. + Upstream, +} + /// Per-request bag of state passed mutably through each middleware. #[derive(Debug)] pub struct RequestContext { @@ -58,6 +81,12 @@ impl RequestContext { pub fn extensions_mut(&mut self) -> &mut Extensions { &mut self.extensions } + + /// The terminal stage that produced the response, if the terminal ran. + /// Returns `None` when a middleware short-circuited the chain. + pub fn response_source(&self) -> Option { + self.get::().copied() + } } #[cfg(test)] @@ -71,6 +100,21 @@ mod tests { assert_ne!(a.id, b.id); } + #[test] + fn response_source_round_trips_each_variant() { + for source in [ + ResponseSource::Stub, + ResponseSource::Snapshot, + ResponseSource::ReplayMiss, + ResponseSource::Upstream, + ] { + let mut ctx = RequestContext::new(); + assert_eq!(ctx.response_source(), None); + ctx.insert(source); + assert_eq!(ctx.response_source(), Some(source)); + } + } + #[test] fn extension_round_trips_by_type() { #[derive(Clone, Debug, PartialEq)] diff --git a/crates/partly-proxy-lib/src/lib.rs b/crates/partly-proxy-lib/src/lib.rs index 06c8ed9..5ab6f2f 100644 --- a/crates/partly-proxy-lib/src/lib.rs +++ b/crates/partly-proxy-lib/src/lib.rs @@ -40,7 +40,7 @@ pub use command::{Command, CommandResponse, CommandSender}; pub use config::{ InboundTlsConfig, Mode, ProxyConfig, RecordingConfig, UpstreamTarget, UpstreamTlsConfig, }; -pub use context::RequestContext; +pub use context::{RequestContext, ResponseSource}; pub use middleware::{Next, ProxyMiddleware, SharedMiddleware, Terminal, TerminalFuture}; /// Re-export of the JSON-Lines snapshot backend, available when the /// `storage-jsonl` feature is on (which it is by default). diff --git a/crates/partly-proxy-lib/src/listener.rs b/crates/partly-proxy-lib/src/listener.rs index e4466ea..c837cf0 100644 --- a/crates/partly-proxy-lib/src/listener.rs +++ b/crates/partly-proxy-lib/src/listener.rs @@ -31,7 +31,7 @@ use tokio_util::{sync::CancellationToken, task::TaskTracker}; use crate::{ builder::UpstreamSpec, config::Mode, - context::RequestContext, + context::{RequestContext, ResponseSource}, forwarder::Forwarder, middleware::{self, SharedMiddleware, Terminal, TerminalFuture}, proxy_io::{ProxyRequest, ProxyResponse}, @@ -438,27 +438,35 @@ struct LiveTerminal<'a> { } impl Terminal for LiveTerminal<'_> { - fn invoke<'b>(&'b self, req: ProxyRequest, _ctx: &'b mut RequestContext) -> TerminalFuture<'b> { + fn invoke<'b>(&'b self, req: ProxyRequest, ctx: &'b mut RequestContext) -> TerminalFuture<'b> { Box::pin(async move { // Lifecycle stage 6: stub scan. The first matching stub wins. if let Some((response, delay)) = self.runtime.stubs.take_match(&req).await { if let Some(d) = delay { tokio::time::sleep(d).await; } + ctx.insert(ResponseSource::Stub); return Ok(response.into_proxy()); } // Lifecycle stage 7: replay lookup. The lookup applies // `redact_request_for_snapshot` to a working copy before hashing. if let Some(source) = &self.runtime.replay { if let Some(resp) = source.lookup(&req, &self.runtime.middleware) { + ctx.insert(ResponseSource::Snapshot); return Ok(resp); } } // Lifecycle stage 8: terminal miss. In Replay mode we never // touch the upstream — see SPECIFICATION.md §8.3. match self.runtime.mode { - Mode::Replay => Ok(replay_miss_response()), + Mode::Replay => { + ctx.insert(ResponseSource::ReplayMiss); + Ok(replay_miss_response()) + } Mode::Record => { + // Stamp before awaiting so the marker is present even + // if the forward errors. + ctx.insert(ResponseSource::Upstream); self.runtime .forwarder .forward(req, &self.runtime.name) diff --git a/crates/partly-proxy-lib/tests/replay.rs b/crates/partly-proxy-lib/tests/replay.rs index 613a114..b3f5abf 100644 --- a/crates/partly-proxy-lib/tests/replay.rs +++ b/crates/partly-proxy-lib/tests/replay.rs @@ -1,6 +1,10 @@ //! Replay layered with middleware, stubs and the live forwarder. -use std::{net::SocketAddr, sync::Arc, time::Duration}; +use std::{ + net::SocketAddr, + sync::{Arc, Mutex}, + time::Duration, +}; use async_trait::async_trait; use bytes::Bytes; @@ -10,7 +14,7 @@ use partly_proxy_lib::{ Command, ExchangeOutcome, MatchStrategy, Mode, Next, ProxyClusterBuilder, ProxyConfig, ProxyMiddleware, ProxyRequest, ProxyResponse, RecordedExchange, RecordedRequest, RecordedResponse, RecordingConfig, ReplaySource, RequestContext, RequestMatcher, - Result as ProxyResult, SharedMiddleware, StubbedResponse, UpstreamTarget, + ResponseSource, Result as ProxyResult, SharedMiddleware, StubbedResponse, UpstreamTarget, }; use tokio::task::JoinHandle; @@ -386,3 +390,199 @@ async fn replay_records_served_exchanges_when_recording_enabled() { cluster.shutdown().await.unwrap(); } + +/// Captures `ctx.response_source()` after `next.run` returns. Used by the +/// `ResponseSource` tests to assert which terminal branch produced the response. +struct CaptureSource(Arc>>); + +#[async_trait] +impl ProxyMiddleware for CaptureSource { + async fn handle( + &self, + req: ProxyRequest, + ctx: &mut RequestContext, + next: Next<'_>, + ) -> ProxyResult { + let resp = next.run(req, ctx).await; + *self.0.lock().unwrap() = ctx.response_source(); + resp + } +} + +/// Short-circuits without ever calling `next.run`. +struct ShortCircuit; + +#[async_trait] +impl ProxyMiddleware for ShortCircuit { + async fn handle( + &self, + _req: ProxyRequest, + _ctx: &mut RequestContext, + _next: Next<'_>, + ) -> ProxyResult { + Ok(ProxyResponse::new(StatusCode::IM_A_TEAPOT).with_body(Bytes::from_static(b"short"))) + } +} + +fn unreachable_addr() -> SocketAddr { + let l = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let a = l.local_addr().unwrap(); + drop(l); + a +} + +#[tokio::test] +async fn response_source_stub_marks_ctx() { + let captured = Arc::new(Mutex::new(None)); + let cluster = ProxyClusterBuilder::new() + .add_upstream_with( + "api", + cfg(format!("http://{}", unreachable_addr())), + vec![Arc::new(CaptureSource(captured.clone())) as SharedMiddleware], + None, + ) + .run() + .await + .unwrap(); + let proxy = cluster.addr("api").unwrap(); + + cluster + .command_sender() + .send(Command::Stub { + upstream: None, + matcher: RequestMatcher::new().method(Method::GET).path("/x"), + response: StubbedResponse::new(StatusCode::OK).body(Bytes::from_static(b"ok")), + times: Some(1), + }) + .await + .unwrap(); + + let r = http_client() + .get(format!("http://{proxy}/x")) + .send() + .await + .unwrap(); + assert_eq!(r.status(), 200); + assert_eq!(*captured.lock().unwrap(), Some(ResponseSource::Stub)); + + cluster.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn response_source_snapshot_marks_ctx() { + let captured = Arc::new(Mutex::new(None)); + let replay = ReplaySource::new( + vec![make_recorded(Method::GET, "/x", b"", 200, b"replayed")], + MatchStrategy::MethodUriAndBodyHash, + ); + let cluster = ProxyClusterBuilder::new() + .add_upstream_with_mode( + "api", + cfg(format!("http://{}", unreachable_addr())), + vec![Arc::new(CaptureSource(captured.clone())) as SharedMiddleware], + Some(replay), + Mode::Replay, + ) + .run() + .await + .unwrap(); + let proxy = cluster.addr("api").unwrap(); + + let r = http_client() + .get(format!("http://{proxy}/x")) + .send() + .await + .unwrap(); + assert_eq!(r.status(), 200); + assert_eq!(*captured.lock().unwrap(), Some(ResponseSource::Snapshot)); + + cluster.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn response_source_replay_miss_marks_ctx() { + let captured = Arc::new(Mutex::new(None)); + let replay = ReplaySource::new( + vec![make_recorded(Method::GET, "/x", b"", 200, b"replayed")], + MatchStrategy::MethodUriAndBodyHash, + ); + let cluster = ProxyClusterBuilder::new() + .add_upstream_with_mode( + "api", + cfg(format!("http://{}", unreachable_addr())), + vec![Arc::new(CaptureSource(captured.clone())) as SharedMiddleware], + Some(replay), + Mode::Replay, + ) + .run() + .await + .unwrap(); + let proxy = cluster.addr("api").unwrap(); + + let r = http_client() + .get(format!("http://{proxy}/miss")) + .send() + .await + .unwrap(); + assert_eq!(r.status(), StatusCode::SERVICE_UNAVAILABLE); + assert_eq!(*captured.lock().unwrap(), Some(ResponseSource::ReplayMiss)); + + cluster.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn response_source_upstream_marks_ctx() { + let (echo_addr, _t) = spawn_echo().await; + let captured = Arc::new(Mutex::new(None)); + let cluster = ProxyClusterBuilder::new() + .add_upstream_with_mode( + "api", + cfg(format!("http://{echo_addr}")), + vec![Arc::new(CaptureSource(captured.clone())) as SharedMiddleware], + None, + Mode::Record, + ) + .run() + .await + .unwrap(); + let proxy = cluster.addr("api").unwrap(); + + let r = http_client() + .get(format!("http://{proxy}/anything")) + .send() + .await + .unwrap(); + assert!(r.status().is_success()); + assert_eq!(*captured.lock().unwrap(), Some(ResponseSource::Upstream)); + + cluster.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn response_source_absent_when_middleware_short_circuits() { + let captured = Arc::new(Mutex::new(None)); + let cluster = ProxyClusterBuilder::new() + .add_upstream_with( + "api", + cfg(format!("http://{}", unreachable_addr())), + vec![ + Arc::new(CaptureSource(captured.clone())) as SharedMiddleware, + Arc::new(ShortCircuit) as SharedMiddleware, + ], + None, + ) + .run() + .await + .unwrap(); + let proxy = cluster.addr("api").unwrap(); + + let r = http_client() + .get(format!("http://{proxy}/anything")) + .send() + .await + .unwrap(); + assert_eq!(r.status(), StatusCode::IM_A_TEAPOT); + assert_eq!(*captured.lock().unwrap(), None); + + cluster.shutdown().await.unwrap(); +}