diff --git a/bindings/python/src/scan.rs b/bindings/python/src/scan.rs index 73d754b504..6a476aa681 100644 --- a/bindings/python/src/scan.rs +++ b/bindings/python/src/scan.rs @@ -15,12 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use iceberg::expr::Bind; use iceberg::scan::{FileScanTask, FileScanTaskDeleteFile}; -use iceberg::spec::{DataContentType, DataFileFormat}; +use iceberg::spec::{ + DataContentType, DataFileFormat, Literal, NameMapping, PartitionSpec, Struct, + UnboundPartitionSpec, +}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{PyAny, PySequence}; +use pyo3::types::{PyAny, PyBool, PyBytes, PyFloat, PyInt, PySequence, PyString}; +use serde_json::{Number as JsonNumber, Value as JsonValue}; use crate::expression::PyPredicate; use crate::schema::PySchema; @@ -86,6 +92,170 @@ fn py_deletes_to_rust(values: Option<&Bound<'_, PyAny>>) -> PyResult, + schema: &PySchema, +) -> PyResult>> { + value + .map(|value| { + let spec: UnboundPartitionSpec = serde_json::from_str(value).map_err(|e| { + PyValueError::new_err(format!("Failed to parse partition_spec JSON: {e}")) + })?; + spec.bind(schema.inner.clone()) + .map(Arc::new) + .map_err(crate::error::to_py_err) + }) + .transpose() +} + +fn parse_name_mapping(value: Option<&str>) -> PyResult>> { + value + .map(|value| { + serde_json::from_str(value).map(Arc::new).map_err(|e| { + PyValueError::new_err(format!("Failed to parse name_mapping JSON: {e}")) + }) + }) + .transpose() +} + +fn bytes_to_hex(bytes: &[u8]) -> String { + const HEX: &[u8; 16] = b"0123456789abcdef"; + let mut out = String::with_capacity(bytes.len() * 2); + for byte in bytes { + out.push(HEX[(byte >> 4) as usize] as char); + out.push(HEX[(byte & 0x0f) as usize] as char); + } + out +} + +fn py_to_json_value(value: &Bound<'_, PyAny>) -> PyResult { + if value.is_none() { + return Ok(JsonValue::Null); + } + if value.is_instance_of::() { + return Ok(JsonValue::Bool(value.extract()?)); + } + if value.is_instance_of::() { + let v = value.extract::().map_err(|_| { + PyValueError::new_err(format!( + "integer {} exceeds i64 range; partition values require JSON-compatible integers", + value + .str() + .and_then(|s| s.extract::()) + .unwrap_or_else(|_| "".to_string()), + )) + })?; + return Ok(JsonValue::Number(v.into())); + } + if value.is_instance_of::() { + let v = value.extract::()?; + let number = JsonNumber::from_f64(v) + .ok_or_else(|| PyValueError::new_err("partition float values must be finite"))?; + return Ok(JsonValue::Number(number)); + } + if let Ok(v) = value.extract::() { + return Ok(JsonValue::String(v)); + } + if let Ok(bytes) = value.cast::() { + return Ok(JsonValue::String(bytes_to_hex(bytes.as_bytes()))); + } + Err(PyTypeError::new_err(format!( + "Cannot convert partition value to Iceberg JSON value: {}", + value.repr()?.to_str()? + ))) +} + +fn partition_values_to_json_array(values: &Bound<'_, PyAny>) -> PyResult> { + if values.is_instance_of::() { + match serde_json::from_str::(&values.extract::()?) { + Ok(JsonValue::Array(values)) => return Ok(values), + Ok(_) => { + return Err(PyTypeError::new_err( + "partition_data JSON string must contain an array", + )); + } + Err(e) => { + return Err(PyValueError::new_err(format!( + "Failed to parse partition_data JSON: {e}" + ))); + } + } + } + if values.is_instance_of::() { + return Err(PyTypeError::new_err( + "partition_data must be a sequence of values or a JSON array string, not bytes", + )); + } + + let seq = values.cast::().map_err(|_| { + PyTypeError::new_err("partition_data must be a sequence of values or a JSON array string") + })?; + let len = seq.len()?; + let mut out = Vec::with_capacity(len); + for i in 0..len { + out.push(py_to_json_value(&seq.get_item(i)?)?); + } + Ok(out) +} + +fn parse_partition_data( + value: Option<&Bound<'_, PyAny>>, + partition_spec: Option<&Arc>, + schema: &PySchema, +) -> PyResult> { + match (value, partition_spec) { + (None, None) => Ok(None), + (None, Some(_)) => Ok(None), + (Some(_), None) => Err(PyValueError::new_err( + "partition_spec is required when partition_data is provided", + )), + (Some(value), Some(partition_spec)) => { + let values = partition_values_to_json_array(value)?; + let partition_type = partition_spec + .partition_type(schema.inner.as_ref()) + .map_err(crate::error::to_py_err)?; + let fields = partition_type.fields(); + if values.len() != fields.len() { + return Err(PyValueError::new_err(format!( + "partition_data length {} does not match partition_spec field count {}", + values.len(), + fields.len() + ))); + } + + let literals = values + .into_iter() + .zip(fields.iter()) + .map(|(value, field)| { + Literal::try_from_json(value, &field.field_type) + .map_err(crate::error::to_py_err) + }) + .collect::>>()?; + Ok(Some(Struct::from_iter(literals))) + } + } +} + +fn validate_scan_range(file_size_in_bytes: u64, start: u64, length: Option) -> PyResult { + if start > file_size_in_bytes { + return Err(PyValueError::new_err(format!( + "start ({start}) must be less than or equal to file_size_in_bytes ({file_size_in_bytes})" + ))); + } + + let length = length.unwrap_or(file_size_in_bytes - start); + let end = start.checked_add(length).ok_or_else(|| { + PyValueError::new_err(format!("start ({start}) + length ({length}) overflows u64")) + })?; + if end > file_size_in_bytes { + return Err(PyValueError::new_err(format!( + "start ({start}) + length ({length}) must be less than or equal to file_size_in_bytes ({file_size_in_bytes})" + ))); + } + + Ok(length) +} + #[pyclass( name = "DeleteFile", module = "pyiceberg_core.scan", @@ -185,6 +355,9 @@ impl PyFileScanTask { data_file_format = "parquet", predicate = None, deletes = None, + partition_data = None, + partition_spec = None, + name_mapping = None, case_sensitive = true ))] fn new( @@ -198,13 +371,21 @@ impl PyFileScanTask { data_file_format: &str, predicate: Option<&PyPredicate>, deletes: Option<&Bound<'_, PyAny>>, + partition_data: Option<&Bound<'_, PyAny>>, + partition_spec: Option<&str>, + name_mapping: Option<&str>, case_sensitive: bool, ) -> PyResult { + let partition_spec = parse_partition_spec(partition_spec, schema)?; + let partition = parse_partition_data(partition_data, partition_spec.as_ref(), schema)?; + let name_mapping = parse_name_mapping(name_mapping)?; + let length = validate_scan_range(file_size_in_bytes, start, length)?; + Ok(Self { inner: FileScanTask { file_size_in_bytes, start, - length: length.unwrap_or(file_size_in_bytes), + length, record_count, data_file_path, data_file_format: parse_data_file_format(data_file_format)?, @@ -215,9 +396,9 @@ impl PyFileScanTask { .transpose() .map_err(crate::error::to_py_err)?, deletes: py_deletes_to_rust(deletes)?, - partition: None, - partition_spec: None, - name_mapping: None, + partition, + partition_spec, + name_mapping, case_sensitive, }, }) @@ -268,6 +449,21 @@ impl PyFileScanTask { self.inner.predicate.is_some() } + #[getter] + fn has_partition_data(&self) -> bool { + self.inner.partition.is_some() + } + + #[getter] + fn has_partition_spec(&self) -> bool { + self.inner.partition_spec.is_some() + } + + #[getter] + fn has_name_mapping(&self) -> bool { + self.inner.name_mapping.is_some() + } + #[getter] fn case_sensitive(&self) -> bool { self.inner.case_sensitive diff --git a/bindings/python/tests/test_scan.py b/bindings/python/tests/test_scan.py index f97bc29eff..39d5335596 100644 --- a/bindings/python/tests/test_scan.py +++ b/bindings/python/tests/test_scan.py @@ -94,6 +94,19 @@ def test_file_scan_task_properties_without_deletes(): assert task.case_sensitive is True +def test_file_scan_task_default_length_uses_remaining_file_size(): + task = FileScanTask( + schema(), + "s3://bucket/data.parquet", + 1024, + [1], + start=128, + ) + + assert task.start == 128 + assert task.length == 896 + + def test_file_scan_task_binds_predicate_and_deletes(): delete = DeleteFile("s3://bucket/delete.parquet", 128, "position-deletes") task = FileScanTask( @@ -112,6 +125,136 @@ def test_file_scan_task_binds_predicate_and_deletes(): assert task.has_predicate is True +def test_file_scan_task_accepts_partition_context_and_name_mapping(): + partition_spec = json.dumps( + { + "spec-id": 1, + "fields": [ + { + "source-id": 1, + "field-id": 1000, + "name": "id", + "transform": "identity", + } + ], + } + ) + name_mapping = json.dumps([{"field-id": 1, "names": ["id", "record_id"]}]) + + task = FileScanTask( + schema(), + "s3://bucket/data.parquet", + 1024, + [1], + partition_data=[7], + partition_spec=partition_spec, + name_mapping=name_mapping, + ) + + assert task.has_partition_data is True + assert task.has_partition_spec is True + assert task.has_name_mapping is True + + +def test_file_scan_task_accepts_partition_data_json_array(): + partition_spec = json.dumps( + { + "fields": [ + { + "source-id": 2, + "field-id": 1000, + "name": "name", + "transform": "identity", + } + ] + } + ) + + task = FileScanTask( + schema(), + "s3://bucket/data.parquet", + 1024, + [2], + partition_data=json.dumps(["alice"]), + partition_spec=partition_spec, + ) + + assert task.has_partition_data is True + assert task.has_partition_spec is True + assert task.has_name_mapping is False + + +def test_file_scan_task_rejects_partition_data_without_spec(): + with pytest.raises(ValueError, match="partition_spec is required"): + FileScanTask( + schema(), + "s3://bucket/data.parquet", + 1024, + [1], + partition_data=[7], + ) + + +def test_file_scan_task_rejects_partition_data_length_mismatch(): + partition_spec = json.dumps( + { + "fields": [ + { + "source-id": 1, + "field-id": 1000, + "name": "id", + "transform": "identity", + } + ] + } + ) + + with pytest.raises(ValueError, match="partition_data length"): + FileScanTask( + schema(), + "s3://bucket/data.parquet", + 1024, + [1], + partition_data=[], + partition_spec=partition_spec, + ) + + +def test_file_scan_task_rejects_start_beyond_file_size(): + with pytest.raises(ValueError, match="start \\(1025\\).*file_size_in_bytes \\(1024\\)"): + FileScanTask( + schema(), + "s3://bucket/data.parquet", + 1024, + [1], + start=1025, + ) + + +def test_file_scan_task_rejects_length_beyond_file_size(): + with pytest.raises(ValueError, match="start \\(128\\) \\+ length \\(897\\)"): + FileScanTask( + schema(), + "s3://bucket/data.parquet", + 1024, + [1], + start=128, + length=897, + ) + + +def test_file_scan_task_rejects_start_plus_length_overflow(): + with pytest.raises(ValueError, match="overflows u64"): + FileScanTask( + schema(), + "s3://bucket/data.parquet", + 2**64 - 1, + [1], + start=2**64 - 2, + length=2, + ) + + def test_file_scan_task_rejects_unbindable_predicate(): with pytest.raises(ValueError, match="Field missing not found in schema"): FileScanTask(