Skip to content

Commit 6be3d25

Browse files
test_utils: new module cancel_safe
Differential Revision: D80626978
1 parent a9cd5a2 commit 6be3d25

File tree

2 files changed

+318
-0
lines changed

2 files changed

+318
-0
lines changed

hyperactor/src/test_utils.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
/// Utilities to verify cancellation safety.
10+
pub mod cancel_safe;
911
/// PingPongActor test util.
1012
pub mod pingpong;
1113
/// ProcSupervisionCoordinator test util.
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//! Utilities for testing **cancel safety** of futures.
10+
//!
11+
//! # What does "cancel-safe" mean?
12+
//!
13+
//! A future is *cancel-safe* if, at **any** `Poll::Pending` boundary:
14+
//!
15+
//! 1. **State remains valid** – dropping the future there does not
16+
//! violate external invariants or leave shared state corrupted.
17+
//! 2. **Restartability holds** – from that state, constructing a
18+
//! fresh future for the same logical operation can still run to
19+
//! completion and produce the expected result.
20+
//! 3. **No partial side effects** – cancellation never leaves behind
21+
//! a visible "half-done" action; effects are either not started,
22+
//! or fully completed in an idempotent way.
23+
//!
24+
//! # Why cancel-safety matters
25+
//!
26+
//! Executors are free to drop futures after any `Poll::Pending`. This
27+
//! means that cancellation is not an exceptional path – it is *part
28+
//! of the normal contract*. A cancel-unsafe future can leak
29+
//! resources, corrupt protocol state, or leave behind truncated I/O.
30+
//!
31+
//! # What this module offers
32+
//!
33+
//! This module provides helpers (`assert_cancel_safe`,
34+
//! `assert_cancel_safe_async`) that:
35+
//!
36+
//! - drive a future to completion once, counting its yield points,
37+
//! - then for every possible cancellation boundary `k`, poll a fresh
38+
//! future `k` times, drop it, and finally ensure a **new run**
39+
//! still produces the expected result.
40+
//!
41+
//! # Examples
42+
//!
43+
//! - ✓ Pure/logical futures: simple state machines with no I/O (e.g.
44+
//! yields twice, then return 42).
45+
//! - ✓ Framed writers that stage bytes internally and only commit
46+
//! once the frame is fully written.
47+
//! - ✗ Writers that flush a partial frame before returning `Pending`.
48+
//! - ✗ Futures that consume from a shared queue before `Pending` and
49+
//! drop without rollback.
50+
51+
use std::fmt::Debug;
52+
use std::future::Future;
53+
use std::pin::Pin;
54+
use std::task::Context;
55+
use std::task::Poll;
56+
use std::task::RawWaker;
57+
use std::task::RawWakerVTable;
58+
use std::task::Waker;
59+
60+
/// A minimal no-op waker for manual polling.
61+
fn noop_waker() -> Waker {
62+
fn clone(_: *const ()) -> RawWaker {
63+
RawWaker::new(std::ptr::null(), &VTABLE)
64+
}
65+
fn wake(_: *const ()) {}
66+
fn wake_by_ref(_: *const ()) {}
67+
fn drop(_: *const ()) {}
68+
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
69+
// SAFETY: The vtable doesn't use the data pointer.
70+
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
71+
}
72+
73+
/// Poll a future once.
74+
fn poll_once<F: Future + Unpin>(fut: &mut F, cx: &mut Context<'_>) -> Poll<F::Output> {
75+
Pin::new(fut).poll(cx)
76+
}
77+
78+
/// Drive a fresh future to completion, returning (`pending_count`,
79+
/// `out`). `pending_count` is the number of times the future returned
80+
/// `Poll::Pending` before it finally resolved to `Poll::Ready`.
81+
fn run_to_completion_count_pending<F, T>(mut mk: impl FnMut() -> F) -> (usize, T)
82+
where
83+
F: Future<Output = T>,
84+
{
85+
let waker = noop_waker();
86+
let mut cx = Context::from_waker(&waker);
87+
88+
let mut fut = Box::pin(mk());
89+
let mut pending_count = 0usize;
90+
91+
loop {
92+
match poll_once(&mut fut, &mut cx) {
93+
Poll::Ready(out) => return (pending_count, out),
94+
Poll::Pending => {
95+
pending_count += 1;
96+
// Nothing else to do: we are just counting yield
97+
// points.
98+
}
99+
}
100+
}
101+
}
102+
103+
/// Runtime-independent version: on each `Poll::Pending`, we just poll
104+
/// again. Suitable for pure/logical futures that don’t rely on
105+
/// timers, IO, or other external progress driven by an async runtime.
106+
pub fn assert_cancel_safe<F, T>(mut mk: impl FnMut() -> F, expected: &T)
107+
where
108+
F: Future<Output = T>,
109+
T: Debug + PartialEq,
110+
{
111+
// 1) Establish ground truth and number of yield points.
112+
let (pending_total, out) = run_to_completion_count_pending(&mut mk);
113+
assert_eq!(&out, expected, "baseline run output mismatch");
114+
115+
// 2) Cancel at every poll boundary k, then ensure a fresh run
116+
// still matches.
117+
for k in 0..=pending_total {
118+
let waker = noop_waker();
119+
let mut cx = Context::from_waker(&waker);
120+
121+
// Poll exactly k times (dropping afterwards).
122+
{
123+
let mut fut = Box::pin(mk());
124+
for _ in 0..k {
125+
if poll_once(&mut fut, &mut cx).is_ready() {
126+
// Future completed earlier than k: no
127+
// cancellation point here. Drop and move on to
128+
// next k.
129+
break;
130+
}
131+
}
132+
// Drop here = "cancellation".
133+
drop(fut);
134+
}
135+
136+
// 3) Now ensure we can still complete cleanly and match
137+
// expected. This verifies cancelling at this boundary didn’t
138+
// corrupt global state or violate invariants needed for a
139+
// clean, subsequent run.
140+
let (_, out2) = run_to_completion_count_pending(&mut mk);
141+
assert_eq!(
142+
&out2, expected,
143+
"output mismatch after cancelling at poll #{k}"
144+
);
145+
}
146+
}
147+
148+
/// Cancel-safety check for async futures. On every `Poll::Pending`,
149+
/// runs `on_pending().await` to drive external progress (e.g.
150+
/// advancing a paused clock or IO). Cancels at each yield boundary
151+
/// and ensures a fresh run still produces `expected`.
152+
pub async fn assert_cancel_safe_async<F, T, P, FutStep>(
153+
mut mk: impl FnMut() -> F,
154+
expected: &T,
155+
mut on_pending: P,
156+
) where
157+
F: Future<Output = T>,
158+
T: Debug + PartialEq,
159+
P: FnMut() -> FutStep,
160+
FutStep: Future<Output = ()>,
161+
{
162+
let waker = noop_waker();
163+
let mut cx = Context::from_waker(&waker);
164+
165+
// 1) First, establish expected + number of pendings with the
166+
// ability to drive progress.
167+
let mut pending_total = 0usize;
168+
{
169+
let mut fut = Box::pin(mk());
170+
loop {
171+
match poll_once(&mut fut, &mut cx) {
172+
Poll::Ready(out) => {
173+
assert_eq!(&out, expected, "baseline run output mismatch");
174+
break;
175+
}
176+
Poll::Pending => {
177+
pending_total += 1;
178+
on_pending().await;
179+
}
180+
}
181+
}
182+
}
183+
184+
// 2) Cancel at each poll boundary.
185+
for k in 0..=pending_total {
186+
// Poll exactly k steps, advancing external progress each
187+
// time.
188+
{
189+
let mut fut = Box::pin(mk());
190+
for _ in 0..k {
191+
match poll_once(&mut fut, &mut cx) {
192+
Poll::Ready(_) => break, // Completed earlier than k
193+
Poll::Pending => on_pending().await,
194+
}
195+
}
196+
drop(fut); // cancellation
197+
}
198+
199+
// 3) Then ensure a clean full completion still yields
200+
// expected.
201+
{
202+
let mut fut = Box::pin(mk());
203+
loop {
204+
match poll_once(&mut fut, &mut cx) {
205+
Poll::Ready(out) => {
206+
assert_eq!(
207+
&out, expected,
208+
"output mismatch after cancelling at poll #{k}"
209+
);
210+
break;
211+
}
212+
Poll::Pending => on_pending().await,
213+
}
214+
}
215+
}
216+
}
217+
}
218+
219+
/// Convenience macro for `assert_cancel_safe`.
220+
///
221+
/// Example:
222+
/// ```ignore
223+
/// assert_cancel_safe!(CountToThree { step: 0 }, 42);
224+
/// ```
225+
///
226+
/// - `my_future_expr` is any expression that produces a fresh future
227+
/// when evaluated (e.g. `CountToThree { step: 0 }`).
228+
/// - `expected_value` is the value you expect the future to resolve
229+
/// to. **Pass a plain value, not a reference**. The macro will take a
230+
/// reference internally.
231+
#[macro_export]
232+
macro_rules! assert_cancel_safe {
233+
($make_future:expr, $expected:expr) => {{ $crate::test_utils::cancel_safe::assert_cancel_safe(|| $make_future, &$expected) }};
234+
}
235+
236+
/// Async convenience macro for `assert_cancel_safe_async`.
237+
///
238+
/// Example:
239+
/// ```ignore
240+
/// assert_cancel_safe_async!(
241+
/// two_sleeps(),
242+
/// 7,
243+
/// || async { tokio::time::advance(std::time::Duration::from_millis(1)).await }
244+
/// );
245+
/// ```
246+
///
247+
/// - `my_future_expr` is any expression that produces a fresh future
248+
/// when evaluated (e.g. `two_sleeps()`).
249+
/// - `expected_value` is the value you expect the future to resolve
250+
/// to. **Pass a plain value, not a reference**. The macro will take
251+
/// a reference internally.
252+
/// - `on_pending` is a closure that returns an async block, used to
253+
/// drive external progress each time the future yields
254+
/// `Poll::Pending`.
255+
#[macro_export]
256+
macro_rules! assert_cancel_safe_async {
257+
($make_future:expr, $expected:expr, $on_pending:expr) => {{
258+
$crate::test_utils::cancel_safe::assert_cancel_safe_async(
259+
|| $make_future,
260+
&$expected,
261+
$on_pending,
262+
)
263+
.await
264+
}};
265+
}
266+
267+
#[cfg(test)]
268+
mod tests {
269+
use tokio::time::Duration;
270+
use tokio::time::{self};
271+
272+
use super::*;
273+
274+
// A future that yields twice, then returns a number.
275+
struct CountToThree {
276+
step: u8,
277+
}
278+
279+
impl Future for CountToThree {
280+
type Output = u8;
281+
282+
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
283+
self.step += 1;
284+
match self.step {
285+
1 | 2 => Poll::Pending, // yield twice...
286+
3 => Poll::Ready(42), // ... 3rd time's a charm
287+
_ => panic!("polled after completion"),
288+
}
289+
}
290+
}
291+
292+
// Smoke test: verify that a simple state-machine future (yields
293+
// twice, then completes) passes the cancel-safety checks.
294+
#[test]
295+
fn test_count_to_three_cancel_safe() {
296+
assert_cancel_safe!(CountToThree { step: 0 }, 42u8);
297+
}
298+
299+
// A future that waits for two sleeps (1ms each), then returns 7.
300+
#[allow(clippy::disallowed_methods)]
301+
async fn two_sleeps() -> u8 {
302+
time::sleep(Duration::from_millis(1)).await;
303+
time::sleep(Duration::from_millis(1)).await;
304+
7
305+
}
306+
307+
// Smoke test: verify that a timer-based async future (with two
308+
// sleeps) passes the async cancel-safety checks under tokio's
309+
// mocked time.
310+
#[tokio::test(start_paused = true)]
311+
async fn test_two_sleeps_cancel_safe_async() {
312+
assert_cancel_safe_async!(two_sleeps(), 7, || async {
313+
time::advance(Duration::from_millis(1)).await
314+
});
315+
}
316+
}

0 commit comments

Comments
 (0)