diff --git a/bindings/python/src/scan.rs b/bindings/python/src/scan.rs index a3998aea84..bb01b84186 100644 --- a/bindings/python/src/scan.rs +++ b/bindings/python/src/scan.rs @@ -291,6 +291,16 @@ fn schema_top_level_field_ids(schema: &PySchema) -> Vec { .collect() } +fn schema_top_level_field_names(schema: &PySchema) -> Vec { + schema + .inner + .as_struct() + .fields() + .iter() + .map(|field| field.name.clone()) + .collect() +} + fn validate_reader_projection(output_schema: &PySchema, tasks: &[FileScanTask]) -> PyResult<()> { let output_field_ids = schema_top_level_field_ids(output_schema); for task in tasks { @@ -304,6 +314,24 @@ fn validate_reader_projection(output_schema: &PySchema, tasks: &[FileScanTask]) Ok(()) } +fn validate_selected_fields_match_output_schema( + output_schema: &PySchema, + selected_fields: Option<&[String]>, +) -> PyResult<()> { + let Some(selected) = selected_fields else { + return Ok(()); + }; + + let output_names = schema_top_level_field_names(output_schema); + if output_names != selected && selected.iter().all(|name| output_names.contains(name)) { + return Err(PyValueError::new_err(format!( + "output_schema columns {:?} must match selected_fields {:?}", + output_names, selected + ))); + } + Ok(()) +} + fn arrow_schema_for_reader( output_schema: &PySchema, tasks: &[FileScanTask], @@ -811,6 +839,125 @@ impl PyTable { Ok(py_tasks) } + + #[pyo3(signature = ( + output_schema, + *, + selected_fields = None, + predicate = None, + snapshot_id = None, + case_sensitive = true, + max_rows = None, + batch_size = Some(65536), + data_file_concurrency_limit = None, + concurrency_limit = None, + manifest_entry_concurrency_limit = None, + row_group_filtering_enabled = true, + row_selection_enabled = false + ))] + fn read_arrow<'py>( + &self, + py: Python<'py>, + output_schema: &PySchema, + selected_fields: Option>, + predicate: Option<&Bound<'_, PyPredicate>>, + snapshot_id: Option, + case_sensitive: bool, + max_rows: Option, + batch_size: Option, + data_file_concurrency_limit: Option, + concurrency_limit: Option, + manifest_entry_concurrency_limit: Option, + row_group_filtering_enabled: bool, + row_selection_enabled: bool, + ) -> PyResult> { + let mut scan_builder = self.inner.scan(); + + validate_selected_fields_match_output_schema(output_schema, selected_fields.as_deref())?; + + if let Some(fields) = selected_fields { + if fields.is_empty() { + scan_builder = scan_builder.select_empty(); + } else { + scan_builder = scan_builder.select(fields); + } + } else { + scan_builder = scan_builder.select_all(); + } + + if let Some(pred_bound) = predicate { + let pred = pred_bound.extract::>()?; + scan_builder = scan_builder.with_filter(pred.inner.clone()); + } + + if let Some(snap_id) = snapshot_id { + scan_builder = scan_builder.snapshot_id(snap_id); + } + + scan_builder = scan_builder.with_case_sensitive(case_sensitive); + + if let Some(limit) = concurrency_limit { + scan_builder = scan_builder.with_concurrency_limit(limit); + } + + if let Some(limit) = manifest_entry_concurrency_limit { + scan_builder = scan_builder.with_manifest_entry_concurrency_limit(limit); + } + + let df_concurrency = data_file_concurrency_limit; + if let Some(limit) = df_concurrency { + scan_builder = scan_builder.with_data_file_concurrency_limit(limit); + } + + scan_builder = scan_builder.with_row_group_filtering_enabled(row_group_filtering_enabled); + scan_builder = scan_builder.with_row_selection_enabled(row_selection_enabled); + if let Some(bs) = batch_size { + scan_builder = scan_builder.with_batch_size(Some(bs)); + } + + let scan = scan_builder.build().map_err(crate::error::to_py_err)?; + + let rust_tasks = py.detach(|| { + let task_stream = runtime() + .block_on(async { scan.plan_files().await }) + .map_err(crate::error::to_py_err)?; + + runtime() + .block_on(async { task_stream.try_collect::>().await }) + .map_err(crate::error::to_py_err) + })?; + + validate_reader_projection(output_schema, &rust_tasks)?; + let schema = arrow_schema_for_reader(output_schema, &rust_tasks)?; + + let task_stream_for_reader = Box::pin(stream::iter(rust_tasks.into_iter().map(Ok))) as FileScanTaskStream; + + let mut reader_builder = ArrowReaderBuilder::new(self.inner.file_io().clone(), iceberg_runtime()) + .with_row_group_filtering_enabled(row_group_filtering_enabled) + .with_row_selection_enabled(row_selection_enabled); + + if let Some(bs) = batch_size { + reader_builder = reader_builder.with_batch_size(bs); + } + if let Some(limit) = df_concurrency { + reader_builder = reader_builder.with_data_file_concurrency_limit(limit); + } + + let stream = reader_builder + .build() + .read(task_stream_for_reader) + .map_err(crate::error::to_py_err)? + .stream(); + + let reader: Box = + Box::new(BlockingArrowRecordBatchReader { + schema, + stream, + remaining_rows: max_rows, + }); + + reader.into_pyarrow(py) + } } pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { diff --git a/bindings/python/tests/test_scan.py b/bindings/python/tests/test_scan.py index 3663a2b3b7..2f77d72ff0 100644 --- a/bindings/python/tests/test_scan.py +++ b/bindings/python/tests/test_scan.py @@ -550,3 +550,91 @@ def test_table_empty_table_planning(): tasks_pred = table.plan_files(predicate=Reference("id").eq(5)) assert len(tasks_pred) == 0 + + +TABLE_METADATA_WITH_SNAPSHOT_JSON = json.dumps( + { + "format-version": 2, + "table-uuid": "fb070e82-2d1f-4ef6-8ab6-c4d12c6ed490", + "location": "s3://bucket/table", + "last-sequence-number": 1, + "last-updated-ms": 1600000000000, + "last-column-id": 2, + "schemas": [ + { + "schema-id": 1, + "type": "struct", + "fields": [ + {"id": 1, "name": "id", "required": True, "type": "long"}, + {"id": 2, "name": "name", "required": False, "type": "string"}, + ], + } + ], + "current-schema-id": 1, + "partition-specs": [{"spec-id": 0, "fields": []}], + "default-spec-id": 0, + "last-partition-id": 1000, + "default-sort-order-id": 0, + "sort-orders": [{"order-id": 0, "fields": []}], + "properties": {}, + "current-snapshot-id": 1, + "snapshots": [ + { + "snapshot-id": 1, + "timestamp-ms": 1600000000000, + "summary": {"operation": "append"}, + "manifest-list": "s3://bucket/table/metadata/snap-1.avro", + "schema-id": 1 + } + ], + "snapshot-log": [], + "metadata-log": [], + } +) + + +def test_table_read_arrow(): + from pyiceberg_core.scan import Table + from pyiceberg_core.expression import Reference + + table = Table.from_metadata_json( + FileIO.from_props({}), + ["ns", "tbl"], + TABLE_METADATA_JSON, + ) + + # 1. Test empty table reader + reader = table.read_arrow(schema()) + assert isinstance(reader, pa.RecordBatchReader) + assert reader.schema.names == ["id", "name"] + with pytest.raises(StopIteration): + reader.read_next_batch() + + reader_projected = table.read_arrow(id_schema(), selected_fields=["id"]) + assert reader_projected.schema.names == ["id"] + + with pytest.raises(ValueError, match="output_schema columns"): + table.read_arrow(schema(), selected_fields=["id"]) + + # 2. Test max_rows=0 + reader_limit_0 = table.read_arrow(schema(), max_rows=0) + assert isinstance(reader_limit_0, pa.RecordBatchReader) + with pytest.raises(StopIteration): + reader_limit_0.read_next_batch() + + # 3. Test invalid selected field + table_with_snapshot = Table.from_metadata_json( + FileIO.from_props({}), + ["ns", "tbl"], + TABLE_METADATA_WITH_SNAPSHOT_JSON, + ) + with pytest.raises(ValueError, match="Column missing not found in table"): + table_with_snapshot.read_arrow(schema(), selected_fields=["missing"]) + + # 4. Test filter binding + reader_filtered = table.read_arrow(schema(), predicate=Reference("id").eq(5)) + assert isinstance(reader_filtered, pa.RecordBatchReader) + + # 5. Test unbindable filter (column not in schema) + with pytest.raises(ValueError, match="Field missing not found in schema"): + table_with_snapshot.plan_files(predicate=Reference("missing").eq(5))