From b81990d5fcb5ce92624cd61441da885f27363db9 Mon Sep 17 00:00:00 2001 From: Paul Mathew <[email protected]> Date: Wed, 20 May 2026 15:42:20 -0400 Subject: [PATCH] perf(upsert): project join_cols only on destination scan when when_matched_update_all=False MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a ``selected_fields`` projection to the destination match ``DataScan`` in ``Transaction.upsert`` so the insert-on-no-match branch reads only ``join_cols`` from each destination batch instead of every column. The consumer loop in that branch never reads non-key columns — they're fetched, decoded, materialised into PyArrow batches, and discarded. For ``when_matched_update_all=True`` the projection stays at ``("*",)`` because ``upsert_util.get_rows_to_update`` reads every non-key column off the destination row to detect value drift and skip no-op overwrites. Narrowing there would silently regress that optimisation. Benchmarked against a representative event-log table (``hours(created_at)`` partition, composite UUID key, wide JSON payload), 1,000 destination- sampled rows across 168h: - Files planned: 2 / 2 (unchanged — projection is independent of manifest pruning). - Parquet column bytes off S3: 75.0 MiB → 16.4 MiB (78.2% reduction). - Arrow bytes after decompression: 452.4 KiB → 61.8 KiB (86.3% reduction). - to_arrow() wall time: 3.84s → 1.66s (57% reduction). A previous version of this PR additionally introduced an ``augment_filter_with_partition_ranges`` helper that AND'd ``[min, max]`` predicates on partition source columns onto the row filter. Review surfaced that for the safe-shape case (partition source columns ⊆ ``join_cols``), ``create_match_filter``'s per-disjunct projection already enables manifest pruning, and ``inclusive_projection`` walks the full N-disjunct tree regardless of the augmentation — so it added work without substituting for anything. It's been dropped from this PR. Related: #2138, #2159, #3129. Co-authored-by: Cursor <[email protected]> --- pyiceberg/table/__init__.py | 5 + tests/table/test_upsert.py | 179 +++++++++++++++++++++++++++++++++++- 2 files changed, 183 insertions(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 64ad10050d..4621b5f3bf 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -877,12 +877,17 @@ def upsert( # get list of rows that exist so we don't have to load the entire target table matched_predicate = upsert_util.create_match_filter(df, join_cols) + # When ``when_matched_update_all=False`` the consumer loop below + # only ever reads ``join_cols`` off each destination batch. + selected_fields: tuple[str, ...] = ("*",) if when_matched_update_all else tuple(join_cols) + # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. matched_iceberg_record_batches_scan = DataScan( table_metadata=self.table_metadata, io=self._table.io, row_filter=matched_predicate, + selected_fields=selected_fields, case_sensitive=case_sensitive, ) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 08f90c6600..e5569757b0 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import datetime from pathlib import PosixPath +from typing import Any import pyarrow as pa import pytest @@ -26,11 +28,13 @@ from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference from pyiceberg.expressions.literals import LongLiteral from pyiceberg.io.pyarrow import schema_to_pyarrow +from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import Table, UpsertResult from pyiceberg.table.snapshots import Operation from pyiceberg.table.upsert_util import create_match_filter -from pyiceberg.types import IntegerType, NestedField, StringType, StructType +from pyiceberg.transforms import IdentityTransform +from pyiceberg.types import DateType, IntegerType, NestedField, StringType, StructType from tests.catalog.test_base import InMemoryCatalog @@ -888,3 +892,176 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None: for snapshot in snapshots[initial_snapshot_count:]: assert snapshot.summary is not None assert snapshot.summary.additional_properties.get("test_prop") == "test_value" + + +class TestUpsertScanProjection: + """``Transaction.upsert`` narrows the destination scan's + ``selected_fields`` to ``join_cols`` when ``when_matched_update_all=False``. + + Rationale: the insert-on-no-match branch only reads ``join_cols`` + off each destination batch (to feed ``create_match_filter``); every + other column is unused. Projection at the scan boundary lets the + parquet reader prune wide non-key columns at the file level — + significant for tables whose payload column (e.g. a JSON ``log``) + dominates file bytes. ``_projected_field_ids`` auto-unions the + row-filter's column ids back in, so any column referenced by the + join-key predicate is still readable for filter evaluation without + needing to list it explicitly. + + Falls back to ``("*",)`` when ``when_matched_update_all=True`` + because ``get_rows_to_update`` reads every non-key column off the + destination row to detect value drift — narrowing would break the + no-op-write skip. + """ + + @staticmethod + def _build_partitioned_table(catalog: Catalog, identifier: str) -> Table: + _drop_table(catalog, identifier) + schema = Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_date", DateType(), required=True), + NestedField(3, "order_type", StringType(), required=True), + ) + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) + return catalog.create_table(identifier, schema=schema, partition_spec=spec) + + @staticmethod + def _arrow_schema() -> pa.Schema: + return pa.schema( + [ + pa.field("order_id", pa.int32(), nullable=False), + pa.field("order_date", pa.date32(), nullable=False), + pa.field("order_type", pa.string(), nullable=False), + ] + ) + + def _seed(self, table: Table) -> None: + table.append( + pa.Table.from_pylist( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, + {"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "A"}, + ], + schema=self._arrow_schema(), + ) + ) + + @pytest.fixture + def captured_scans(self, monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]: + """Spy on ``DataScan.__init__`` to capture every kwargs dict. + + Lets the tests pin which ``selected_fields`` the upsert path + actually passes — assertions on the surfaced batch schema alone + would miss the case where the underlying projection contract + regresses but the test data happens to have only join_cols + anyway. + + The spy preserves ``__init__``'s signature via + :func:`functools.wraps` so ``DataScan.update()``'s reflective + ``inspect.signature(type(self).__init__).parameters`` lookup + (used by ``use_ref``) still resolves to the real parameter + names, not the spy's ``**kwargs``. + """ + import functools + + from pyiceberg.table import DataScan + + captured: list[dict[str, Any]] = [] + original_init = DataScan.__init__ + + @functools.wraps(original_init) + def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None: + captured.append(dict(kwargs)) + original_init(self, *args, **kwargs) + + monkeypatch.setattr(DataScan, "__init__", _spy) + return captured + + def test_when_matched_false_projects_join_cols_only(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None: + """The insert-on-no-match branch never reads non-key destination + columns, so the scan must narrow the projection to ``join_cols`` + — saving the parquet reader from materialising wide payload + columns just to be discarded.""" + table = self._build_partitioned_table(catalog, "default.test_upsert_projection_insert_only") + self._seed(table) + upsert_df = pa.Table.from_pylist( + [ + {"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, + {"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"}, + ], + schema=self._arrow_schema(), + ) + + # Snapshot only the scans constructed during the upsert (the + # seed append above may have created its own). + before = len(captured_scans) + res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=False) + upsert_scans = captured_scans[before:] + assert res.rows_inserted == 1 + assert res.rows_updated == 0 + + # The upsert constructs one DataScan for the destination match. + # ``use_ref`` may construct a second DataScan as an inherited + # copy (via ``self.update``), which carries the same + # ``selected_fields`` through. Pin both: at least one scan was + # constructed during the upsert, and every scan that ran + # carries the narrowed projection. + assert upsert_scans, "upsert path constructed no DataScan — projection contract regression" + selected = [s.get("selected_fields") for s in upsert_scans] + assert all(sf == ("order_id",) for sf in selected), ( + f"expected every DataScan during upsert to use selected_fields=('order_id',); got {selected}" + ) + + def test_when_matched_true_keeps_star_projection(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None: + """The update branch's ``get_rows_to_update`` compares non-key + columns to detect actual value changes — projecting only + ``join_cols`` would feed it data with no non-key columns to + compare and silently turn every match into a write-back. Must + keep ``("*",)``.""" + table = self._build_partitioned_table(catalog, "default.test_upsert_projection_update_mode") + self._seed(table) + upsert_df = pa.Table.from_pylist( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "B"}, + {"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"}, + ], + schema=self._arrow_schema(), + ) + + before = len(captured_scans) + res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True) + upsert_scans = captured_scans[before:] + assert res.rows_updated == 1 + assert res.rows_inserted == 1 + + assert upsert_scans, "upsert path constructed no DataScan — projection contract regression" + selected = [s.get("selected_fields") for s in upsert_scans] + assert all(sf == ("*",) for sf in selected), ( + f"expected every DataScan during upsert to keep selected_fields=('*',) for the update branch; got {selected}" + ) + + def test_update_mode_actually_updates_non_key_columns(self, catalog: Catalog) -> None: + """End-to-end correctness pin: with ``when_matched_update_all=True`` + the destination scan must read non-key columns so + ``get_rows_to_update`` can detect ``order_type`` changes. A + regression that narrows projection unconditionally would skip + the comparison and silently miss updates whose non-key columns + differ. + """ + identifier = "default.test_upsert_update_mode_correctness" + table = self._build_partitioned_table(catalog, identifier) + self._seed(table) + # Source has the same (order_id, order_date) as one destination + # row but a different ``order_type``. Update path must detect + # the non-key change and overwrite. + upsert_df = pa.Table.from_pylist( + [{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "CHANGED"}], + schema=self._arrow_schema(), + ) + res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True) + assert res.rows_updated == 1 + assert res.rows_inserted == 0 + + # Read back: the original 'A' must have been overwritten with 'CHANGED'. + rows = {r["order_id"]: r for r in table.scan().to_arrow().to_pylist()} + assert rows[2]["order_type"] == "CHANGED"