diff --git a/dev/provision.py b/dev/provision.py index b358da6593..5b198ce94b 100644 --- a/dev/provision.py +++ b/dev/provision.py @@ -328,6 +328,7 @@ CREATE TABLE {catalog_name}.default.test_table_empty_list_and_map ( col_list array, col_map map, + col_struct struct, col_list_with_struct array> ) USING iceberg @@ -340,8 +341,8 @@ spark.sql( f""" INSERT INTO {catalog_name}.default.test_table_empty_list_and_map - VALUES (null, null, null), - (array(), map(), array(struct(1))) + VALUES (null, null, null, null), + (array(), map(), struct(1), array(struct(1))) """ ) diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index 7f6cfe9987..da25a42116 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1199,6 +1199,7 @@ class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]): ... 1: Accessor(position=1, inner=None), ... 5: Accessor(position=2, inner=Accessor(position=0, inner=None)), ... 6: Accessor(position=2, inner=Accessor(position=1, inner=None)) + ... 3: Accessor(position=2, inner=None), ... } >>> result == expected True @@ -1214,8 +1215,7 @@ def struct(self, struct: StructType, field_results: List[Dict[Position, Accessor if field_results[position]: for inner_field_id, acc in field_results[position].items(): result[inner_field_id] = Accessor(position, inner=acc) - else: - result[field.field_id] = Accessor(position) + result[field.field_id] = Accessor(position) return result diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 4926b70121..12d9ff95a9 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -168,6 +168,23 @@ def test_notnull_bind_required() -> None: assert NotNull(Reference("a")).bind(schema) == AlwaysTrue() +def test_notnull_bind_top_struct() -> None: + schema = Schema( + NestedField( + 3, + "struct_col", + required=False, + field_type=StructType( + NestedField(1, "id", IntegerType(), required=True), + NestedField(2, "cost", DecimalType(38, 18), required=False), + ), + ), + schema_id=1, + ) + bound = BoundNotNull(BoundReference(schema.find_field(3), schema.accessor_for_field(3))) + assert NotNull(Reference("struct_col")).bind(schema) == bound + + def test_isnan_inverse() -> None: assert ~IsNaN(Reference("f")) == NotNaN(Reference("f")) diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 1eb3500a17..fd6a87fb46 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -41,6 +41,7 @@ LessThan, NotEqualTo, NotNaN, + NotNull, ) from pyiceberg.io import PYARROW_USE_LARGE_TYPES_ON_READ from pyiceberg.io.pyarrow import ( @@ -668,6 +669,24 @@ def test_filter_case_insensitive(catalog: Catalog) -> None: assert arrow_table["b"].to_pylist() == ["2"] +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_filters_on_top_level_struct(catalog: Catalog) -> None: + test_empty_struct = catalog.load_table("default.test_table_empty_list_and_map") + + arrow_table = test_empty_struct.scan().to_arrow() + assert None in arrow_table["col_struct"].to_pylist() + + arrow_table = test_empty_struct.scan(row_filter=NotNull("col_struct")).to_arrow() + assert arrow_table["col_struct"].to_pylist() == [{"test": 1}] + + arrow_table = test_empty_struct.scan(row_filter="col_struct is not null", case_sensitive=False).to_arrow() + assert arrow_table["col_struct"].to_pylist() == [{"test": 1}] + + arrow_table = test_empty_struct.scan(row_filter="COL_STRUCT is null", case_sensitive=False).to_arrow() + assert arrow_table["col_struct"].to_pylist() == [None] + + @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_upgrade_table_version(catalog: Catalog) -> None: diff --git a/tests/test_schema.py b/tests/test_schema.py index daa46dee1f..61c64e71fc 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -398,6 +398,7 @@ def test_build_position_accessors(table_schema_nested: Schema) -> None: 4: Accessor(position=3, inner=None), 6: Accessor(position=4, inner=None), 11: Accessor(position=5, inner=None), + 15: Accessor(position=6, inner=None), 16: Accessor(position=6, inner=Accessor(position=0, inner=None)), 17: Accessor(position=6, inner=Accessor(position=1, inner=None)), } @@ -925,7 +926,7 @@ def primitive_fields() -> List[NestedField]: ] -def test_add_top_level_primitives(primitive_fields: NestedField) -> None: +def test_add_top_level_primitives(primitive_fields: List[NestedField]) -> None: for primitive_field in primitive_fields: new_schema = Schema(primitive_field) applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore