Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
393 changes: 192 additions & 201 deletions pyiceberg/expressions/__init__.py

Large diffs are not rendered by default.

276 changes: 138 additions & 138 deletions pyiceberg/expressions/visitors.py

Large diffs are not rendered by default.

78 changes: 39 additions & 39 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
def __init__(self, schema: Schema | None = None):
self._schema = schema

def _get_field_name(self, term: BoundTerm[Any]) -> str | tuple[str, ...]:
def _get_field_name(self, term: BoundTerm) -> str | tuple[str, ...]:
"""Get the field name or nested field path for a bound term.

For nested struct fields, returns a tuple of field names (e.g., ("mazeMetadata", "run_id")).
Expand All @@ -837,50 +837,50 @@ def _get_field_name(self, term: BoundTerm[Any]) -> str | tuple[str, ...]:
# Fallback to just the field name if schema is not available
return term.ref().field.name

def visit_in(self, term: BoundTerm[Any], literals: set[Any]) -> pc.Expression:
def visit_in(self, term: BoundTerm, literals: set[Any]) -> pc.Expression:
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
return pc.field(self._get_field_name(term)).isin(pyarrow_literals)

def visit_not_in(self, term: BoundTerm[Any], literals: set[Any]) -> pc.Expression:
def visit_not_in(self, term: BoundTerm, literals: set[Any]) -> pc.Expression:
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
return ~pc.field(self._get_field_name(term)).isin(pyarrow_literals)

def visit_is_nan(self, term: BoundTerm[Any]) -> pc.Expression:
def visit_is_nan(self, term: BoundTerm) -> pc.Expression:
ref = pc.field(self._get_field_name(term))
return pc.is_nan(ref)

def visit_not_nan(self, term: BoundTerm[Any]) -> pc.Expression:
def visit_not_nan(self, term: BoundTerm) -> pc.Expression:
ref = pc.field(self._get_field_name(term))
return ~pc.is_nan(ref)

def visit_is_null(self, term: BoundTerm[Any]) -> pc.Expression:
def visit_is_null(self, term: BoundTerm) -> pc.Expression:
return pc.field(self._get_field_name(term)).is_null(nan_is_null=False)

def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression:
def visit_not_null(self, term: BoundTerm) -> pc.Expression:
return pc.field(self._get_field_name(term)).is_valid()

def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
def visit_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression:
return pc.field(self._get_field_name(term)) == _convert_scalar(literal.value, term.ref().field.field_type)

def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
def visit_not_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression:
return pc.field(self._get_field_name(term)) != _convert_scalar(literal.value, term.ref().field.field_type)

def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression:
return pc.field(self._get_field_name(term)) >= _convert_scalar(literal.value, term.ref().field.field_type)

def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
def visit_greater_than(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression:
return pc.field(self._get_field_name(term)) > _convert_scalar(literal.value, term.ref().field.field_type)

def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
def visit_less_than(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression:
return pc.field(self._get_field_name(term)) < _convert_scalar(literal.value, term.ref().field.field_type)

def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression:
return pc.field(self._get_field_name(term)) <= _convert_scalar(literal.value, term.ref().field.field_type)

def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
def visit_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression:
return pc.starts_with(pc.field(self._get_field_name(term)), literal.value)

def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
def visit_not_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression:
return ~pc.starts_with(pc.field(self._get_field_name(term)), literal.value)

def visit_true(self) -> pc.Expression:
Expand All @@ -901,13 +901,13 @@ def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> p

class _NullNaNUnmentionedTermsCollector(BoundBooleanExpressionVisitor[None]):
# BoundTerms which have either is_null or is_not_null appearing at least once in the boolean expr.
is_null_or_not_bound_terms: set[BoundTerm[Any]]
is_null_or_not_bound_terms: set[BoundTerm]
# The remaining BoundTerms appearing in the boolean expr.
null_unmentioned_bound_terms: set[BoundTerm[Any]]
null_unmentioned_bound_terms: set[BoundTerm]
# BoundTerms which have either is_nan or is_not_nan appearing at least once in the boolean expr.
is_nan_or_not_bound_terms: set[BoundTerm[Any]]
is_nan_or_not_bound_terms: set[BoundTerm]
# The remaining BoundTerms appearing in the boolean expr.
nan_unmentioned_bound_terms: set[BoundTerm[Any]]
nan_unmentioned_bound_terms: set[BoundTerm]

def __init__(self) -> None:
super().__init__()
Expand All @@ -916,81 +916,81 @@ def __init__(self) -> None:
self.is_nan_or_not_bound_terms = set()
self.nan_unmentioned_bound_terms = set()

def _handle_explicit_is_null_or_not(self, term: BoundTerm[Any]) -> None:
def _handle_explicit_is_null_or_not(self, term: BoundTerm) -> None:
"""Handle the predicate case where either is_null or is_not_null is included."""
if term in self.null_unmentioned_bound_terms:
self.null_unmentioned_bound_terms.remove(term)
self.is_null_or_not_bound_terms.add(term)

def _handle_null_unmentioned(self, term: BoundTerm[Any]) -> None:
def _handle_null_unmentioned(self, term: BoundTerm) -> None:
"""Handle the predicate case where neither is_null or is_not_null is included."""
if term not in self.is_null_or_not_bound_terms:
self.null_unmentioned_bound_terms.add(term)

def _handle_explicit_is_nan_or_not(self, term: BoundTerm[Any]) -> None:
def _handle_explicit_is_nan_or_not(self, term: BoundTerm) -> None:
"""Handle the predicate case where either is_nan or is_not_nan is included."""
if term in self.nan_unmentioned_bound_terms:
self.nan_unmentioned_bound_terms.remove(term)
self.is_nan_or_not_bound_terms.add(term)

def _handle_nan_unmentioned(self, term: BoundTerm[Any]) -> None:
def _handle_nan_unmentioned(self, term: BoundTerm) -> None:
"""Handle the predicate case where neither is_nan or is_not_nan is included."""
if term not in self.is_nan_or_not_bound_terms:
self.nan_unmentioned_bound_terms.add(term)

def visit_in(self, term: BoundTerm[Any], literals: set[Any]) -> None:
def visit_in(self, term: BoundTerm, literals: set[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_not_in(self, term: BoundTerm[Any], literals: set[Any]) -> None:
def visit_not_in(self, term: BoundTerm, literals: set[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_is_nan(self, term: BoundTerm[Any]) -> None:
def visit_is_nan(self, term: BoundTerm) -> None:
self._handle_null_unmentioned(term)
self._handle_explicit_is_nan_or_not(term)

def visit_not_nan(self, term: BoundTerm[Any]) -> None:
def visit_not_nan(self, term: BoundTerm) -> None:
self._handle_null_unmentioned(term)
self._handle_explicit_is_nan_or_not(term)

def visit_is_null(self, term: BoundTerm[Any]) -> None:
def visit_is_null(self, term: BoundTerm) -> None:
self._handle_explicit_is_null_or_not(term)
self._handle_nan_unmentioned(term)

def visit_not_null(self, term: BoundTerm[Any]) -> None:
def visit_not_null(self, term: BoundTerm) -> None:
self._handle_explicit_is_null_or_not(term)
self._handle_nan_unmentioned(term)

def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
def visit_equal(self, term: BoundTerm, literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
def visit_not_equal(self, term: BoundTerm, literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
def visit_greater_than(self, term: BoundTerm, literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
def visit_less_than(self, term: BoundTerm, literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
def visit_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
def visit_not_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

Expand Down Expand Up @@ -1040,10 +1040,10 @@ def _expression_to_complementary_pyarrow(expr: BooleanExpression, schema: Schema
collector.collect(expr)

# Convert the set of terms to a sorted list so that layout of the expression to build is deterministic.
null_unmentioned_bound_terms: list[BoundTerm[Any]] = sorted(
null_unmentioned_bound_terms: list[BoundTerm] = sorted(
collector.null_unmentioned_bound_terms, key=lambda term: term.ref().field.name
)
nan_unmentioned_bound_terms: list[BoundTerm[Any]] = sorted(
nan_unmentioned_bound_terms: list[BoundTerm] = sorted(
collector.nan_unmentioned_bound_terms, key=lambda term: term.ref().field.name
)

Expand Down
Loading