Skip to content

Commit 768266d

Browse files
committed
feat: add async 'for' loop support to LogScanner (#424)
1 parent bb6933a commit 768266d

3 files changed

Lines changed: 188 additions & 23 deletions

File tree

bindings/python/fluss/__init__.pyi

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ class ScanRecords:
125125
def __getitem__(self, index: slice) -> List[ScanRecord]: ...
126126
@overload
127127
def __getitem__(self, bucket: TableBucket) -> List[ScanRecord]: ...
128-
def __getitem__(self, key: Union[int, slice, TableBucket]) -> Union[ScanRecord, List[ScanRecord]]: ...
128+
def __getitem__(
129+
self, key: Union[int, slice, TableBucket]
130+
) -> Union[ScanRecord, List[ScanRecord]]: ...
129131
def __contains__(self, bucket: TableBucket) -> bool: ...
130132
def __iter__(self) -> Iterator[ScanRecord]: ...
131133
def __str__(self) -> str: ...
@@ -369,7 +371,6 @@ class FlussAdmin:
369371
...
370372
def __repr__(self) -> str: ...
371373

372-
373374
class DatabaseDescriptor:
374375
"""Descriptor for a Fluss database (comment and custom properties)."""
375376

@@ -383,7 +384,6 @@ class DatabaseDescriptor:
383384
def get_custom_properties(self) -> Dict[str, str]: ...
384385
def __repr__(self) -> str: ...
385386

386-
387387
class DatabaseInfo:
388388
"""Information about a Fluss database."""
389389

@@ -604,7 +604,6 @@ class UpsertWriter:
604604
...
605605
def __repr__(self) -> str: ...
606606

607-
608607
class WriteResultHandle:
609608
"""Handle for a pending write (append/upsert/delete). Ignore for fire-and-forget, or await handle.wait() for ack."""
610609

@@ -613,7 +612,6 @@ class WriteResultHandle:
613612
...
614613
def __repr__(self) -> str: ...
615614

616-
617615
class Lookuper:
618616
"""Lookuper for performing primary key lookups on a Fluss table."""
619617

bindings/python/src/table.rs

Lines changed: 136 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ use pyo3::types::{
3030
PyDeltaAccess, PyDict, PyList, PySequence, PySlice, PyTime, PyTimeAccess, PyTuple, PyType,
3131
PyTzInfo,
3232
};
33+
use pyo3::{
34+
Bound, IntoPyObjectExt, Py, PyAny, PyClassInitializer, PyErr, PyRef, PyRefMut, PyResult, Python,
35+
};
3336
use pyo3_async_runtimes::tokio::future_into_py;
3437
use std::collections::HashMap;
3538
use std::sync::Arc;
@@ -1863,6 +1866,13 @@ enum ScannerKind {
18631866
Batch(fcore::client::RecordBatchLogScanner),
18641867
}
18651868

1869+
/// The internal state of the scanner, protected by a Tokio Mutex for async cross-thread sharing
1870+
struct ScannerState {
1871+
kind: ScannerKind,
1872+
/// A buffer to hold records polled from the network before yielding them one-by-one to Python
1873+
pending_records: std::collections::VecDeque<Py<ScanRecord>>,
1874+
}
1875+
18661876
impl ScannerKind {
18671877
fn as_record(&self) -> PyResult<&fcore::client::LogScanner> {
18681878
match self {
@@ -1901,7 +1911,7 @@ macro_rules! with_scanner {
19011911
/// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches
19021912
#[pyclass]
19031913
pub struct LogScanner {
1904-
scanner: ScannerKind,
1914+
state: Arc<tokio::sync::Mutex<ScannerState>>,
19051915
admin: fcore::client::FlussAdmin,
19061916
table_info: fcore::metadata::TableInfo,
19071917
/// The projected Arrow schema to use for empty table creation
@@ -1922,7 +1932,8 @@ impl LogScanner {
19221932
fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> {
19231933
py.detach(|| {
19241934
TOKIO_RUNTIME.block_on(async {
1925-
with_scanner!(&self.scanner, subscribe(bucket_id, start_offset))
1935+
let state = self.state.lock().await;
1936+
with_scanner!(&state.kind, subscribe(bucket_id, start_offset))
19261937
.map_err(|e| FlussError::from_core_error(&e))
19271938
})
19281939
})
@@ -1935,7 +1946,8 @@ impl LogScanner {
19351946
fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap<i32, i64>) -> PyResult<()> {
19361947
py.detach(|| {
19371948
TOKIO_RUNTIME.block_on(async {
1938-
with_scanner!(&self.scanner, subscribe_buckets(&bucket_offsets))
1949+
let state = self.state.lock().await;
1950+
with_scanner!(&state.kind, subscribe_buckets(&bucket_offsets))
19391951
.map_err(|e| FlussError::from_core_error(&e))
19401952
})
19411953
})
@@ -1956,8 +1968,9 @@ impl LogScanner {
19561968
) -> PyResult<()> {
19571969
py.detach(|| {
19581970
TOKIO_RUNTIME.block_on(async {
1971+
let state = self.state.lock().await;
19591972
with_scanner!(
1960-
&self.scanner,
1973+
&state.kind,
19611974
subscribe_partition(partition_id, bucket_id, start_offset)
19621975
)
19631976
.map_err(|e| FlussError::from_core_error(&e))
@@ -1976,8 +1989,9 @@ impl LogScanner {
19761989
) -> PyResult<()> {
19771990
py.detach(|| {
19781991
TOKIO_RUNTIME.block_on(async {
1992+
let state = self.state.lock().await;
19791993
with_scanner!(
1980-
&self.scanner,
1994+
&state.kind,
19811995
subscribe_partition_buckets(&partition_bucket_offsets)
19821996
)
19831997
.map_err(|e| FlussError::from_core_error(&e))
@@ -1992,7 +2006,8 @@ impl LogScanner {
19922006
fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> {
19932007
py.detach(|| {
19942008
TOKIO_RUNTIME.block_on(async {
1995-
with_scanner!(&self.scanner, unsubscribe(bucket_id))
2009+
let state = self.state.lock().await;
2010+
with_scanner!(&state.kind, unsubscribe(bucket_id))
19962011
.map_err(|e| FlussError::from_core_error(&e))
19972012
})
19982013
})
@@ -2006,11 +2021,9 @@ impl LogScanner {
20062021
fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> {
20072022
py.detach(|| {
20082023
TOKIO_RUNTIME.block_on(async {
2009-
with_scanner!(
2010-
&self.scanner,
2011-
unsubscribe_partition(partition_id, bucket_id)
2012-
)
2013-
.map_err(|e| FlussError::from_core_error(&e))
2024+
let state = self.state.lock().await;
2025+
with_scanner!(&state.kind, unsubscribe_partition(partition_id, bucket_id))
2026+
.map_err(|e| FlussError::from_core_error(&e))
20142027
})
20152028
})
20162029
}
@@ -2030,7 +2043,10 @@ impl LogScanner {
20302043
/// - Returns an empty ScanRecords if no records are available
20312044
/// - When timeout expires, returns an empty ScanRecords (NOT an error)
20322045
fn poll(&self, py: Python, timeout_ms: i64) -> PyResult<ScanRecords> {
2033-
let scanner = self.scanner.as_record()?;
2046+
let scanner_ref =
2047+
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2048+
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2049+
let scanner = lock.kind.as_record()?;
20342050

20352051
if timeout_ms < 0 {
20362052
return Err(FlussError::new_err(format!(
@@ -2079,7 +2095,10 @@ impl LogScanner {
20792095
/// - Returns an empty list if no batches are available
20802096
/// - When timeout expires, returns an empty list (NOT an error)
20812097
fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult<Vec<RecordBatch>> {
2082-
let scanner = self.scanner.as_batch()?;
2098+
let scanner_ref =
2099+
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2100+
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2101+
let scanner = lock.kind.as_batch()?;
20832102

20842103
if timeout_ms < 0 {
20852104
return Err(FlussError::new_err(format!(
@@ -2114,7 +2133,10 @@ impl LogScanner {
21142133
/// - Returns an empty table (with correct schema) if no records are available
21152134
/// - When timeout expires, returns an empty table (NOT an error)
21162135
fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult<Py<PyAny>> {
2117-
let scanner = self.scanner.as_batch()?;
2136+
let scanner_ref =
2137+
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2138+
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2139+
let scanner = lock.kind.as_batch()?;
21182140

21192141
if timeout_ms < 0 {
21202142
return Err(FlussError::new_err(format!(
@@ -2167,7 +2189,10 @@ impl LogScanner {
21672189
/// Returns:
21682190
/// PyArrow Table containing all data from subscribed buckets
21692191
fn to_arrow(&self, py: Python) -> PyResult<Py<PyAny>> {
2170-
let scanner = self.scanner.as_batch()?;
2192+
let scanner_ref =
2193+
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2194+
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2195+
let scanner = lock.kind.as_batch()?;
21712196
let subscribed = scanner.get_subscribed_buckets();
21722197
if subscribed.is_empty() {
21732198
return Err(FlussError::new_err(
@@ -2199,6 +2224,90 @@ impl LogScanner {
21992224
Ok(df)
22002225
}
22012226

2227+
fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult<Bound<'py, PyAny>> {
2228+
let py = slf.py();
2229+
let code = pyo3::ffi::c_str!(
2230+
r#"
2231+
async def _adapter(obj):
2232+
while True:
2233+
try:
2234+
yield await obj.__anext__()
2235+
except StopAsyncIteration:
2236+
break
2237+
"#
2238+
);
2239+
let globals = pyo3::types::PyDict::new(py);
2240+
py.run(code, Some(&globals), None)?;
2241+
let adapter = globals.get_item("_adapter")?.unwrap();
2242+
// Return adapt(self)
2243+
adapter.call1((slf.into_bound_py_any(py)?,))
2244+
}
2245+
2246+
fn __anext__<'py>(slf: PyRefMut<'py, Self>) -> PyResult<Option<Bound<'py, PyAny>>> {
2247+
let state_arc = slf.state.clone();
2248+
let projected_row_type = slf.projected_row_type.clone();
2249+
let py = slf.py();
2250+
2251+
let future = future_into_py(py, async move {
2252+
let mut state = state_arc.lock().await;
2253+
2254+
// 1. If we already have buffered records, pop and return immediately
2255+
if let Some(record) = state.pending_records.pop_front() {
2256+
return Ok(record.into_any());
2257+
}
2258+
2259+
// 2. Buffer is empty, we must poll the network for the next batch
2260+
// The underlying kind must be a Record-based scanner.
2261+
let scanner = match state.kind.as_record() {
2262+
Ok(s) => s,
2263+
Err(_) => {
2264+
return Err(pyo3::exceptions::PyStopAsyncIteration::new_err(
2265+
"Stream Ended",
2266+
));
2267+
}
2268+
};
2269+
2270+
// Poll with a reasonable internal timeout before unblocking the event loop
2271+
let timeout = core::time::Duration::from_millis(5000);
2272+
2273+
let mut current_records = scanner
2274+
.poll(timeout)
2275+
.await
2276+
.map_err(|e| FlussError::from_core_error(&e))?;
2277+
2278+
// If it's a real timeout with zero records, loop or throw StopAsyncIteration?
2279+
// Since it's a streaming log, we can yield None or block. Blocking requires a loop in the future.
2280+
while current_records.is_empty() {
2281+
current_records = scanner
2282+
.poll(timeout)
2283+
.await
2284+
.map_err(|e| FlussError::from_core_error(&e))?;
2285+
}
2286+
2287+
// Now we have records.
2288+
Python::attach(|py| {
2289+
for (_, records) in current_records.into_records_by_buckets() {
2290+
for core_record in records {
2291+
let scan_record =
2292+
ScanRecord::from_core(py, &core_record, &projected_row_type)?;
2293+
state.pending_records.push_back(Py::new(py, scan_record)?);
2294+
}
2295+
}
2296+
2297+
// Pop the very first one to return right now
2298+
if let Some(record) = state.pending_records.pop_front() {
2299+
Ok(record.into_any())
2300+
} else {
2301+
Err(pyo3::exceptions::PyStopAsyncIteration::new_err(
2302+
"Stream Ended",
2303+
))
2304+
}
2305+
})
2306+
})?;
2307+
2308+
Ok(Some(future))
2309+
}
2310+
22022311
fn __repr__(&self) -> String {
22032312
format!("LogScanner(table={})", self.table_info.table_path)
22042313
}
@@ -2213,7 +2322,10 @@ impl LogScanner {
22132322
projected_row_type: fcore::metadata::RowType,
22142323
) -> Self {
22152324
Self {
2216-
scanner,
2325+
state: std::sync::Arc::new(tokio::sync::Mutex::new(ScannerState {
2326+
kind: scanner,
2327+
pending_records: std::collections::VecDeque::new(),
2328+
})),
22172329
admin,
22182330
table_info,
22192331
projected_schema,
@@ -2264,7 +2376,10 @@ impl LogScanner {
22642376
py: Python,
22652377
subscribed: &[(fcore::metadata::TableBucket, i64)],
22662378
) -> PyResult<HashMap<fcore::metadata::TableBucket, i64>> {
2267-
let scanner = self.scanner.as_batch()?;
2379+
let scanner_ref =
2380+
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2381+
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2382+
let scanner = lock.kind.as_batch()?;
22682383
let is_partitioned = scanner.is_partitioned();
22692384
let table_path = &self.table_info.table_path;
22702385

@@ -2367,7 +2482,10 @@ impl LogScanner {
23672482
py: Python,
23682483
mut stopping_offsets: HashMap<fcore::metadata::TableBucket, i64>,
23692484
) -> PyResult<Py<PyAny>> {
2370-
let scanner = self.scanner.as_batch()?;
2485+
let scanner_ref =
2486+
unsafe { &*(&self.state as *const std::sync::Arc<tokio::sync::Mutex<ScannerState>>) };
2487+
let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await });
2488+
let scanner = lock.kind.as_batch()?;
23712489
let mut all_batches = Vec::new();
23722490

23732491
while !stopping_offsets.is_empty() {

bindings/python/test/test_log_table.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,55 @@ async def test_scan_records_indexing_and_slicing(connection, admin):
729729
await admin.drop_table(table_path, ignore_if_not_exists=False)
730730

731731

732+
async def test_async_iterator(connection, admin):
733+
"""Test the Python asynchronous iterator loop (`async for`) on LogScanner."""
734+
table_path = fluss.TablePath("fluss", "py_test_async_iterator")
735+
await admin.drop_table(table_path, ignore_if_not_exists=True)
736+
737+
schema = fluss.Schema(
738+
pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())])
739+
)
740+
await admin.create_table(table_path, fluss.TableDescriptor(schema))
741+
742+
table = await connection.get_table(table_path)
743+
writer = table.new_append().create_writer()
744+
745+
# Write 5 records
746+
writer.write_arrow_batch(
747+
pa.RecordBatch.from_arrays(
748+
[pa.array(list(range(1, 6)), type=pa.int32()),
749+
pa.array([f"async{i}" for i in range(1, 6)])],
750+
schema=pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]),
751+
)
752+
)
753+
await writer.flush()
754+
755+
scanner = await table.new_scan().create_log_scanner()
756+
num_buckets = (await admin.get_table_info(table_path)).num_buckets
757+
scanner.subscribe_buckets({i: fluss.EARLIEST_OFFSET for i in range(num_buckets)})
758+
759+
collected = []
760+
761+
# Here is the magical Issue #424 async iterator logic at work:
762+
async def consume_scanner():
763+
async for record in scanner:
764+
collected.append(record)
765+
if len(collected) == 5:
766+
break
767+
768+
# We must race the consumption against a timeout so the test doesn't hang if the iterator is broken
769+
await asyncio.wait_for(consume_scanner(), timeout=10.0)
770+
771+
assert len(collected) == 5, f"Expected 5 records, got {len(collected)}"
772+
773+
collected.sort(key=lambda r: r.row["id"])
774+
for i, record in enumerate(collected):
775+
assert record.row["id"] == i + 1
776+
assert record.row["val"] == f"async{i + 1}"
777+
778+
await admin.drop_table(table_path, ignore_if_not_exists=False)
779+
780+
732781
# ---------------------------------------------------------------------------
733782
# Helpers
734783
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)