Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions crates/partly-proxy-lib/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,29 @@ use std::time::Instant;
use http::Extensions;
use uuid::Uuid;

/// Which terminal stage produced the response for this request.
///
/// Stamped by `LiveTerminal` after its routing decision and read back by
/// middleware (post-`next.run`) and by the response-emit path.
///
/// Absent from the context when a middleware short-circuited the chain and
/// the terminal never ran — middleware that need to distinguish their own
/// short-circuit from a terminal outcome should treat `None` as "produced
/// by middleware".
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResponseSource {
/// A registered stub matched and fired.
Stub,
/// A recorded exchange satisfied the replay lookup.
Snapshot,
/// `Mode::Replay` with no stub and no snapshot match — synthetic 503.
ReplayMiss,
/// `Mode::Record` — the request was forwarded to the real upstream.
/// Stamped before the forward is awaited, so the marker is present
/// even if the upstream call ultimately errors.
Upstream,
}

/// Per-request bag of state passed mutably through each middleware.
#[derive(Debug)]
pub struct RequestContext {
Expand Down Expand Up @@ -58,6 +81,12 @@ impl RequestContext {
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}

/// The terminal stage that produced the response, if the terminal ran.
/// Returns `None` when a middleware short-circuited the chain.
pub fn response_source(&self) -> Option<ResponseSource> {
self.get::<ResponseSource>().copied()
}
}

#[cfg(test)]
Expand All @@ -71,6 +100,21 @@ mod tests {
assert_ne!(a.id, b.id);
}

#[test]
fn response_source_round_trips_each_variant() {
for source in [
ResponseSource::Stub,
ResponseSource::Snapshot,
ResponseSource::ReplayMiss,
ResponseSource::Upstream,
] {
let mut ctx = RequestContext::new();
assert_eq!(ctx.response_source(), None);
ctx.insert(source);
assert_eq!(ctx.response_source(), Some(source));
}
}

#[test]
fn extension_round_trips_by_type() {
#[derive(Clone, Debug, PartialEq)]
Expand Down
2 changes: 1 addition & 1 deletion crates/partly-proxy-lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub use command::{Command, CommandResponse, CommandSender};
pub use config::{
InboundTlsConfig, Mode, ProxyConfig, RecordingConfig, UpstreamTarget, UpstreamTlsConfig,
};
pub use context::RequestContext;
pub use context::{RequestContext, ResponseSource};
pub use middleware::{Next, ProxyMiddleware, SharedMiddleware, Terminal, TerminalFuture};
/// Re-export of the JSON-Lines snapshot backend, available when the
/// `storage-jsonl` feature is on (which it is by default).
Expand Down
14 changes: 11 additions & 3 deletions crates/partly-proxy-lib/src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use tokio_util::{sync::CancellationToken, task::TaskTracker};
use crate::{
builder::UpstreamSpec,
config::Mode,
context::RequestContext,
context::{RequestContext, ResponseSource},
forwarder::Forwarder,
middleware::{self, SharedMiddleware, Terminal, TerminalFuture},
proxy_io::{ProxyRequest, ProxyResponse},
Expand Down Expand Up @@ -438,27 +438,35 @@ struct LiveTerminal<'a> {
}

impl Terminal for LiveTerminal<'_> {
fn invoke<'b>(&'b self, req: ProxyRequest, _ctx: &'b mut RequestContext) -> TerminalFuture<'b> {
fn invoke<'b>(&'b self, req: ProxyRequest, ctx: &'b mut RequestContext) -> TerminalFuture<'b> {
Box::pin(async move {
// Lifecycle stage 6: stub scan. The first matching stub wins.
if let Some((response, delay)) = self.runtime.stubs.take_match(&req).await {
if let Some(d) = delay {
tokio::time::sleep(d).await;
}
ctx.insert(ResponseSource::Stub);
return Ok(response.into_proxy());
}
// Lifecycle stage 7: replay lookup. The lookup applies
// `redact_request_for_snapshot` to a working copy before hashing.
if let Some(source) = &self.runtime.replay {
if let Some(resp) = source.lookup(&req, &self.runtime.middleware) {
ctx.insert(ResponseSource::Snapshot);
return Ok(resp);
}
}
// Lifecycle stage 8: terminal miss. In Replay mode we never
// touch the upstream — see SPECIFICATION.md §8.3.
match self.runtime.mode {
Mode::Replay => Ok(replay_miss_response()),
Mode::Replay => {
ctx.insert(ResponseSource::ReplayMiss);
Ok(replay_miss_response())
}
Mode::Record => {
// Stamp before awaiting so the marker is present even
// if the forward errors.
ctx.insert(ResponseSource::Upstream);
self.runtime
.forwarder
.forward(req, &self.runtime.name)
Expand Down
204 changes: 202 additions & 2 deletions crates/partly-proxy-lib/tests/replay.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//! Replay layered with middleware, stubs and the live forwarder.

use std::{net::SocketAddr, sync::Arc, time::Duration};
use std::{
net::SocketAddr,
sync::{Arc, Mutex},
time::Duration,
};

use async_trait::async_trait;
use bytes::Bytes;
Expand All @@ -10,7 +14,7 @@ use partly_proxy_lib::{
Command, ExchangeOutcome, MatchStrategy, Mode, Next, ProxyClusterBuilder, ProxyConfig,
ProxyMiddleware, ProxyRequest, ProxyResponse, RecordedExchange, RecordedRequest,
RecordedResponse, RecordingConfig, ReplaySource, RequestContext, RequestMatcher,
Result as ProxyResult, SharedMiddleware, StubbedResponse, UpstreamTarget,
ResponseSource, Result as ProxyResult, SharedMiddleware, StubbedResponse, UpstreamTarget,
};
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -386,3 +390,199 @@ async fn replay_records_served_exchanges_when_recording_enabled() {

cluster.shutdown().await.unwrap();
}

/// Captures `ctx.response_source()` after `next.run` returns. Used by the
/// `ResponseSource` tests to assert which terminal branch produced the response.
struct CaptureSource(Arc<Mutex<Option<ResponseSource>>>);

#[async_trait]
impl ProxyMiddleware for CaptureSource {
async fn handle(
&self,
req: ProxyRequest,
ctx: &mut RequestContext,
next: Next<'_>,
) -> ProxyResult<ProxyResponse> {
let resp = next.run(req, ctx).await;
*self.0.lock().unwrap() = ctx.response_source();
resp
}
}

/// Short-circuits without ever calling `next.run`.
struct ShortCircuit;

#[async_trait]
impl ProxyMiddleware for ShortCircuit {
async fn handle(
&self,
_req: ProxyRequest,
_ctx: &mut RequestContext,
_next: Next<'_>,
) -> ProxyResult<ProxyResponse> {
Ok(ProxyResponse::new(StatusCode::IM_A_TEAPOT).with_body(Bytes::from_static(b"short")))
}
}

fn unreachable_addr() -> SocketAddr {
let l = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let a = l.local_addr().unwrap();
drop(l);
a
}

#[tokio::test]
async fn response_source_stub_marks_ctx() {
let captured = Arc::new(Mutex::new(None));
let cluster = ProxyClusterBuilder::new()
.add_upstream_with(
"api",
cfg(format!("http://{}", unreachable_addr())),
vec![Arc::new(CaptureSource(captured.clone())) as SharedMiddleware],
None,
)
.run()
.await
.unwrap();
let proxy = cluster.addr("api").unwrap();

cluster
.command_sender()
.send(Command::Stub {
upstream: None,
matcher: RequestMatcher::new().method(Method::GET).path("/x"),
response: StubbedResponse::new(StatusCode::OK).body(Bytes::from_static(b"ok")),
times: Some(1),
})
.await
.unwrap();

let r = http_client()
.get(format!("http://{proxy}/x"))
.send()
.await
.unwrap();
assert_eq!(r.status(), 200);
assert_eq!(*captured.lock().unwrap(), Some(ResponseSource::Stub));

cluster.shutdown().await.unwrap();
}

#[tokio::test]
async fn response_source_snapshot_marks_ctx() {
let captured = Arc::new(Mutex::new(None));
let replay = ReplaySource::new(
vec![make_recorded(Method::GET, "/x", b"", 200, b"replayed")],
MatchStrategy::MethodUriAndBodyHash,
);
let cluster = ProxyClusterBuilder::new()
.add_upstream_with_mode(
"api",
cfg(format!("http://{}", unreachable_addr())),
vec![Arc::new(CaptureSource(captured.clone())) as SharedMiddleware],
Some(replay),
Mode::Replay,
)
.run()
.await
.unwrap();
let proxy = cluster.addr("api").unwrap();

let r = http_client()
.get(format!("http://{proxy}/x"))
.send()
.await
.unwrap();
assert_eq!(r.status(), 200);
assert_eq!(*captured.lock().unwrap(), Some(ResponseSource::Snapshot));

cluster.shutdown().await.unwrap();
}

#[tokio::test]
async fn response_source_replay_miss_marks_ctx() {
let captured = Arc::new(Mutex::new(None));
let replay = ReplaySource::new(
vec![make_recorded(Method::GET, "/x", b"", 200, b"replayed")],
MatchStrategy::MethodUriAndBodyHash,
);
let cluster = ProxyClusterBuilder::new()
.add_upstream_with_mode(
"api",
cfg(format!("http://{}", unreachable_addr())),
vec![Arc::new(CaptureSource(captured.clone())) as SharedMiddleware],
Some(replay),
Mode::Replay,
)
.run()
.await
.unwrap();
let proxy = cluster.addr("api").unwrap();

let r = http_client()
.get(format!("http://{proxy}/miss"))
.send()
.await
.unwrap();
assert_eq!(r.status(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(*captured.lock().unwrap(), Some(ResponseSource::ReplayMiss));

cluster.shutdown().await.unwrap();
}

#[tokio::test]
async fn response_source_upstream_marks_ctx() {
let (echo_addr, _t) = spawn_echo().await;
let captured = Arc::new(Mutex::new(None));
let cluster = ProxyClusterBuilder::new()
.add_upstream_with_mode(
"api",
cfg(format!("http://{echo_addr}")),
vec![Arc::new(CaptureSource(captured.clone())) as SharedMiddleware],
None,
Mode::Record,
)
.run()
.await
.unwrap();
let proxy = cluster.addr("api").unwrap();

let r = http_client()
.get(format!("http://{proxy}/anything"))
.send()
.await
.unwrap();
assert!(r.status().is_success());
assert_eq!(*captured.lock().unwrap(), Some(ResponseSource::Upstream));

cluster.shutdown().await.unwrap();
}

#[tokio::test]
async fn response_source_absent_when_middleware_short_circuits() {
let captured = Arc::new(Mutex::new(None));
let cluster = ProxyClusterBuilder::new()
.add_upstream_with(
"api",
cfg(format!("http://{}", unreachable_addr())),
vec![
Arc::new(CaptureSource(captured.clone())) as SharedMiddleware,
Arc::new(ShortCircuit) as SharedMiddleware,
],
None,
)
.run()
.await
.unwrap();
let proxy = cluster.addr("api").unwrap();

let r = http_client()
.get(format!("http://{proxy}/anything"))
.send()
.await
.unwrap();
assert_eq!(r.status(), StatusCode::IM_A_TEAPOT);
assert_eq!(*captured.lock().unwrap(), None);

cluster.shutdown().await.unwrap();
}
Loading