|
20 | 20 | from abc import ABC, abstractmethod |
21 | 21 | from collections.abc import Callable, Iterable, Sequence |
22 | 22 | from functools import cached_property |
23 | | -from typing import Annotated, Any, TypeAlias |
| 23 | +from typing import Any, TypeAlias |
24 | 24 | from typing import Literal as TypingLiteral |
25 | 25 |
|
26 | | -from pydantic import ConfigDict, Discriminator, Field, Tag, model_validator |
| 26 | +from pydantic import ConfigDict, Field, SerializeAsAny, model_validator |
27 | 27 | from pydantic_core.core_schema import ValidatorFunctionWrapHandler |
28 | 28 |
|
29 | 29 | from pyiceberg.expressions.literals import AboveMax, BelowMin, Literal, literal |
@@ -73,6 +73,10 @@ def __or__(self, other: BooleanExpression) -> BooleanExpression: |
73 | 73 | @classmethod |
74 | 74 | def handle_primitive_type(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> BooleanExpression: |
75 | 75 | """Apply custom deserialization logic before validation.""" |
| 76 | + # Already a BooleanExpression? return as-is so we keep the concrete subclass. |
| 77 | + if isinstance(v, BooleanExpression): |
| 78 | + return v |
| 79 | + |
76 | 80 | # Handle different input formats |
77 | 81 | if isinstance(v, bool): |
78 | 82 | return AlwaysTrue() if v else AlwaysFalse() |
@@ -123,6 +127,9 @@ def handle_primitive_type(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> |
123 | 127 | return handler(v) |
124 | 128 |
|
125 | 129 |
|
| 130 | +SerializableBooleanExpression: TypeAlias = SerializeAsAny["BooleanExpression"] |
| 131 | + |
| 132 | + |
126 | 133 | def _build_balanced_tree( |
127 | 134 | operator_: Callable[[BooleanExpression, BooleanExpression], BooleanExpression], items: Sequence[BooleanExpression] |
128 | 135 | ) -> BooleanExpression: |
@@ -290,36 +297,14 @@ def as_bound(self) -> type[BoundReference]: |
290 | 297 | return BoundReference |
291 | 298 |
|
292 | 299 |
|
293 | | -Predicates: TypeAlias = Annotated[ |
294 | | - Annotated["IsNull", Tag("is-null")] |
295 | | - | Annotated["NotNull", Tag("not-null")] |
296 | | - | Annotated["IsNaN", Tag("is-nan")] |
297 | | - | Annotated["NotNaN", Tag("not-nan")] |
298 | | - | Annotated["EqualTo", Tag("eq")] |
299 | | - | Annotated["NotEqualTo", Tag("not-eq")] |
300 | | - | Annotated["LessThan", Tag("lt")] |
301 | | - | Annotated["LessThanOrEqual", Tag("lt-eq")] |
302 | | - | Annotated["GreaterThan", Tag("gt")] |
303 | | - | Annotated["GreaterThanOrEqual", Tag("gt-eq")] |
304 | | - | Annotated["StartsWith", Tag("starts-with")] |
305 | | - | Annotated["NotStartsWith", Tag("not-starts-with")] |
306 | | - | Annotated["In", Tag("in")] |
307 | | - | Annotated["NotIn", Tag("not-in")] |
308 | | - | Annotated["And", Tag("and")] |
309 | | - | Annotated["Or", Tag("or")] |
310 | | - | Annotated["Not", Tag("not")], |
311 | | - Discriminator("type"), |
312 | | -] |
313 | | - |
314 | | - |
315 | 300 | class And(BooleanExpression): |
316 | 301 | """AND operation expression - logical conjunction.""" |
317 | 302 |
|
318 | 303 | model_config = ConfigDict(arbitrary_types_allowed=True) |
319 | 304 |
|
320 | 305 | type: TypingLiteral["and"] = Field(default="and", alias="type") |
321 | | - left: Predicates |
322 | | - right: Predicates |
| 306 | + left: SerializableBooleanExpression = Field() |
| 307 | + right: SerializableBooleanExpression = Field() |
323 | 308 |
|
324 | 309 | def __init__(self, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression, **_: Any) -> None: |
325 | 310 | if isinstance(self, And) and not hasattr(self, "left") and not hasattr(self, "right"): |
@@ -365,8 +350,8 @@ class Or(BooleanExpression): |
365 | 350 | model_config = ConfigDict(arbitrary_types_allowed=True) |
366 | 351 |
|
367 | 352 | type: TypingLiteral["or"] = Field(default="or", alias="type") |
368 | | - left: Predicates |
369 | | - right: Predicates |
| 353 | + left: SerializableBooleanExpression = Field() |
| 354 | + right: SerializableBooleanExpression = Field() |
370 | 355 |
|
371 | 356 | def __init__(self, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> None: |
372 | 357 | if isinstance(self, Or) and not hasattr(self, "left") and not hasattr(self, "right"): |
@@ -412,7 +397,7 @@ class Not(BooleanExpression): |
412 | 397 | model_config = ConfigDict(arbitrary_types_allowed=True) |
413 | 398 |
|
414 | 399 | type: TypingLiteral["not"] = Field(default="not") |
415 | | - child: Predicates = Field() |
| 400 | + child: SerializableBooleanExpression = Field() |
416 | 401 |
|
417 | 402 | def __init__(self, child: BooleanExpression, **_: Any) -> None: |
418 | 403 | super().__init__(child=child) |
|
0 commit comments