Skip to content

Commit 7c4ecf8

Browse files
committed
add UpdateSortOrder
1 parent 8d3497f commit 7c4ecf8

File tree

6 files changed

+251
-37
lines changed

6 files changed

+251
-37
lines changed

pyiceberg/table/__init__.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,9 @@ def __init__(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRe
282282
self._updates = updates
283283
self._requirements = requirements
284284

285+
def _reset_state(self) -> None:
286+
"""No-op for static updates that don't cache metadata-derived state."""
287+
285288
def _commit(self) -> UpdatesAndRequirements:
286289
"""Return the stored updates and requirements."""
287290
return self._updates, self._requirements
@@ -1090,14 +1093,8 @@ def _reapply_updates(self) -> None:
10901093
This is called on retry to regenerate all pending updates
10911094
based on the latest table metadata, similar to Java's BaseTransaction.applyUpdates().
10921095
1093-
All updates are rebuilt from _pending_updates to ensure consistency.
1094-
This includes:
1095-
- Static updates (properties, format version) via _StaticUpdate
1096-
- Snapshot operations via snapshot producers with _reset_state()
1097-
1098-
NOTE: Every operation that should survive retry must be tracked in _pending_updates.
1099-
Simple operations use _StaticUpdate wrapper, complex operations (like snapshot
1100-
producers) implement _reset_state() and _commit() directly.
1096+
NOTE: When adding new cached properties to UpdateTableMetadata subclasses,
1097+
ensure they are cleared in _reset_state() to avoid stale data on retry.
11011098
"""
11021099
self._table.refresh()
11031100

@@ -1106,11 +1103,7 @@ def _reapply_updates(self) -> None:
11061103
self._working_metadata = self._table.metadata
11071104

11081105
for pending_update in self._pending_updates:
1109-
# NOTE: When adding new cached properties to snapshot producers,
1110-
# ensure they are cleared in _reset_state() to avoid stale data on retry
1111-
if hasattr(pending_update, "_reset_state"):
1112-
pending_update._reset_state()
1113-
1106+
pending_update._reset_state()
11141107
updates, requirements = pending_update._commit()
11151108
self._updates += updates
11161109
self._working_metadata = update_table_metadata(self._working_metadata, updates)

pyiceberg/table/update/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,19 @@ def __init__(self, transaction: Transaction) -> None:
6666
self._transaction = transaction
6767

6868
@abstractmethod
69-
def _commit(self) -> UpdatesAndRequirements: ...
69+
def _commit(self) -> UpdatesAndRequirements:
70+
"""Generate the table updates and requirements for this operation."""
71+
...
72+
73+
@abstractmethod
74+
def _reset_state(self) -> None:
75+
"""Reset internal state for retry after table metadata refresh.
76+
77+
This is called by Transaction._reapply_updates() when retrying after a
78+
CommitFailedException. Implementations should rebuild any cached state
79+
from self._transaction.table_metadata.
80+
"""
81+
...
7082

7183
def commit(self) -> None:
7284
updates, requirements = self._commit()

pyiceberg/table/update/snapshot.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,9 @@ def __init__(self, transaction: Transaction) -> None:
972972
self._updates = ()
973973
self._requirements = ()
974974

975+
def _reset_state(self) -> None:
976+
"""No-op: updates contain user-provided snapshot IDs that don't need refresh."""
977+
975978
def _commit(self) -> UpdatesAndRequirements:
976979
"""Apply the pending changes and commit."""
977980
return self._updates, self._requirements
@@ -1093,6 +1096,15 @@ def __init__(self, transaction: Transaction) -> None:
10931096
self._requirements = ()
10941097
self._snapshot_ids_to_expire = set()
10951098

1099+
def _reset_state(self) -> None:
1100+
"""Clear accumulated updates for retry.
1101+
1102+
The _snapshot_ids_to_expire are user-provided and preserved.
1103+
The _updates are cleared so _commit() can rebuild them with
1104+
refreshed protected snapshot IDs.
1105+
"""
1106+
self._updates = ()
1107+
10961108
def _commit(self) -> UpdatesAndRequirements:
10971109
"""
10981110
Commit the staged updates and requirements.

pyiceberg/table/update/sorting.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,22 @@ class UpdateSortOrder(UpdateTableMetadata["UpdateSortOrder"]):
3939
_last_assigned_order_id: int | None
4040
_case_sensitive: bool
4141
_fields: list[SortField]
42+
# Store user intent for retry support: (column_name, transform, direction, null_order)
43+
_field_additions: list[tuple[str, Transform[Any, Any], SortDirection, NullOrder]]
4244

4345
def __init__(self, transaction: Transaction, case_sensitive: bool = True) -> None:
4446
super().__init__(transaction)
45-
self._fields: list[SortField] = []
46-
self._case_sensitive: bool = case_sensitive
47-
self._last_assigned_order_id: int | None = None
47+
self._fields = []
48+
self._case_sensitive = case_sensitive
49+
self._last_assigned_order_id = None
50+
self._field_additions = []
51+
52+
def _reset_state(self) -> None:
53+
"""Reset state for retry, re-resolving column names from refreshed metadata."""
54+
self._fields = []
55+
self._last_assigned_order_id = None
56+
for column_name, transform, direction, null_order in self._field_additions:
57+
self._do_add_sort_field(column_name, transform, direction, null_order)
4858

4959
def _column_name_to_id(self, column_name: str) -> int:
5060
"""Map the column name to the column field id."""
@@ -57,23 +67,22 @@ def _column_name_to_id(self, column_name: str) -> int:
5767
.field_id
5868
)
5969

60-
def _add_sort_field(
70+
def _do_add_sort_field(
6171
self,
62-
source_id: int,
72+
source_column_name: str,
6373
transform: Transform[Any, Any],
6474
direction: SortDirection,
6575
null_order: NullOrder,
66-
) -> UpdateSortOrder:
67-
"""Add a sort field to the sort order list."""
76+
) -> None:
77+
"""Add a sort field to the sort order list (internal implementation for retry support)."""
6878
self._fields.append(
6979
SortField(
70-
source_id=source_id,
80+
source_id=self._column_name_to_id(source_column_name),
7181
transform=transform,
7282
direction=direction,
7383
null_order=null_order,
7484
)
7585
)
76-
return self
7786

7887
def _reuse_or_create_sort_order_id(self) -> int:
7988
"""Return the last assigned sort order id or create a new one."""
@@ -90,23 +99,17 @@ def asc(
9099
self, source_column_name: str, transform: Transform[Any, Any], null_order: NullOrder = NullOrder.NULLS_LAST
91100
) -> UpdateSortOrder:
92101
"""Add a sort field with ascending order."""
93-
return self._add_sort_field(
94-
source_id=self._column_name_to_id(source_column_name),
95-
transform=transform,
96-
direction=SortDirection.ASC,
97-
null_order=null_order,
98-
)
102+
self._field_additions.append((source_column_name, transform, SortDirection.ASC, null_order))
103+
self._do_add_sort_field(source_column_name, transform, SortDirection.ASC, null_order)
104+
return self
99105

100106
def desc(
101107
self, source_column_name: str, transform: Transform[Any, Any], null_order: NullOrder = NullOrder.NULLS_LAST
102108
) -> UpdateSortOrder:
103109
"""Add a sort field with descending order."""
104-
return self._add_sort_field(
105-
source_id=self._column_name_to_id(source_column_name),
106-
transform=transform,
107-
direction=SortDirection.DESC,
108-
null_order=null_order,
109-
)
110+
self._field_additions.append((source_column_name, transform, SortDirection.DESC, null_order))
111+
self._do_add_sort_field(source_column_name, transform, SortDirection.DESC, null_order)
112+
return self
110113

111114
def _apply(self) -> SortOrder:
112115
"""Return the sort order."""

pyiceberg/table/update/statistics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class UpdateStatistics(UpdateTableMetadata["UpdateStatistics"]):
5252
def __init__(self, transaction: "Transaction") -> None:
5353
super().__init__(transaction)
5454

55+
def _reset_state(self) -> None:
56+
"""No-op: updates contain user-provided data that doesn't need refresh."""
57+
5558
def set_statistics(self, statistics_file: StatisticsFile) -> "UpdateStatistics":
5659
self._updates += (
5760
SetStatisticsUpdate(

tests/table/test_commit_retry.py

Lines changed: 193 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,8 +1126,6 @@ def mock_commit(
11261126

11271127

11281128
class TestUpdateSchemaRetry:
1129-
"""Tests for UpdateSchema retry behavior (Java-like behavior)."""
1130-
11311129
def test_update_schema_retried_on_conflict(self, catalog: SqlCatalog, schema: Schema) -> None:
11321130
"""Test that UpdateSchema operations are retried on CommitFailedException."""
11331131
from pyiceberg.types import StringType
@@ -1273,3 +1271,196 @@ def mock_commit(
12731271
assert table.schema().schema_id == 1
12741272
assert len(table.schema().fields) == 2
12751273
assert table.schema().find_field("new_col").field_type == StringType()
1274+
1275+
1276+
class TestUpdateSortOrderRetry:
1277+
def test_update_sort_order_retried_on_conflict(self, catalog: SqlCatalog, schema: Schema) -> None:
1278+
"""Test that UpdateSortOrder operations are retried on CommitFailedException."""
1279+
from pyiceberg.transforms import IdentityTransform
1280+
1281+
table = catalog.create_table(
1282+
"default.test_sort_order_retry",
1283+
schema=schema,
1284+
properties={
1285+
TableProperties.COMMIT_NUM_RETRIES: "3",
1286+
TableProperties.COMMIT_MIN_RETRY_WAIT_MS: "1",
1287+
TableProperties.COMMIT_MAX_RETRY_WAIT_MS: "10",
1288+
},
1289+
)
1290+
1291+
original_commit = catalog.commit_table
1292+
commit_count = 0
1293+
1294+
def mock_commit(
1295+
tbl: Table, requirements: tuple[TableRequirement, ...], updates: tuple[TableUpdate, ...]
1296+
) -> CommitTableResponse:
1297+
nonlocal commit_count
1298+
commit_count += 1
1299+
if commit_count == 1:
1300+
raise CommitFailedException("Simulated sort order conflict")
1301+
return original_commit(tbl, requirements, updates)
1302+
1303+
with patch.object(catalog, "commit_table", side_effect=mock_commit):
1304+
with table.update_sort_order() as update_sort_order:
1305+
update_sort_order.asc("id", IdentityTransform())
1306+
1307+
assert commit_count == 2
1308+
1309+
# Verify sort order was updated
1310+
table.refresh()
1311+
sort_order = table.sort_order()
1312+
assert sort_order.order_id == 1
1313+
assert len(sort_order.fields) == 1
1314+
assert sort_order.fields[0].source_id == 1 # "id" column
1315+
1316+
def test_update_sort_order_resolves_conflict_on_retry(self, catalog: SqlCatalog, schema: Schema) -> None:
1317+
"""Test that sort order update can resolve conflicts via retry."""
1318+
from pyiceberg.table.sorting import SortDirection
1319+
from pyiceberg.transforms import IdentityTransform
1320+
1321+
table = catalog.create_table(
1322+
"default.test_sort_order_conflict_resolved",
1323+
schema=schema,
1324+
properties={
1325+
TableProperties.COMMIT_NUM_RETRIES: "5",
1326+
TableProperties.COMMIT_MIN_RETRY_WAIT_MS: "1",
1327+
TableProperties.COMMIT_MAX_RETRY_WAIT_MS: "10",
1328+
},
1329+
)
1330+
1331+
with table.update_sort_order() as update_sort_order:
1332+
update_sort_order.asc("id", IdentityTransform())
1333+
1334+
table2 = catalog.load_table("default.test_sort_order_conflict_resolved")
1335+
with table2.update_sort_order() as update_sort_order2:
1336+
update_sort_order2.desc("id", IdentityTransform())
1337+
1338+
assert table.sort_order().order_id == 1
1339+
assert table2.sort_order().order_id == 2
1340+
1341+
original_commit = catalog.commit_table
1342+
commit_count = 0
1343+
1344+
def mock_commit(
1345+
tbl: Table, requirements: tuple[TableRequirement, ...], updates: tuple[TableUpdate, ...]
1346+
) -> CommitTableResponse:
1347+
nonlocal commit_count
1348+
commit_count += 1
1349+
return original_commit(tbl, requirements, updates)
1350+
1351+
with patch.object(catalog, "commit_table", side_effect=mock_commit):
1352+
with table.update_sort_order() as update_sort_order:
1353+
update_sort_order.asc("id", IdentityTransform())
1354+
1355+
assert commit_count == 2
1356+
1357+
table.refresh()
1358+
sort_order = table.sort_order()
1359+
assert sort_order.order_id == 1 # Reused existing order with same fields
1360+
assert len(sort_order.fields) == 1
1361+
assert sort_order.fields[0].direction == SortDirection.ASC
1362+
1363+
def test_transaction_with_sort_order_change_and_append_retries(
1364+
self, catalog: SqlCatalog, schema: Schema, arrow_table: pa.Table
1365+
) -> None:
1366+
"""Test that a transaction with sort order change and append handles retry correctly."""
1367+
from pyiceberg.transforms import IdentityTransform
1368+
1369+
table = catalog.create_table(
1370+
"default.test_transaction_sort_order_and_append",
1371+
schema=schema,
1372+
properties={
1373+
TableProperties.COMMIT_NUM_RETRIES: "3",
1374+
TableProperties.COMMIT_MIN_RETRY_WAIT_MS: "1",
1375+
TableProperties.COMMIT_MAX_RETRY_WAIT_MS: "10",
1376+
},
1377+
)
1378+
1379+
original_commit = catalog.commit_table
1380+
commit_count = 0
1381+
captured_updates: list[tuple[TableUpdate, ...]] = []
1382+
1383+
def mock_commit(
1384+
tbl: Table, requirements: tuple[TableRequirement, ...], updates: tuple[TableUpdate, ...]
1385+
) -> CommitTableResponse:
1386+
nonlocal commit_count
1387+
commit_count += 1
1388+
captured_updates.append(updates)
1389+
if commit_count == 1:
1390+
raise CommitFailedException("Simulated conflict")
1391+
return original_commit(tbl, requirements, updates)
1392+
1393+
with patch.object(catalog, "commit_table", side_effect=mock_commit):
1394+
with table.transaction() as txn:
1395+
with txn.update_sort_order() as update_sort_order:
1396+
update_sort_order.asc("id", IdentityTransform())
1397+
txn.append(arrow_table)
1398+
1399+
assert commit_count == 2
1400+
1401+
first_attempt_update_types = [type(u).__name__ for u in captured_updates[0]]
1402+
assert "AddSortOrderUpdate" in first_attempt_update_types
1403+
assert "AddSnapshotUpdate" in first_attempt_update_types
1404+
1405+
retry_attempt_update_types = [type(u).__name__ for u in captured_updates[1]]
1406+
assert "AddSortOrderUpdate" in retry_attempt_update_types
1407+
assert "AddSnapshotUpdate" in retry_attempt_update_types
1408+
1409+
assert len(table.scan().to_arrow()) == 3
1410+
1411+
sort_order = table.sort_order()
1412+
assert sort_order.order_id == 1
1413+
assert len(sort_order.fields) == 1
1414+
assert sort_order.fields[0].source_id == 1 # "id" column
1415+
1416+
def test_sort_order_column_name_re_resolved_on_retry(self, catalog: SqlCatalog, schema: Schema) -> None:
1417+
"""Test that column names are re-resolved from refreshed schema on retry.
1418+
1419+
This ensures that if the schema changes between retries (e.g., column ID changes),
1420+
the sort order will use the correct field ID from the refreshed schema.
1421+
"""
1422+
from pyiceberg.transforms import IdentityTransform
1423+
1424+
table = catalog.create_table(
1425+
"default.test_sort_order_column_re_resolved",
1426+
schema=schema,
1427+
properties={
1428+
TableProperties.COMMIT_NUM_RETRIES: "3",
1429+
TableProperties.COMMIT_MIN_RETRY_WAIT_MS: "1",
1430+
TableProperties.COMMIT_MAX_RETRY_WAIT_MS: "10",
1431+
},
1432+
)
1433+
1434+
original_commit = catalog.commit_table
1435+
commit_count = 0
1436+
captured_sort_fields: list[list[int]] = []
1437+
1438+
def mock_commit(
1439+
tbl: Table, requirements: tuple[TableRequirement, ...], updates: tuple[TableUpdate, ...]
1440+
) -> CommitTableResponse:
1441+
nonlocal commit_count
1442+
commit_count += 1
1443+
1444+
# Extract sort field source IDs from updates
1445+
from pyiceberg.table.update import AddSortOrderUpdate
1446+
1447+
for update in updates:
1448+
if isinstance(update, AddSortOrderUpdate):
1449+
source_ids = [f.source_id for f in update.sort_order.fields]
1450+
captured_sort_fields.append(source_ids)
1451+
1452+
if commit_count == 1:
1453+
raise CommitFailedException("Simulated conflict")
1454+
return original_commit(tbl, requirements, updates)
1455+
1456+
with patch.object(catalog, "commit_table", side_effect=mock_commit):
1457+
with table.update_sort_order() as update_sort_order:
1458+
update_sort_order.asc("id", IdentityTransform())
1459+
1460+
assert commit_count == 2
1461+
assert len(captured_sort_fields) == 2
1462+
1463+
# Both attempts should resolve "id" to the same source_id (1)
1464+
# This verifies the column name is being re-resolved correctly
1465+
assert captured_sort_fields[0] == [1] # First attempt
1466+
assert captured_sort_fields[1] == [1] # Retry attempt

0 commit comments

Comments
 (0)