Skip to content

Support Filters on Top-Level Struct Fields #1832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dev/provision.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@
CREATE TABLE {catalog_name}.default.test_table_empty_list_and_map (
col_list array<int>,
col_map map<int, int>,
col_struct struct<test:int>,
col_list_with_struct array<struct<test:int>>
)
USING iceberg
Expand All @@ -340,8 +341,8 @@
spark.sql(
f"""
INSERT INTO {catalog_name}.default.test_table_empty_list_and_map
VALUES (null, null, null),
(array(), map(), array(struct(1)))
VALUES (null, null, null, null),
(array(), map(), struct(1), array(struct(1)))
"""
)

Expand Down
4 changes: 2 additions & 2 deletions pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,7 @@ class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]):
... 1: Accessor(position=1, inner=None),
... 5: Accessor(position=2, inner=Accessor(position=0, inner=None)),
... 6: Accessor(position=2, inner=Accessor(position=1, inner=None))
... 3: Accessor(position=2, inner=None),
... }
>>> result == expected
True
Expand All @@ -1214,8 +1215,7 @@ def struct(self, struct: StructType, field_results: List[Dict[Position, Accessor
if field_results[position]:
for inner_field_id, acc in field_results[position].items():
result[inner_field_id] = Accessor(position, inner=acc)
else:
result[field.field_id] = Accessor(position)
result[field.field_id] = Accessor(position)

return result

Expand Down
17 changes: 17 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,23 @@ def test_notnull_bind_required() -> None:
assert NotNull(Reference("a")).bind(schema) == AlwaysTrue()


def test_notnull_bind_top_struct() -> None:
schema = Schema(
NestedField(
3,
"struct_col",
required=False,
field_type=StructType(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "cost", DecimalType(38, 18), required=False),
),
),
schema_id=1,
)
bound = BoundNotNull(BoundReference(schema.find_field(3), schema.accessor_for_field(3)))
assert NotNull(Reference("struct_col")).bind(schema) == bound


def test_isnan_inverse() -> None:
assert ~IsNaN(Reference("f")) == NotNaN(Reference("f"))

Expand Down
19 changes: 19 additions & 0 deletions tests/integration/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
LessThan,
NotEqualTo,
NotNaN,
NotNull,
)
from pyiceberg.io import PYARROW_USE_LARGE_TYPES_ON_READ
from pyiceberg.io.pyarrow import (
Expand Down Expand Up @@ -668,6 +669,24 @@ def test_filter_case_insensitive(catalog: Catalog) -> None:
assert arrow_table["b"].to_pylist() == ["2"]


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_filters_on_top_level_struct(catalog: Catalog) -> None:
test_empty_struct = catalog.load_table("default.test_table_empty_list_and_map")

arrow_table = test_empty_struct.scan().to_arrow()
assert None in arrow_table["col_struct"].to_pylist()

arrow_table = test_empty_struct.scan(row_filter=NotNull("col_struct")).to_arrow()
assert arrow_table["col_struct"].to_pylist() == [{"test": 1}]

arrow_table = test_empty_struct.scan(row_filter="col_struct is not null", case_sensitive=False).to_arrow()
assert arrow_table["col_struct"].to_pylist() == [{"test": 1}]

arrow_table = test_empty_struct.scan(row_filter="COL_STRUCT is null", case_sensitive=False).to_arrow()
assert arrow_table["col_struct"].to_pylist() == [None]


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_upgrade_table_version(catalog: Catalog) -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def test_build_position_accessors(table_schema_nested: Schema) -> None:
4: Accessor(position=3, inner=None),
6: Accessor(position=4, inner=None),
11: Accessor(position=5, inner=None),
15: Accessor(position=6, inner=None),
16: Accessor(position=6, inner=Accessor(position=0, inner=None)),
17: Accessor(position=6, inner=Accessor(position=1, inner=None)),
}
Expand Down Expand Up @@ -925,7 +926,7 @@ def primitive_fields() -> List[NestedField]:
]


def test_add_top_level_primitives(primitive_fields: NestedField) -> None:
def test_add_top_level_primitives(primitive_fields: List[NestedField]) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing a typing error I noticed. The type of the fixture was incorrect.

for primitive_field in primitive_fields:
new_schema = Schema(primitive_field)
applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore
Expand Down
Loading