Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions bindings/python/src/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,15 +565,38 @@ impl PyFileScanTask {
struct BlockingArrowRecordBatchReader {
schema: ArrowSchemaRef,
stream: ArrowRecordBatchStream,
remaining_rows: Option<usize>,
}

impl Iterator for BlockingArrowRecordBatchReader {
type Item = ArrowResult<RecordBatch>;

fn next(&mut self) -> Option<Self::Item> {
runtime()
.block_on(self.stream.next())
.map(|result| result.map_err(|err| ArrowError::ExternalError(Box::new(err))))
if let Some(0) = self.remaining_rows {
return None;
}

let batch = runtime()
.block_on(self.stream.next())?
.map_err(|err| ArrowError::ExternalError(Box::new(err)));

match batch {
Ok(batch) => {
if let Some(rem) = self.remaining_rows {
if batch.num_rows() <= rem {
self.remaining_rows = Some(rem - batch.num_rows());
Some(Ok(batch))
} else {
let sliced = batch.slice(0, rem);
self.remaining_rows = Some(0);
Some(Ok(sliced))
}
} else {
Some(Ok(batch))
}
}
Err(err) => Some(Err(err)),
}
}
}

Expand Down Expand Up @@ -623,11 +646,13 @@ impl PyArrowReader {
}
}

#[pyo3(signature = (output_schema, tasks, *, max_rows = None))]
fn read<'py>(
&self,
py: Python<'py>,
output_schema: &PySchema,
tasks: &Bound<'_, PyAny>,
max_rows: Option<usize>,
) -> PyResult<Bound<'py, PyAny>> {
let rust_tasks = py_tasks_to_rust(tasks)?;
validate_reader_projection(output_schema, &rust_tasks)?;
Expand All @@ -651,7 +676,11 @@ impl PyArrowReader {
.map_err(crate::error::to_py_err)?
.stream();
let reader: Box<dyn RecordBatchReader + Send> =
Box::new(BlockingArrowRecordBatchReader { schema, stream });
Box::new(BlockingArrowRecordBatchReader {
schema,
stream,
remaining_rows: max_rows,
});

reader.into_pyarrow(py)
}
Expand Down
57 changes: 55 additions & 2 deletions bindings/python/tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def test_arrow_reader_metadata_projection_no_longer_fails():
projected_reader = reader.read(id_file_schema(), [task])
assert isinstance(projected_reader, pa.RecordBatchReader)
assert projected_reader.schema.names == ["id", "_file"]

file_field = projected_reader.schema.field("_file")
assert "run_end_encoded" in str(file_field.type) or pa.types.is_run_end_encoded(file_field.type)

Expand Down Expand Up @@ -386,7 +386,7 @@ def test_arrow_reader_partition_projection_no_longer_fails():
projected_reader = reader.read(id_schema(), [task])
assert isinstance(projected_reader, pa.RecordBatchReader)
assert projected_reader.schema.names == ["id"]

id_field = projected_reader.schema.field("id")
assert "run_end_encoded" in str(id_field.type) or pa.types.is_run_end_encoded(id_field.type)

Expand Down Expand Up @@ -423,3 +423,56 @@ def test_arrow_reader_rejects_tasks_with_different_physical_schemas():

with pytest.raises(ValueError, match="same Arrow schema"):
reader.read(id_schema(), [constant_task, plain_task])


def test_arrow_reader_max_rows_behavior():
reader = ArrowReader(FileIO.from_props({}))

# max_rows=0
batch_reader = reader.read(id_schema(), [], max_rows=0)
assert isinstance(batch_reader, pa.RecordBatchReader)
assert batch_reader.schema.names == ["id"]
with pytest.raises(StopIteration):
batch_reader.read_next_batch()

# max_rows=None
batch_reader_none = reader.read(id_schema(), [], max_rows=None)
assert isinstance(batch_reader_none, pa.RecordBatchReader)
assert batch_reader_none.schema.names == ["id"]
with pytest.raises(StopIteration):
batch_reader_none.read_next_batch()


def test_arrow_reader_with_real_parquet_and_limits(tmp_path):
import os

import pyarrow.parquet as pq

# Write a small parquet file with 5 rows
table = pa.table({"id": [1, 2, 3, 4, 5], "name": ["a", "b", "c", "d", "e"]})
local_path = str(tmp_path / "data.parquet")
pq.write_table(table, local_path)
file_path = "file://" + local_path
file_size = os.path.getsize(local_path)

reader = ArrowReader(FileIO.from_props({}))

# 1. Test reading all
task = FileScanTask(schema(), file_path, file_size, [1, 2])
batch_reader = reader.read(schema(), [task])
res_table = batch_reader.read_all()
assert len(res_table) == 5
assert res_table.column("id").to_pylist() == [1, 2, 3, 4, 5]

# 2. Test max_rows = 3
task_limit = FileScanTask(schema(), file_path, file_size, [1, 2])
batch_reader_limit = reader.read(schema(), [task_limit], max_rows=3)
res_table_limit = batch_reader_limit.read_all()
assert len(res_table_limit) == 3
assert res_table_limit.column("id").to_pylist() == [1, 2, 3]

# 3. Test max_rows = 0
task_limit_0 = FileScanTask(schema(), file_path, file_size, [1, 2])
batch_reader_limit_0 = reader.read(schema(), [task_limit_0], max_rows=0)
res_table_limit_0 = batch_reader_limit_0.read_all()
assert len(res_table_limit_0) == 0