|
14 | 14 | # KIND, either express or implied. See the License for the |
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
| 17 | +import datetime |
17 | 18 | from pathlib import PosixPath |
| 19 | +from typing import Any |
18 | 20 |
|
19 | 21 | import pyarrow as pa |
20 | 22 | import pytest |
|
26 | 28 | from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference |
27 | 29 | from pyiceberg.expressions.literals import LongLiteral |
28 | 30 | from pyiceberg.io.pyarrow import schema_to_pyarrow |
| 31 | +from pyiceberg.partitioning import PartitionField, PartitionSpec |
29 | 32 | from pyiceberg.schema import Schema |
30 | 33 | from pyiceberg.table import Table, UpsertResult |
31 | 34 | from pyiceberg.table.snapshots import Operation |
32 | 35 | from pyiceberg.table.upsert_util import create_match_filter |
33 | | -from pyiceberg.types import IntegerType, NestedField, StringType, StructType |
| 36 | +from pyiceberg.transforms import IdentityTransform |
| 37 | +from pyiceberg.types import DateType, IntegerType, NestedField, StringType, StructType |
34 | 38 | from tests.catalog.test_base import InMemoryCatalog |
35 | 39 |
|
36 | 40 |
|
@@ -888,3 +892,176 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None: |
888 | 892 | for snapshot in snapshots[initial_snapshot_count:]: |
889 | 893 | assert snapshot.summary is not None |
890 | 894 | assert snapshot.summary.additional_properties.get("test_prop") == "test_value" |
| 895 | + |
| 896 | + |
| 897 | +class TestUpsertScanProjection: |
| 898 | + """``Transaction.upsert`` narrows the destination scan's |
| 899 | + ``selected_fields`` to ``join_cols`` when ``when_matched_update_all=False``. |
| 900 | +
|
| 901 | + Rationale: the insert-on-no-match branch only reads ``join_cols`` |
| 902 | + off each destination batch (to feed ``create_match_filter``); every |
| 903 | + other column is unused. Projection at the scan boundary lets the |
| 904 | + parquet reader prune wide non-key columns at the file level — |
| 905 | + significant for tables whose payload column (e.g. a JSON ``log``) |
| 906 | + dominates file bytes. ``_projected_field_ids`` auto-unions the |
| 907 | + row-filter's column ids back in, so any column referenced by the |
| 908 | + join-key predicate is still readable for filter evaluation without |
| 909 | + needing to list it explicitly. |
| 910 | +
|
| 911 | + Falls back to ``("*",)`` when ``when_matched_update_all=True`` |
| 912 | + because ``get_rows_to_update`` reads every non-key column off the |
| 913 | + destination row to detect value drift — narrowing would break the |
| 914 | + no-op-write skip. |
| 915 | + """ |
| 916 | + |
| 917 | + @staticmethod |
| 918 | + def _build_partitioned_table(catalog: Catalog, identifier: str) -> Table: |
| 919 | + _drop_table(catalog, identifier) |
| 920 | + schema = Schema( |
| 921 | + NestedField(1, "order_id", IntegerType(), required=True), |
| 922 | + NestedField(2, "order_date", DateType(), required=True), |
| 923 | + NestedField(3, "order_type", StringType(), required=True), |
| 924 | + ) |
| 925 | + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) |
| 926 | + return catalog.create_table(identifier, schema=schema, partition_spec=spec) |
| 927 | + |
| 928 | + @staticmethod |
| 929 | + def _arrow_schema() -> pa.Schema: |
| 930 | + return pa.schema( |
| 931 | + [ |
| 932 | + pa.field("order_id", pa.int32(), nullable=False), |
| 933 | + pa.field("order_date", pa.date32(), nullable=False), |
| 934 | + pa.field("order_type", pa.string(), nullable=False), |
| 935 | + ] |
| 936 | + ) |
| 937 | + |
| 938 | + def _seed(self, table: Table) -> None: |
| 939 | + table.append( |
| 940 | + pa.Table.from_pylist( |
| 941 | + [ |
| 942 | + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, |
| 943 | + {"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "A"}, |
| 944 | + ], |
| 945 | + schema=self._arrow_schema(), |
| 946 | + ) |
| 947 | + ) |
| 948 | + |
| 949 | + @pytest.fixture |
| 950 | + def captured_scans(self, monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]: |
| 951 | + """Spy on ``DataScan.__init__`` to capture every kwargs dict. |
| 952 | +
|
| 953 | + Lets the tests pin which ``selected_fields`` the upsert path |
| 954 | + actually passes — assertions on the surfaced batch schema alone |
| 955 | + would miss the case where the underlying projection contract |
| 956 | + regresses but the test data happens to have only join_cols |
| 957 | + anyway. |
| 958 | +
|
| 959 | + The spy preserves ``__init__``'s signature via |
| 960 | + :func:`functools.wraps` so ``DataScan.update()``'s reflective |
| 961 | + ``inspect.signature(type(self).__init__).parameters`` lookup |
| 962 | + (used by ``use_ref``) still resolves to the real parameter |
| 963 | + names, not the spy's ``**kwargs``. |
| 964 | + """ |
| 965 | + import functools |
| 966 | + |
| 967 | + from pyiceberg.table import DataScan |
| 968 | + |
| 969 | + captured: list[dict[str, Any]] = [] |
| 970 | + original_init = DataScan.__init__ |
| 971 | + |
| 972 | + @functools.wraps(original_init) |
| 973 | + def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None: |
| 974 | + captured.append(dict(kwargs)) |
| 975 | + original_init(self, *args, **kwargs) |
| 976 | + |
| 977 | + monkeypatch.setattr(DataScan, "__init__", _spy) |
| 978 | + return captured |
| 979 | + |
| 980 | + def test_when_matched_false_projects_join_cols_only(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None: |
| 981 | + """The insert-on-no-match branch never reads non-key destination |
| 982 | + columns, so the scan must narrow the projection to ``join_cols`` |
| 983 | + — saving the parquet reader from materialising wide payload |
| 984 | + columns just to be discarded.""" |
| 985 | + table = self._build_partitioned_table(catalog, "default.test_upsert_projection_insert_only") |
| 986 | + self._seed(table) |
| 987 | + upsert_df = pa.Table.from_pylist( |
| 988 | + [ |
| 989 | + {"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, |
| 990 | + {"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"}, |
| 991 | + ], |
| 992 | + schema=self._arrow_schema(), |
| 993 | + ) |
| 994 | + |
| 995 | + # Snapshot only the scans constructed during the upsert (the |
| 996 | + # seed append above may have created its own). |
| 997 | + before = len(captured_scans) |
| 998 | + res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=False) |
| 999 | + upsert_scans = captured_scans[before:] |
| 1000 | + assert res.rows_inserted == 1 |
| 1001 | + assert res.rows_updated == 0 |
| 1002 | + |
| 1003 | + # The upsert constructs one DataScan for the destination match. |
| 1004 | + # ``use_ref`` may construct a second DataScan as an inherited |
| 1005 | + # copy (via ``self.update``), which carries the same |
| 1006 | + # ``selected_fields`` through. Pin both: at least one scan was |
| 1007 | + # constructed during the upsert, and every scan that ran |
| 1008 | + # carries the narrowed projection. |
| 1009 | + assert upsert_scans, "upsert path constructed no DataScan — projection contract regression" |
| 1010 | + selected = [s.get("selected_fields") for s in upsert_scans] |
| 1011 | + assert all(sf == ("order_id",) for sf in selected), ( |
| 1012 | + f"expected every DataScan during upsert to use selected_fields=('order_id',); got {selected}" |
| 1013 | + ) |
| 1014 | + |
| 1015 | + def test_when_matched_true_keeps_star_projection(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None: |
| 1016 | + """The update branch's ``get_rows_to_update`` compares non-key |
| 1017 | + columns to detect actual value changes — projecting only |
| 1018 | + ``join_cols`` would feed it data with no non-key columns to |
| 1019 | + compare and silently turn every match into a write-back. Must |
| 1020 | + keep ``("*",)``.""" |
| 1021 | + table = self._build_partitioned_table(catalog, "default.test_upsert_projection_update_mode") |
| 1022 | + self._seed(table) |
| 1023 | + upsert_df = pa.Table.from_pylist( |
| 1024 | + [ |
| 1025 | + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "B"}, |
| 1026 | + {"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"}, |
| 1027 | + ], |
| 1028 | + schema=self._arrow_schema(), |
| 1029 | + ) |
| 1030 | + |
| 1031 | + before = len(captured_scans) |
| 1032 | + res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True) |
| 1033 | + upsert_scans = captured_scans[before:] |
| 1034 | + assert res.rows_updated == 1 |
| 1035 | + assert res.rows_inserted == 1 |
| 1036 | + |
| 1037 | + assert upsert_scans, "upsert path constructed no DataScan — projection contract regression" |
| 1038 | + selected = [s.get("selected_fields") for s in upsert_scans] |
| 1039 | + assert all(sf == ("*",) for sf in selected), ( |
| 1040 | + f"expected every DataScan during upsert to keep selected_fields=('*',) for the update branch; got {selected}" |
| 1041 | + ) |
| 1042 | + |
| 1043 | + def test_update_mode_actually_updates_non_key_columns(self, catalog: Catalog) -> None: |
| 1044 | + """End-to-end correctness pin: with ``when_matched_update_all=True`` |
| 1045 | + the destination scan must read non-key columns so |
| 1046 | + ``get_rows_to_update`` can detect ``order_type`` changes. A |
| 1047 | + regression that narrows projection unconditionally would skip |
| 1048 | + the comparison and silently miss updates whose non-key columns |
| 1049 | + differ. |
| 1050 | + """ |
| 1051 | + identifier = "default.test_upsert_update_mode_correctness" |
| 1052 | + table = self._build_partitioned_table(catalog, identifier) |
| 1053 | + self._seed(table) |
| 1054 | + # Source has the same (order_id, order_date) as one destination |
| 1055 | + # row but a different ``order_type``. Update path must detect |
| 1056 | + # the non-key change and overwrite. |
| 1057 | + upsert_df = pa.Table.from_pylist( |
| 1058 | + [{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "CHANGED"}], |
| 1059 | + schema=self._arrow_schema(), |
| 1060 | + ) |
| 1061 | + res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True) |
| 1062 | + assert res.rows_updated == 1 |
| 1063 | + assert res.rows_inserted == 0 |
| 1064 | + |
| 1065 | + # Read back: the original 'A' must have been overwritten with 'CHANGED'. |
| 1066 | + rows = {r["order_id"]: r for r in table.scan().to_arrow().to_pylist()} |
| 1067 | + assert rows[2]["order_type"] == "CHANGED" |
0 commit comments