Skip to content

Commit 4ad2fd8

Browse files
committed
refactor: Remove Mutex and utilize __aiter__ with _async_poll(timeout_ms) instead
1 parent 195ec7c commit 4ad2fd8

2 files changed

Lines changed: 412 additions & 109 deletions

File tree

bindings/python/src/table.rs

Lines changed: 71 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1864,13 +1864,6 @@ enum ScannerKind {
18641864
Batch(fcore::client::RecordBatchLogScanner),
18651865
}
18661866

1867-
/// The internal state of the scanner, protected by a Tokio Mutex for async cross-thread sharing
1868-
struct ScannerState {
1869-
kind: ScannerKind,
1870-
/// A buffer to hold records polled from the network before yielding them one-by-one to Python
1871-
pending_records: std::collections::VecDeque<Py<ScanRecord>>,
1872-
}
1873-
18741867
impl ScannerKind {
18751868
fn as_record(&self) -> PyResult<&fcore::client::LogScanner> {
18761869
match self {
@@ -1895,7 +1888,7 @@ impl ScannerKind {
18951888
/// Both `LogScanner` and `RecordBatchLogScanner` share the same subscribe interface.
18961889
macro_rules! with_scanner {
18971890
($scanner:expr, $method:ident($($arg:expr),*)) => {
1898-
match $scanner {
1891+
match $scanner.as_ref() {
18991892
ScannerKind::Record(s) => s.$method($($arg),*).await,
19001893
ScannerKind::Batch(s) => s.$method($($arg),*).await,
19011894
}
@@ -1909,7 +1902,7 @@ macro_rules! with_scanner {
19091902
/// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches
19101903
#[pyclass]
19111904
pub struct LogScanner {
1912-
state: Arc<tokio::sync::Mutex<ScannerState>>,
1905+
kind: Arc<ScannerKind>,
19131906
admin: fcore::client::FlussAdmin,
19141907
table_info: fcore::metadata::TableInfo,
19151908
/// The projected Arrow schema to use for empty table creation
@@ -1930,8 +1923,7 @@ impl LogScanner {
19301923
fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> {
19311924
py.detach(|| {
19321925
TOKIO_RUNTIME.block_on(async {
1933-
let state = self.state.lock().await;
1934-
with_scanner!(&state.kind, subscribe(bucket_id, start_offset))
1926+
with_scanner!(&self.kind, subscribe(bucket_id, start_offset))
19351927
.map_err(|e| FlussError::from_core_error(&e))
19361928
})
19371929
})
@@ -1944,8 +1936,7 @@ impl LogScanner {
19441936
fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap<i32, i64>) -> PyResult<()> {
19451937
py.detach(|| {
19461938
TOKIO_RUNTIME.block_on(async {
1947-
let state = self.state.lock().await;
1948-
with_scanner!(&state.kind, subscribe_buckets(&bucket_offsets))
1939+
with_scanner!(&self.kind, subscribe_buckets(&bucket_offsets))
19491940
.map_err(|e| FlussError::from_core_error(&e))
19501941
})
19511942
})
@@ -1966,9 +1957,8 @@ impl LogScanner {
19661957
) -> PyResult<()> {
19671958
py.detach(|| {
19681959
TOKIO_RUNTIME.block_on(async {
1969-
let state = self.state.lock().await;
19701960
with_scanner!(
1971-
&state.kind,
1961+
&self.kind,
19721962
subscribe_partition(partition_id, bucket_id, start_offset)
19731963
)
19741964
.map_err(|e| FlussError::from_core_error(&e))
@@ -1987,9 +1977,8 @@ impl LogScanner {
19871977
) -> PyResult<()> {
19881978
py.detach(|| {
19891979
TOKIO_RUNTIME.block_on(async {
1990-
let state = self.state.lock().await;
19911980
with_scanner!(
1992-
&state.kind,
1981+
&self.kind,
19931982
subscribe_partition_buckets(&partition_bucket_offsets)
19941983
)
19951984
.map_err(|e| FlussError::from_core_error(&e))
@@ -2004,8 +1993,7 @@ impl LogScanner {
20041993
fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> {
20051994
py.detach(|| {
20061995
TOKIO_RUNTIME.block_on(async {
2007-
let state = self.state.lock().await;
2008-
with_scanner!(&state.kind, unsubscribe(bucket_id))
1996+
with_scanner!(&self.kind, unsubscribe(bucket_id))
20091997
.map_err(|e| FlussError::from_core_error(&e))
20101998
})
20111999
})
@@ -2019,8 +2007,7 @@ impl LogScanner {
20192007
fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> {
20202008
py.detach(|| {
20212009
TOKIO_RUNTIME.block_on(async {
2022-
let state = self.state.lock().await;
2023-
with_scanner!(&state.kind, unsubscribe_partition(partition_id, bucket_id))
2010+
with_scanner!(&self.kind, unsubscribe_partition(partition_id, bucket_id))
20242011
.map_err(|e| FlussError::from_core_error(&e))
20252012
})
20262013
})
@@ -2041,10 +2028,7 @@ impl LogScanner {
20412028
/// - Returns an empty ScanRecords if no records are available
20422029
/// - When timeout expires, returns an empty ScanRecords (NOT an error)
20432030
fn poll(&self, py: Python, timeout_ms: i64) -> PyResult<ScanRecords> {
2044-
let scanner_ref =
2045-
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2046-
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2047-
let scanner = lock.kind.as_record()?;
2031+
let scanner = self.kind.as_record()?;
20482032

20492033
if timeout_ms < 0 {
20502034
return Err(FlussError::new_err(format!(
@@ -2093,10 +2077,7 @@ impl LogScanner {
20932077
/// - Returns an empty list if no batches are available
20942078
/// - When timeout expires, returns an empty list (NOT an error)
20952079
fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult<Vec<RecordBatch>> {
2096-
let scanner_ref =
2097-
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2098-
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2099-
let scanner = lock.kind.as_batch()?;
2080+
let scanner = self.kind.as_batch()?;
21002081

21012082
if timeout_ms < 0 {
21022083
return Err(FlussError::new_err(format!(
@@ -2131,10 +2112,7 @@ impl LogScanner {
21312112
/// - Returns an empty table (with correct schema) if no records are available
21322113
/// - When timeout expires, returns an empty table (NOT an error)
21332114
fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult<Py<PyAny>> {
2134-
let scanner_ref =
2135-
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2136-
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2137-
let scanner = lock.kind.as_batch()?;
2115+
let scanner = self.kind.as_batch()?;
21382116

21392117
if timeout_ms < 0 {
21402118
return Err(FlussError::new_err(format!(
@@ -2188,11 +2166,7 @@ impl LogScanner {
21882166
/// PyArrow Table containing all data from subscribed buckets
21892167
fn to_arrow(&self, py: Python) -> PyResult<Py<PyAny>> {
21902168
let subscribed = {
2191-
let scanner_ref = unsafe {
2192-
&*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>)
2193-
};
2194-
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2195-
let scanner = lock.kind.as_batch()?;
2169+
let scanner = self.kind.as_batch()?;
21962170
let subs = scanner.get_subscribed_buckets();
21972171
if subs.is_empty() {
21982172
return Err(FlussError::new_err(
@@ -2227,87 +2201,84 @@ impl LogScanner {
22272201
}
22282202

22292203
fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult<Bound<'py, PyAny>> {
2204+
static ASYNC_GEN_FN: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
22302205
let py = slf.py();
2231-
let code = pyo3::ffi::c_str!(
2232-
r#"
2233-
async def _adapter(obj):
2206+
let gen_fn = ASYNC_GEN_FN.get_or_init(py, || {
2207+
let code = pyo3::ffi::c_str!(
2208+
r#"
2209+
async def _async_scan(scanner, timeout_ms=1000):
22342210
while True:
2235-
try:
2236-
yield await obj.__anext__()
2237-
except StopAsyncIteration:
2238-
break
2211+
batch = await scanner._async_poll(timeout_ms)
2212+
if batch:
2213+
for record in batch:
2214+
yield record
22392215
"#
2240-
);
2241-
let globals = pyo3::types::PyDict::new(py);
2242-
py.run(code, Some(&globals), None)?;
2243-
let adapter = globals.get_item("_adapter")?.unwrap();
2244-
// Return adapt(self)
2245-
adapter.call1((slf.into_bound_py_any(py)?,))
2216+
);
2217+
let globals = pyo3::types::PyDict::new(py);
2218+
py.run(code, Some(&globals), None).unwrap();
2219+
globals.get_item("_async_scan").unwrap().unwrap().unbind()
2220+
});
2221+
gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,))
22462222
}
22472223

2248-
fn __anext__<'py>(slf: PyRefMut<'py, Self>) -> PyResult<Option<Bound<'py, PyAny>>> {
2249-
let state_arc = slf.state.clone();
2250-
let projected_row_type = slf.projected_row_type.clone();
2251-
let py = slf.py();
2252-
2253-
let future = future_into_py(py, async move {
2254-
let mut state = state_arc.lock().await;
2224+
/// Perform a single bounded poll and return a list of ScanRecord objects.
2225+
///
2226+
/// This is the async building block used by `__aiter__` to implement
2227+
/// `async for`. Each call does exactly one network poll (bounded by
2228+
/// `timeout_ms`), converts any results to Python objects, and returns
2229+
/// them as a list. An empty list signals a timeout (no data yet), not
2230+
/// end-of-stream.
2231+
///
2232+
/// Args:
2233+
/// timeout_ms: Timeout in milliseconds for the network poll (default: 1000)
2234+
///
2235+
/// Returns:
2236+
/// Awaitable that resolves to a list of ScanRecord objects
2237+
fn _async_poll<'py>(
2238+
&self,
2239+
py: Python<'py>,
2240+
timeout_ms: Option<i64>,
2241+
) -> PyResult<Bound<'py, PyAny>> {
2242+
let timeout_ms = timeout_ms.unwrap_or(1000);
2243+
if timeout_ms < 0 {
2244+
return Err(FlussError::new_err(format!(
2245+
"timeout_ms must be non-negative, got: {timeout_ms}"
2246+
)));
2247+
}
22552248

2256-
// 1. If we already have buffered records, pop and return immediately
2257-
if let Some(record) = state.pending_records.pop_front() {
2258-
return Ok(record.into_any());
2259-
}
2249+
let scanner = Arc::clone(&self.kind);
2250+
let projected_row_type = self.projected_row_type.clone();
2251+
let timeout = Duration::from_millis(timeout_ms as u64);
22602252

2261-
// 2. Buffer is empty, we must poll the network for the next batch
2262-
// The underlying kind must be a Record-based scanner.
2263-
let scanner = match state.kind.as_record() {
2264-
Ok(s) => s,
2265-
Err(_) => {
2266-
return Err(pyo3::exceptions::PyStopAsyncIteration::new_err(
2267-
"Stream Ended",
2253+
future_into_py(py, async move {
2254+
let core_scanner = match scanner.as_ref() {
2255+
ScannerKind::Record(s) => s,
2256+
ScannerKind::Batch(_) => {
2257+
return Err(PyTypeError::new_err(
2258+
"Async iteration is only supported for record scanners; \
2259+
use create_log_scanner() instead.",
22682260
));
22692261
}
22702262
};
22712263

2272-
// Poll with a reasonable internal timeout before unblocking the event loop
2273-
let timeout = core::time::Duration::from_millis(5000);
2274-
2275-
let mut current_records = scanner
2264+
let scan_records = core_scanner
22762265
.poll(timeout)
22772266
.await
22782267
.map_err(|e| FlussError::from_core_error(&e))?;
22792268

2280-
// If it's a real timeout with zero records, loop or throw StopAsyncIteration?
2281-
// Since it's a streaming log, we can yield None or block. Blocking requires a loop in the future.
2282-
while current_records.is_empty() {
2283-
current_records = scanner
2284-
.poll(timeout)
2285-
.await
2286-
.map_err(|e| FlussError::from_core_error(&e))?;
2287-
}
2288-
2289-
// Now we have records.
2269+
// Convert to Python list
22902270
Python::attach(|py| {
2291-
for (_, records) in current_records.into_records_by_buckets() {
2271+
let mut result: Vec<Py<ScanRecord>> = Vec::new();
2272+
for (_, records) in scan_records.into_records_by_buckets() {
22922273
for core_record in records {
22932274
let scan_record =
22942275
ScanRecord::from_core(py, &core_record, &projected_row_type)?;
2295-
state.pending_records.push_back(Py::new(py, scan_record)?);
2276+
result.push(Py::new(py, scan_record)?);
22962277
}
22972278
}
2298-
2299-
// Pop the very first one to return right now
2300-
if let Some(record) = state.pending_records.pop_front() {
2301-
Ok(record.into_any())
2302-
} else {
2303-
Err(pyo3::exceptions::PyStopAsyncIteration::new_err(
2304-
"Stream Ended",
2305-
))
2306-
}
2279+
Ok(result)
23072280
})
2308-
})?;
2309-
2310-
Ok(Some(future))
2281+
})
23112282
}
23122283

23132284
fn __repr__(&self) -> String {
@@ -2324,10 +2295,7 @@ impl LogScanner {
23242295
projected_row_type: fcore::metadata::RowType,
23252296
) -> Self {
23262297
Self {
2327-
state: std::sync::Arc::new(tokio::sync::Mutex::new(ScannerState {
2328-
kind: scanner,
2329-
pending_records: std::collections::VecDeque::new(),
2330-
})),
2298+
kind: Arc::new(scanner),
23312299
admin,
23322300
table_info,
23332301
projected_schema,
@@ -2378,10 +2346,7 @@ impl LogScanner {
23782346
py: Python,
23792347
subscribed: &[(fcore::metadata::TableBucket, i64)],
23802348
) -> PyResult<HashMap<fcore::metadata::TableBucket, i64>> {
2381-
let scanner_ref =
2382-
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2383-
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2384-
let scanner = lock.kind.as_batch()?;
2349+
let scanner = self.kind.as_batch()?;
23852350
let is_partitioned = scanner.is_partitioned();
23862351
let table_path = &self.table_info.table_path;
23872352

@@ -2484,10 +2449,7 @@ impl LogScanner {
24842449
py: Python,
24852450
mut stopping_offsets: HashMap<fcore::metadata::TableBucket, i64>,
24862451
) -> PyResult<Py<PyAny>> {
2487-
let scanner_ref =
2488-
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2489-
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2490-
let scanner = lock.kind.as_batch()?;
2452+
let scanner = self.kind.as_batch()?;
24912453
let mut all_batches = Vec::new();
24922454

24932455
while !stopping_offsets.is_empty() {

0 commit comments

Comments
 (0)