diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index d2bd48bc99..ef0e72fd48 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -65,7 +65,16 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # When the target table is empty, there is nothing to update :) return source_table.schema.empty_table() - diff_expr = functools.reduce(operator.or_, [pc.field(f"{col}-lhs") != pc.field(f"{col}-rhs") for col in non_key_cols]) + diff_expr = functools.reduce( + operator.or_, + [ + pc.or_kleene( + pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs")), + pc.is_null(pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs"))), + ) + for col in non_key_cols + ], + ) return ( source_table diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 19bfbc01de..5de4a61187 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -509,3 +509,46 @@ def test_upsert_without_identifier_fields(catalog: Catalog) -> None: ValueError, match="Join columns could not be found, please set identifier-field-ids or pass in explicitly." ): tbl.upsert(df) + + +def test_upsert_with_nulls(catalog: Catalog) -> None: + identifier = "default.test_upsert_with_nulls" + _drop_table(catalog, identifier) + + schema = pa.schema( + [ + ("foo", pa.string()), + ("bar", pa.int32()), + ("baz", pa.bool_()), + ] + ) + + # create table with null value + table = catalog.create_table(identifier, schema) + data_with_null = pa.Table.from_pylist( + [ + {"foo": "apple", "bar": None, "baz": False}, + {"foo": "banana", "bar": None, "baz": False}, + ], + schema=schema, + ) + table.append(data_with_null) + assert table.scan().to_arrow()["bar"].is_null() + + # upsert table with non-null value + data_without_null = pa.Table.from_pylist( + [ + {"foo": "apple", "bar": 7, "baz": False}, + ], + schema=schema, + ) + upd = table.upsert(data_without_null, join_cols=["foo"]) + assert upd.rows_updated == 1 + assert upd.rows_inserted == 0 + assert table.scan().to_arrow() == pa.Table.from_pylist( + [ + {"foo": "apple", "bar": 7, "baz": False}, + {"foo": "banana", "bar": None, "baz": False}, + ], + schema=schema, + )