Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions bindings/python/src/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@ fn schema_top_level_field_ids(schema: &PySchema) -> Vec<i32> {
.collect()
}

fn schema_top_level_field_names(schema: &PySchema) -> Vec<String> {
schema
.inner
.as_struct()
.fields()
.iter()
.map(|field| field.name.clone())
.collect()
}

fn validate_reader_projection(output_schema: &PySchema, tasks: &[FileScanTask]) -> PyResult<()> {
let output_field_ids = schema_top_level_field_ids(output_schema);
for task in tasks {
Expand All @@ -304,6 +314,24 @@ fn validate_reader_projection(output_schema: &PySchema, tasks: &[FileScanTask])
Ok(())
}

fn validate_selected_fields_match_output_schema(
output_schema: &PySchema,
selected_fields: Option<&[String]>,
) -> PyResult<()> {
let Some(selected) = selected_fields else {
return Ok(());
};

let output_names = schema_top_level_field_names(output_schema);
if output_names != selected && selected.iter().all(|name| output_names.contains(name)) {
return Err(PyValueError::new_err(format!(
"output_schema columns {:?} must match selected_fields {:?}",
output_names, selected
)));
}
Ok(())
}

fn arrow_schema_for_reader(
output_schema: &PySchema,
tasks: &[FileScanTask],
Expand Down Expand Up @@ -811,6 +839,125 @@ impl PyTable {

Ok(py_tasks)
}

#[pyo3(signature = (
output_schema,
*,
selected_fields = None,
predicate = None,
snapshot_id = None,
case_sensitive = true,
max_rows = None,
batch_size = Some(65536),
data_file_concurrency_limit = None,
concurrency_limit = None,
manifest_entry_concurrency_limit = None,
row_group_filtering_enabled = true,
row_selection_enabled = false
))]
fn read_arrow<'py>(
&self,
py: Python<'py>,
output_schema: &PySchema,
selected_fields: Option<Vec<String>>,
predicate: Option<&Bound<'_, PyPredicate>>,
snapshot_id: Option<i64>,
case_sensitive: bool,
max_rows: Option<usize>,
batch_size: Option<usize>,
data_file_concurrency_limit: Option<usize>,
concurrency_limit: Option<usize>,
manifest_entry_concurrency_limit: Option<usize>,
row_group_filtering_enabled: bool,
row_selection_enabled: bool,
) -> PyResult<Bound<'py, PyAny>> {
let mut scan_builder = self.inner.scan();

validate_selected_fields_match_output_schema(output_schema, selected_fields.as_deref())?;

if let Some(fields) = selected_fields {
if fields.is_empty() {
scan_builder = scan_builder.select_empty();
} else {
scan_builder = scan_builder.select(fields);
}
} else {
scan_builder = scan_builder.select_all();
}

if let Some(pred_bound) = predicate {
let pred = pred_bound.extract::<PyRef<'_, PyPredicate>>()?;
scan_builder = scan_builder.with_filter(pred.inner.clone());
}

if let Some(snap_id) = snapshot_id {
scan_builder = scan_builder.snapshot_id(snap_id);
}

scan_builder = scan_builder.with_case_sensitive(case_sensitive);

if let Some(limit) = concurrency_limit {
scan_builder = scan_builder.with_concurrency_limit(limit);
}

if let Some(limit) = manifest_entry_concurrency_limit {
scan_builder = scan_builder.with_manifest_entry_concurrency_limit(limit);
}

let df_concurrency = data_file_concurrency_limit;
if let Some(limit) = df_concurrency {
scan_builder = scan_builder.with_data_file_concurrency_limit(limit);
}

scan_builder = scan_builder.with_row_group_filtering_enabled(row_group_filtering_enabled);
scan_builder = scan_builder.with_row_selection_enabled(row_selection_enabled);
if let Some(bs) = batch_size {
scan_builder = scan_builder.with_batch_size(Some(bs));
}

let scan = scan_builder.build().map_err(crate::error::to_py_err)?;

let rust_tasks = py.detach(|| {
let task_stream = runtime()
.block_on(async { scan.plan_files().await })
.map_err(crate::error::to_py_err)?;

runtime()
.block_on(async { task_stream.try_collect::<Vec<FileScanTask>>().await })
.map_err(crate::error::to_py_err)
})?;

validate_reader_projection(output_schema, &rust_tasks)?;
let schema = arrow_schema_for_reader(output_schema, &rust_tasks)?;

let task_stream_for_reader = Box::pin(stream::iter(rust_tasks.into_iter().map(Ok))) as FileScanTaskStream;

let mut reader_builder = ArrowReaderBuilder::new(self.inner.file_io().clone(), iceberg_runtime())
.with_row_group_filtering_enabled(row_group_filtering_enabled)
.with_row_selection_enabled(row_selection_enabled);

if let Some(bs) = batch_size {
reader_builder = reader_builder.with_batch_size(bs);
}
if let Some(limit) = df_concurrency {
reader_builder = reader_builder.with_data_file_concurrency_limit(limit);
}

let stream = reader_builder
.build()
.read(task_stream_for_reader)
.map_err(crate::error::to_py_err)?
.stream();

let reader: Box<dyn RecordBatchReader + Send> =
Box::new(BlockingArrowRecordBatchReader {
schema,
stream,
remaining_rows: max_rows,
});

reader.into_pyarrow(py)
}
}

pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand Down
88 changes: 88 additions & 0 deletions bindings/python/tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,91 @@ def test_table_empty_table_planning():

tasks_pred = table.plan_files(predicate=Reference("id").eq(5))
assert len(tasks_pred) == 0


TABLE_METADATA_WITH_SNAPSHOT_JSON = json.dumps(
{
"format-version": 2,
"table-uuid": "fb070e82-2d1f-4ef6-8ab6-c4d12c6ed490",
"location": "s3://bucket/table",
"last-sequence-number": 1,
"last-updated-ms": 1600000000000,
"last-column-id": 2,
"schemas": [
{
"schema-id": 1,
"type": "struct",
"fields": [
{"id": 1, "name": "id", "required": True, "type": "long"},
{"id": 2, "name": "name", "required": False, "type": "string"},
],
}
],
"current-schema-id": 1,
"partition-specs": [{"spec-id": 0, "fields": []}],
"default-spec-id": 0,
"last-partition-id": 1000,
"default-sort-order-id": 0,
"sort-orders": [{"order-id": 0, "fields": []}],
"properties": {},
"current-snapshot-id": 1,
"snapshots": [
{
"snapshot-id": 1,
"timestamp-ms": 1600000000000,
"summary": {"operation": "append"},
"manifest-list": "s3://bucket/table/metadata/snap-1.avro",
"schema-id": 1
}
],
"snapshot-log": [],
"metadata-log": [],
}
)


def test_table_read_arrow():
from pyiceberg_core.scan import Table
from pyiceberg_core.expression import Reference

table = Table.from_metadata_json(
FileIO.from_props({}),
["ns", "tbl"],
TABLE_METADATA_JSON,
)

# 1. Test empty table reader
reader = table.read_arrow(schema())
assert isinstance(reader, pa.RecordBatchReader)
assert reader.schema.names == ["id", "name"]
with pytest.raises(StopIteration):
reader.read_next_batch()

reader_projected = table.read_arrow(id_schema(), selected_fields=["id"])
assert reader_projected.schema.names == ["id"]

with pytest.raises(ValueError, match="output_schema columns"):
table.read_arrow(schema(), selected_fields=["id"])

# 2. Test max_rows=0
reader_limit_0 = table.read_arrow(schema(), max_rows=0)
assert isinstance(reader_limit_0, pa.RecordBatchReader)
with pytest.raises(StopIteration):
reader_limit_0.read_next_batch()

# 3. Test invalid selected field
table_with_snapshot = Table.from_metadata_json(
FileIO.from_props({}),
["ns", "tbl"],
TABLE_METADATA_WITH_SNAPSHOT_JSON,
)
with pytest.raises(ValueError, match="Column missing not found in table"):
table_with_snapshot.read_arrow(schema(), selected_fields=["missing"])

# 4. Test filter binding
reader_filtered = table.read_arrow(schema(), predicate=Reference("id").eq(5))
assert isinstance(reader_filtered, pa.RecordBatchReader)

# 5. Test unbindable filter (column not in schema)
with pytest.raises(ValueError, match="Field missing not found in schema"):
table_with_snapshot.plan_files(predicate=Reference("missing").eq(5))