From 75ccd9283510950b9a954df6faeb39dfb6ef98c5 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Sat, 29 Mar 2025 19:57:27 -0700 Subject: [PATCH 1/3] fix upsert with null values --- pyiceberg/table/upsert_util.py | 11 ++++++++++- tests/table/test_upsert.py | 36 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index d2bd48bc99..1e2f3a17dc 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -65,7 +65,16 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # When the target table is empty, there is nothing to update :) return source_table.schema.empty_table() - diff_expr = functools.reduce(operator.or_, [pc.field(f"{col}-lhs") != pc.field(f"{col}-rhs") for col in non_key_cols]) + diff_expr = functools.reduce( + operator.or_, + [ + pc.or_kleene( + pc.is_null(pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs"))), + pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs")), + ) + for col in non_key_cols + ], + ) return ( source_table diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 19bfbc01de..de77147ec9 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -509,3 +509,39 @@ def test_upsert_without_identifier_fields(catalog: Catalog) -> None: ValueError, match="Join columns could not be found, please set identifier-field-ids or pass in explicitly." ): tbl.upsert(df) + + +def test_upsert_with_nulls(catalog: Catalog) -> None: + identifier = "default.test_upsert_with_nulls" + _drop_table(catalog, identifier) + + schema = pa.schema( + [ + ("foo", pa.string()), + ("bar", pa.int32()), + ("baz", pa.bool_()), + ] + ) + + # create table with null value + table = catalog.create_table(identifier, schema) + data_with_null = pa.Table.from_pylist( + [ + {"foo": "apple", "bar": None, "baz": False}, + ], + schema=schema, + ) + table.append(data_with_null) + assert table.scan().to_arrow()["bar"].is_null() + + # upsert table with non-null value + data_without_null = pa.Table.from_pylist( + [ + {"foo": "apple", "bar": 7, "baz": False}, + ], + schema=schema, + ) + upd = table.upsert(data_without_null, join_cols=["foo"]) + assert upd.rows_updated == 1 + assert upd.rows_inserted == 0 + assert table.scan().to_arrow() == data_without_null From 93de319695340c50dcbde3b0ad978cd9112a4d69 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Sat, 29 Mar 2025 20:10:33 -0700 Subject: [PATCH 2/3] order --- pyiceberg/table/upsert_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 1e2f3a17dc..ef0e72fd48 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -69,8 +69,8 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols operator.or_, [ pc.or_kleene( - pc.is_null(pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs"))), pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs")), + pc.is_null(pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs"))), ) for col in non_key_cols ], From aebcb48941ea2a1c8cedc5c3eb4a4584a14cd500 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Mon, 31 Mar 2025 09:06:06 -0700 Subject: [PATCH 3/3] improve test --- tests/table/test_upsert.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index de77147ec9..5de4a61187 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -528,6 +528,7 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: data_with_null = pa.Table.from_pylist( [ {"foo": "apple", "bar": None, "baz": False}, + {"foo": "banana", "bar": None, "baz": False}, ], schema=schema, ) @@ -544,4 +545,10 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: upd = table.upsert(data_without_null, join_cols=["foo"]) assert upd.rows_updated == 1 assert upd.rows_inserted == 0 - assert table.scan().to_arrow() == data_without_null + assert table.scan().to_arrow() == pa.Table.from_pylist( + [ + {"foo": "apple", "bar": 7, "baz": False}, + {"foo": "banana", "bar": None, "baz": False}, + ], + schema=schema, + )