Skip to content

Commit b81990d

Browse files
Paul MathewPaul Mathew
authored andcommitted
perf(upsert): project join_cols only on destination scan when when_matched_update_all=False
Adds a ``selected_fields`` projection to the destination match ``DataScan`` in ``Transaction.upsert`` so the insert-on-no-match branch reads only ``join_cols`` from each destination batch instead of every column. The consumer loop in that branch never reads non-key columns — they're fetched, decoded, materialised into PyArrow batches, and discarded. For ``when_matched_update_all=True`` the projection stays at ``("*",)`` because ``upsert_util.get_rows_to_update`` reads every non-key column off the destination row to detect value drift and skip no-op overwrites. Narrowing there would silently regress that optimisation. Benchmarked against a representative event-log table (``hours(created_at)`` partition, composite UUID key, wide JSON payload), 1,000 destination- sampled rows across 168h: - Files planned: 2 / 2 (unchanged — projection is independent of manifest pruning). - Parquet column bytes off S3: 75.0 MiB → 16.4 MiB (78.2% reduction). - Arrow bytes after decompression: 452.4 KiB → 61.8 KiB (86.3% reduction). - to_arrow() wall time: 3.84s → 1.66s (57% reduction). A previous version of this PR additionally introduced an ``augment_filter_with_partition_ranges`` helper that AND'd ``[min, max]`` predicates on partition source columns onto the row filter. Review surfaced that for the safe-shape case (partition source columns ⊆ ``join_cols``), ``create_match_filter``'s per-disjunct projection already enables manifest pruning, and ``inclusive_projection`` walks the full N-disjunct tree regardless of the augmentation — so it added work without substituting for anything. It's been dropped from this PR. Related: #2138, #2159, #3129. Co-authored-by: Cursor <[email protected]>
1 parent d339391 commit b81990d

2 files changed

Lines changed: 183 additions & 1 deletion

File tree

pyiceberg/table/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,12 +877,17 @@ def upsert(
877877
# get list of rows that exist so we don't have to load the entire target table
878878
matched_predicate = upsert_util.create_match_filter(df, join_cols)
879879

880+
# When ``when_matched_update_all=False`` the consumer loop below
881+
# only ever reads ``join_cols`` off each destination batch.
882+
selected_fields: tuple[str, ...] = ("*",) if when_matched_update_all else tuple(join_cols)
883+
880884
# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.
881885

882886
matched_iceberg_record_batches_scan = DataScan(
883887
table_metadata=self.table_metadata,
884888
io=self._table.io,
885889
row_filter=matched_predicate,
890+
selected_fields=selected_fields,
886891
case_sensitive=case_sensitive,
887892
)
888893

tests/table/test_upsert.py

Lines changed: 178 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import datetime
1718
from pathlib import PosixPath
19+
from typing import Any
1820

1921
import pyarrow as pa
2022
import pytest
@@ -26,11 +28,13 @@
2628
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
2729
from pyiceberg.expressions.literals import LongLiteral
2830
from pyiceberg.io.pyarrow import schema_to_pyarrow
31+
from pyiceberg.partitioning import PartitionField, PartitionSpec
2932
from pyiceberg.schema import Schema
3033
from pyiceberg.table import Table, UpsertResult
3134
from pyiceberg.table.snapshots import Operation
3235
from pyiceberg.table.upsert_util import create_match_filter
33-
from pyiceberg.types import IntegerType, NestedField, StringType, StructType
36+
from pyiceberg.transforms import IdentityTransform
37+
from pyiceberg.types import DateType, IntegerType, NestedField, StringType, StructType
3438
from tests.catalog.test_base import InMemoryCatalog
3539

3640

@@ -888,3 +892,176 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None:
888892
for snapshot in snapshots[initial_snapshot_count:]:
889893
assert snapshot.summary is not None
890894
assert snapshot.summary.additional_properties.get("test_prop") == "test_value"
895+
896+
897+
class TestUpsertScanProjection:
898+
"""``Transaction.upsert`` narrows the destination scan's
899+
``selected_fields`` to ``join_cols`` when ``when_matched_update_all=False``.
900+
901+
Rationale: the insert-on-no-match branch only reads ``join_cols``
902+
off each destination batch (to feed ``create_match_filter``); every
903+
other column is unused. Projection at the scan boundary lets the
904+
parquet reader prune wide non-key columns at the file level —
905+
significant for tables whose payload column (e.g. a JSON ``log``)
906+
dominates file bytes. ``_projected_field_ids`` auto-unions the
907+
row-filter's column ids back in, so any column referenced by the
908+
join-key predicate is still readable for filter evaluation without
909+
needing to list it explicitly.
910+
911+
Falls back to ``("*",)`` when ``when_matched_update_all=True``
912+
because ``get_rows_to_update`` reads every non-key column off the
913+
destination row to detect value drift — narrowing would break the
914+
no-op-write skip.
915+
"""
916+
917+
@staticmethod
918+
def _build_partitioned_table(catalog: Catalog, identifier: str) -> Table:
919+
_drop_table(catalog, identifier)
920+
schema = Schema(
921+
NestedField(1, "order_id", IntegerType(), required=True),
922+
NestedField(2, "order_date", DateType(), required=True),
923+
NestedField(3, "order_type", StringType(), required=True),
924+
)
925+
spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date"))
926+
return catalog.create_table(identifier, schema=schema, partition_spec=spec)
927+
928+
@staticmethod
929+
def _arrow_schema() -> pa.Schema:
930+
return pa.schema(
931+
[
932+
pa.field("order_id", pa.int32(), nullable=False),
933+
pa.field("order_date", pa.date32(), nullable=False),
934+
pa.field("order_type", pa.string(), nullable=False),
935+
]
936+
)
937+
938+
def _seed(self, table: Table) -> None:
939+
table.append(
940+
pa.Table.from_pylist(
941+
[
942+
{"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"},
943+
{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "A"},
944+
],
945+
schema=self._arrow_schema(),
946+
)
947+
)
948+
949+
@pytest.fixture
950+
def captured_scans(self, monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]:
951+
"""Spy on ``DataScan.__init__`` to capture every kwargs dict.
952+
953+
Lets the tests pin which ``selected_fields`` the upsert path
954+
actually passes — assertions on the surfaced batch schema alone
955+
would miss the case where the underlying projection contract
956+
regresses but the test data happens to have only join_cols
957+
anyway.
958+
959+
The spy preserves ``__init__``'s signature via
960+
:func:`functools.wraps` so ``DataScan.update()``'s reflective
961+
``inspect.signature(type(self).__init__).parameters`` lookup
962+
(used by ``use_ref``) still resolves to the real parameter
963+
names, not the spy's ``**kwargs``.
964+
"""
965+
import functools
966+
967+
from pyiceberg.table import DataScan
968+
969+
captured: list[dict[str, Any]] = []
970+
original_init = DataScan.__init__
971+
972+
@functools.wraps(original_init)
973+
def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None:
974+
captured.append(dict(kwargs))
975+
original_init(self, *args, **kwargs)
976+
977+
monkeypatch.setattr(DataScan, "__init__", _spy)
978+
return captured
979+
980+
def test_when_matched_false_projects_join_cols_only(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None:
981+
"""The insert-on-no-match branch never reads non-key destination
982+
columns, so the scan must narrow the projection to ``join_cols``
983+
— saving the parquet reader from materialising wide payload
984+
columns just to be discarded."""
985+
table = self._build_partitioned_table(catalog, "default.test_upsert_projection_insert_only")
986+
self._seed(table)
987+
upsert_df = pa.Table.from_pylist(
988+
[
989+
{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "B"},
990+
{"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"},
991+
],
992+
schema=self._arrow_schema(),
993+
)
994+
995+
# Snapshot only the scans constructed during the upsert (the
996+
# seed append above may have created its own).
997+
before = len(captured_scans)
998+
res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=False)
999+
upsert_scans = captured_scans[before:]
1000+
assert res.rows_inserted == 1
1001+
assert res.rows_updated == 0
1002+
1003+
# The upsert constructs one DataScan for the destination match.
1004+
# ``use_ref`` may construct a second DataScan as an inherited
1005+
# copy (via ``self.update``), which carries the same
1006+
# ``selected_fields`` through. Pin both: at least one scan was
1007+
# constructed during the upsert, and every scan that ran
1008+
# carries the narrowed projection.
1009+
assert upsert_scans, "upsert path constructed no DataScan — projection contract regression"
1010+
selected = [s.get("selected_fields") for s in upsert_scans]
1011+
assert all(sf == ("order_id",) for sf in selected), (
1012+
f"expected every DataScan during upsert to use selected_fields=('order_id',); got {selected}"
1013+
)
1014+
1015+
def test_when_matched_true_keeps_star_projection(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None:
1016+
"""The update branch's ``get_rows_to_update`` compares non-key
1017+
columns to detect actual value changes — projecting only
1018+
``join_cols`` would feed it data with no non-key columns to
1019+
compare and silently turn every match into a write-back. Must
1020+
keep ``("*",)``."""
1021+
table = self._build_partitioned_table(catalog, "default.test_upsert_projection_update_mode")
1022+
self._seed(table)
1023+
upsert_df = pa.Table.from_pylist(
1024+
[
1025+
{"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "B"},
1026+
{"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"},
1027+
],
1028+
schema=self._arrow_schema(),
1029+
)
1030+
1031+
before = len(captured_scans)
1032+
res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True)
1033+
upsert_scans = captured_scans[before:]
1034+
assert res.rows_updated == 1
1035+
assert res.rows_inserted == 1
1036+
1037+
assert upsert_scans, "upsert path constructed no DataScan — projection contract regression"
1038+
selected = [s.get("selected_fields") for s in upsert_scans]
1039+
assert all(sf == ("*",) for sf in selected), (
1040+
f"expected every DataScan during upsert to keep selected_fields=('*',) for the update branch; got {selected}"
1041+
)
1042+
1043+
def test_update_mode_actually_updates_non_key_columns(self, catalog: Catalog) -> None:
1044+
"""End-to-end correctness pin: with ``when_matched_update_all=True``
1045+
the destination scan must read non-key columns so
1046+
``get_rows_to_update`` can detect ``order_type`` changes. A
1047+
regression that narrows projection unconditionally would skip
1048+
the comparison and silently miss updates whose non-key columns
1049+
differ.
1050+
"""
1051+
identifier = "default.test_upsert_update_mode_correctness"
1052+
table = self._build_partitioned_table(catalog, identifier)
1053+
self._seed(table)
1054+
# Source has the same (order_id, order_date) as one destination
1055+
# row but a different ``order_type``. Update path must detect
1056+
# the non-key change and overwrite.
1057+
upsert_df = pa.Table.from_pylist(
1058+
[{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "CHANGED"}],
1059+
schema=self._arrow_schema(),
1060+
)
1061+
res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True)
1062+
assert res.rows_updated == 1
1063+
assert res.rows_inserted == 0
1064+
1065+
# Read back: the original 'A' must have been overwritten with 'CHANGED'.
1066+
rows = {r["order_id"]: r for r in table.scan().to_arrow().to_pylist()}
1067+
assert rows[2]["order_type"] == "CHANGED"

0 commit comments

Comments
 (0)