From 5d38ae2b97f0d850172e7c3c9be9885400dcfade Mon Sep 17 00:00:00 2001 From: Abanoub Doss Date: Sun, 24 May 2026 20:38:43 -0500 Subject: [PATCH] feat(python): support arrow reader row limits --- bindings/python/src/scan.rs | 37 ++++++++++++++++--- bindings/python/tests/test_scan.py | 57 ++++++++++++++++++++++++++++-- 2 files changed, 88 insertions(+), 6 deletions(-) diff --git a/bindings/python/src/scan.rs b/bindings/python/src/scan.rs index 20eeed8597..61c1e42f40 100644 --- a/bindings/python/src/scan.rs +++ b/bindings/python/src/scan.rs @@ -565,15 +565,38 @@ impl PyFileScanTask { struct BlockingArrowRecordBatchReader { schema: ArrowSchemaRef, stream: ArrowRecordBatchStream, + remaining_rows: Option, } impl Iterator for BlockingArrowRecordBatchReader { type Item = ArrowResult; fn next(&mut self) -> Option { - runtime() - .block_on(self.stream.next()) - .map(|result| result.map_err(|err| ArrowError::ExternalError(Box::new(err)))) + if let Some(0) = self.remaining_rows { + return None; + } + + let batch = runtime() + .block_on(self.stream.next())? + .map_err(|err| ArrowError::ExternalError(Box::new(err))); + + match batch { + Ok(batch) => { + if let Some(rem) = self.remaining_rows { + if batch.num_rows() <= rem { + self.remaining_rows = Some(rem - batch.num_rows()); + Some(Ok(batch)) + } else { + let sliced = batch.slice(0, rem); + self.remaining_rows = Some(0); + Some(Ok(sliced)) + } + } else { + Some(Ok(batch)) + } + } + Err(err) => Some(Err(err)), + } } } @@ -623,11 +646,13 @@ impl PyArrowReader { } } + #[pyo3(signature = (output_schema, tasks, *, max_rows = None))] fn read<'py>( &self, py: Python<'py>, output_schema: &PySchema, tasks: &Bound<'_, PyAny>, + max_rows: Option, ) -> PyResult> { let rust_tasks = py_tasks_to_rust(tasks)?; validate_reader_projection(output_schema, &rust_tasks)?; @@ -651,7 +676,11 @@ impl PyArrowReader { .map_err(crate::error::to_py_err)? .stream(); let reader: Box = - Box::new(BlockingArrowRecordBatchReader { schema, stream }); + Box::new(BlockingArrowRecordBatchReader { + schema, + stream, + remaining_rows: max_rows, + }); reader.into_pyarrow(py) } diff --git a/bindings/python/tests/test_scan.py b/bindings/python/tests/test_scan.py index 9ef58c7e2b..760acb0d13 100644 --- a/bindings/python/tests/test_scan.py +++ b/bindings/python/tests/test_scan.py @@ -347,7 +347,7 @@ def test_arrow_reader_metadata_projection_no_longer_fails(): projected_reader = reader.read(id_file_schema(), [task]) assert isinstance(projected_reader, pa.RecordBatchReader) assert projected_reader.schema.names == ["id", "_file"] - + file_field = projected_reader.schema.field("_file") assert "run_end_encoded" in str(file_field.type) or pa.types.is_run_end_encoded(file_field.type) @@ -386,7 +386,7 @@ def test_arrow_reader_partition_projection_no_longer_fails(): projected_reader = reader.read(id_schema(), [task]) assert isinstance(projected_reader, pa.RecordBatchReader) assert projected_reader.schema.names == ["id"] - + id_field = projected_reader.schema.field("id") assert "run_end_encoded" in str(id_field.type) or pa.types.is_run_end_encoded(id_field.type) @@ -423,3 +423,56 @@ def test_arrow_reader_rejects_tasks_with_different_physical_schemas(): with pytest.raises(ValueError, match="same Arrow schema"): reader.read(id_schema(), [constant_task, plain_task]) + + +def test_arrow_reader_max_rows_behavior(): + reader = ArrowReader(FileIO.from_props({})) + + # max_rows=0 + batch_reader = reader.read(id_schema(), [], max_rows=0) + assert isinstance(batch_reader, pa.RecordBatchReader) + assert batch_reader.schema.names == ["id"] + with pytest.raises(StopIteration): + batch_reader.read_next_batch() + + # max_rows=None + batch_reader_none = reader.read(id_schema(), [], max_rows=None) + assert isinstance(batch_reader_none, pa.RecordBatchReader) + assert batch_reader_none.schema.names == ["id"] + with pytest.raises(StopIteration): + batch_reader_none.read_next_batch() + + +def test_arrow_reader_with_real_parquet_and_limits(tmp_path): + import os + + import pyarrow.parquet as pq + + # Write a small parquet file with 5 rows + table = pa.table({"id": [1, 2, 3, 4, 5], "name": ["a", "b", "c", "d", "e"]}) + local_path = str(tmp_path / "data.parquet") + pq.write_table(table, local_path) + file_path = "file://" + local_path + file_size = os.path.getsize(local_path) + + reader = ArrowReader(FileIO.from_props({})) + + # 1. Test reading all + task = FileScanTask(schema(), file_path, file_size, [1, 2]) + batch_reader = reader.read(schema(), [task]) + res_table = batch_reader.read_all() + assert len(res_table) == 5 + assert res_table.column("id").to_pylist() == [1, 2, 3, 4, 5] + + # 2. Test max_rows = 3 + task_limit = FileScanTask(schema(), file_path, file_size, [1, 2]) + batch_reader_limit = reader.read(schema(), [task_limit], max_rows=3) + res_table_limit = batch_reader_limit.read_all() + assert len(res_table_limit) == 3 + assert res_table_limit.column("id").to_pylist() == [1, 2, 3] + + # 3. Test max_rows = 0 + task_limit_0 = FileScanTask(schema(), file_path, file_size, [1, 2]) + batch_reader_limit_0 = reader.read(schema(), [task_limit_0], max_rows=0) + res_table_limit_0 = batch_reader_limit_0.read_all() + assert len(res_table_limit_0) == 0