Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 61 additions & 19 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,51 +810,83 @@ def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar:


class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
"""Convert Iceberg bound expressions to PyArrow expressions.

Args:
schema: Optional Iceberg schema to resolve full field paths for nested fields.
If not provided, only the field name will be used (not dotted path).
"""

_schema: Optional[Schema]

def __init__(self, schema: Optional[Schema] = None):
self._schema = schema

def _get_field_name(self, term: BoundTerm[Any]) -> Union[str, Tuple[str, ...]]:
"""Get the field name or nested field path for a bound term.

For nested struct fields, returns a tuple of field names (e.g., ("mazeMetadata", "run_id")).
For top-level fields, returns just the field name as a string.

PyArrow requires nested field references as tuples, not dotted strings.
"""
if self._schema is not None:
# Use the schema to get the full dotted path for nested fields
full_name = self._schema.find_column_name(term.ref().field.field_id)
if full_name is not None:
# If the field name contains dots, it's a nested field
# Convert "parent.child" to ("parent", "child") for PyArrow
if "." in full_name:
return tuple(full_name.split("."))
return full_name
# Fallback to just the field name if schema is not available
return term.ref().field.name

def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
return pc.field(term.ref().field.name).isin(pyarrow_literals)
return pc.field(self._get_field_name(term)).isin(pyarrow_literals)

def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
return ~pc.field(term.ref().field.name).isin(pyarrow_literals)
return ~pc.field(self._get_field_name(term)).isin(pyarrow_literals)

def visit_is_nan(self, term: BoundTerm[Any]) -> pc.Expression:
ref = pc.field(term.ref().field.name)
ref = pc.field(self._get_field_name(term))
return pc.is_nan(ref)

def visit_not_nan(self, term: BoundTerm[Any]) -> pc.Expression:
ref = pc.field(term.ref().field.name)
ref = pc.field(self._get_field_name(term))
return ~pc.is_nan(ref)

def visit_is_null(self, term: BoundTerm[Any]) -> pc.Expression:
return pc.field(term.ref().field.name).is_null(nan_is_null=False)
return pc.field(self._get_field_name(term)).is_null(nan_is_null=False)

def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression:
return pc.field(term.ref().field.name).is_valid()
return pc.field(self._get_field_name(term)).is_valid()

def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) == _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(self._get_field_name(term)) == _convert_scalar(literal.value, term.ref().field.field_type)

def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) != _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(self._get_field_name(term)) != _convert_scalar(literal.value, term.ref().field.field_type)

def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) >= _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(self._get_field_name(term)) >= _convert_scalar(literal.value, term.ref().field.field_type)

def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) > _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(self._get_field_name(term)) > _convert_scalar(literal.value, term.ref().field.field_type)

def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) < _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(self._get_field_name(term)) < _convert_scalar(literal.value, term.ref().field.field_type)

def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) <= _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(self._get_field_name(term)) <= _convert_scalar(literal.value, term.ref().field.field_type)

def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.starts_with(pc.field(term.ref().field.name), literal.value)
return pc.starts_with(pc.field(self._get_field_name(term)), literal.value)

def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return ~pc.starts_with(pc.field(term.ref().field.name), literal.value)
return ~pc.starts_with(pc.field(self._get_field_name(term)), literal.value)

def visit_true(self) -> pc.Expression:
return pc.scalar(True)
Expand Down Expand Up @@ -990,11 +1022,21 @@ def collect(
boolean_expression_visit(expr, self)


def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
return boolean_expression_visit(expr, _ConvertToArrowExpression())
def expression_to_pyarrow(expr: BooleanExpression, schema: Optional[Schema] = None) -> pc.Expression:
"""Convert an Iceberg boolean expression to a PyArrow expression.

Args:
expr: The Iceberg boolean expression to convert.
schema: Optional Iceberg schema to resolve full field paths for nested fields.
If provided, nested struct fields will use dotted paths (e.g., "parent.child").

Returns:
A PyArrow compute expression.
"""
return boolean_expression_visit(expr, _ConvertToArrowExpression(schema))


def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expression:
def _expression_to_complementary_pyarrow(expr: BooleanExpression, schema: Optional[Schema] = None) -> pc.Expression:
"""Complementary filter conversion function of expression_to_pyarrow.

Could not use expression_to_pyarrow(Not(expr)) to achieve this complementary effect because ~ in pyarrow.compute.Expression does not handle null.
Expand All @@ -1015,7 +1057,7 @@ def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expressi
preserve_expr = Or(preserve_expr, BoundIsNull(term=term))
for term in nan_unmentioned_bound_terms:
preserve_expr = Or(preserve_expr, BoundIsNaN(term=term))
return expression_to_pyarrow(preserve_expr)
return expression_to_pyarrow(preserve_expr, schema)


@lru_cache
Expand Down Expand Up @@ -1550,7 +1592,7 @@ def _task_to_record_batches(
bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields
)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema)

file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

Expand Down
2 changes: 1 addition & 1 deletion pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def delete(
# Check if there are any files that require an actual rewrite of a data file
if delete_snapshot.rewrites_needed is True:
bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive)
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter)
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter, self.table_metadata.schema())

file_scan = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive)
if branch is not None:
Expand Down
52 changes: 52 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,58 @@ def test_ref_binding_case_insensitive_failure(table_schema_simple: Schema) -> No
ref.bind(table_schema_simple, case_sensitive=False)


def test_ref_binding_nested_struct_field() -> None:
"""Test binding references to nested struct fields (issue #953)."""
schema = Schema(
NestedField(field_id=1, name="age", field_type=IntegerType(), required=True),
NestedField(
field_id=2,
name="employment",
field_type=StructType(
NestedField(field_id=3, name="status", field_type=StringType(), required=False),
NestedField(field_id=4, name="company", field_type=StringType(), required=False),
),
required=False,
),
NestedField(
field_id=5,
name="contact",
field_type=StructType(
NestedField(field_id=6, name="email", field_type=StringType(), required=False),
),
required=False,
),
schema_id=1,
)

# Test that nested field names are in the index
assert "employment.status" in schema._name_to_id
assert "employment.company" in schema._name_to_id
assert "contact.email" in schema._name_to_id

# Test binding a reference to nested fields
ref = Reference("employment.status")
bound = ref.bind(schema, case_sensitive=True)
assert bound.field.field_id == 3
assert bound.field.name == "status"

# Test with different nested field
ref2 = Reference("contact.email")
bound2 = ref2.bind(schema, case_sensitive=True)
assert bound2.field.field_id == 6
assert bound2.field.name == "email"

# Test case-insensitive binding
ref3 = Reference("EMPLOYMENT.STATUS")
bound3 = ref3.bind(schema, case_sensitive=False)
assert bound3.field.field_id == 3

# Test that binding fails for non-existent nested field
ref4 = Reference("employment.department")
with pytest.raises(ValueError):
ref4.bind(schema, case_sensitive=True)


def test_in_to_eq() -> None:
assert In("x", (34.56,)) == EqualTo("x", 34.56)

Expand Down
3 changes: 3 additions & 0 deletions tests/expressions/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def test_with_function() -> None:
def test_nested_fields() -> None:
assert EqualTo("foo.bar", "data") == parser.parse("foo.bar = 'data'")
assert LessThan("location.x", DecimalLiteral(Decimal(52.00))) == parser.parse("location.x < 52.00")
# Test issue #953 scenario - nested struct field filtering
assert EqualTo("employment.status", "Employed") == parser.parse("employment.status = 'Employed'")
assert EqualTo("contact.email", "[email protected]") == parser.parse("contact.email = '[email protected]'")


def test_quoted_column_with_dots() -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
@pytest.fixture(autouse=True)
def clear_global_manifests_cache() -> None:
# Clear the global cache before each test
_manifests.cache_clear() # type: ignore
_manifests.cache_clear()


def _verify_metadata_with_fastavro(avro_file: str, expected_metadata: Dict[str, str]) -> None:
Expand Down