Skip to content

Commit 4bc975b

Browse files
committed
add UpdateSchema and fix integration test
1 parent f25cc42 commit 4bc975b

File tree

7 files changed

+523
-38
lines changed

7 files changed

+523
-38
lines changed

pyiceberg/table/update/schema.py

Lines changed: 125 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,15 @@ class UpdateSchema(UpdateTableMetadata["UpdateSchema"]):
9090
_allow_incompatible_changes: bool
9191
_case_sensitive: bool
9292

93+
# Store user intent for retry support
94+
_column_additions: list[tuple[str | tuple[str, ...], IcebergType, str | None, bool, L | None]]
95+
_column_updates: list[tuple[str | tuple[str, ...], IcebergType | None, bool | None, str | None]]
96+
_column_deletions: list[str | tuple[str, ...]]
97+
_column_renames: list[tuple[str | tuple[str, ...], str]]
98+
_move_operations: list[tuple[str, str | tuple[str, ...], str | tuple[str, ...] | None]]
99+
_optional_columns: list[str | tuple[str, ...]]
100+
_default_value_updates: list[tuple[str | tuple[str, ...], L | None]]
101+
93102
def __init__(
94103
self,
95104
transaction: Transaction,
@@ -99,22 +108,40 @@ def __init__(
99108
name_mapping: NameMapping | None = None,
100109
) -> None:
101110
super().__init__(transaction)
102-
111+
self._transaction = transaction
112+
self._allow_incompatible_changes = allow_incompatible_changes
113+
self._case_sensitive = case_sensitive
114+
self._name_mapping = name_mapping
115+
self._provided_schema = schema # Store for _reset_state
116+
117+
# Initialize user intent storage
118+
self._column_additions = []
119+
self._column_updates = []
120+
self._column_deletions = []
121+
self._column_renames = []
122+
self._move_operations = []
123+
self._optional_columns = []
124+
self._default_value_updates = []
125+
self._identifier_field_updates: set[str] | None = None
126+
127+
# Initialize state from metadata
128+
self._init_state_from_metadata(schema)
129+
130+
def _init_state_from_metadata(self, schema: Schema | None = None) -> None:
131+
"""Initialize or reinitialize state from current transaction metadata."""
103132
if isinstance(schema, Schema):
104133
self._schema = schema
105134
self._last_column_id = itertools.count(1 + schema.highest_field_id)
106135
else:
107136
self._schema = self._transaction.table_metadata.schema()
108137
self._last_column_id = itertools.count(1 + self._transaction.table_metadata.last_column_id)
109138

110-
self._name_mapping = name_mapping
111139
self._identifier_field_names = self._schema.identifier_field_names()
112140

113141
self._adds = {}
114142
self._updates = {}
115143
self._deletes = set()
116144
self._moves = {}
117-
118145
self._added_name_to_id = {}
119146

120147
def get_column_name(field_id: int) -> str:
@@ -127,9 +154,39 @@ def get_column_name(field_id: int) -> str:
127154
field_id: get_column_name(parent_field_id) for field_id, parent_field_id in self._schema._lazy_id_to_parent.items()
128155
}
129156

130-
self._allow_incompatible_changes = allow_incompatible_changes
131-
self._case_sensitive = case_sensitive
132-
self._transaction = transaction
157+
def _reset_state(self) -> None:
158+
"""Reset state for retry, rebuilding from refreshed metadata."""
159+
self._init_state_from_metadata(self._provided_schema)
160+
161+
for path, field_type, doc, required, default_value in self._column_additions:
162+
self._do_add_column(path, field_type, doc, required, default_value)
163+
164+
for path in self._column_deletions:
165+
self._do_delete_column(path)
166+
167+
for path_from, new_name in self._column_renames:
168+
self._do_rename_column(path_from, new_name)
169+
170+
for path, field_type, required, doc in self._column_updates:
171+
self._do_update_column(path, field_type, required, doc)
172+
173+
for path in self._optional_columns:
174+
self._set_column_requirement(path, required=False)
175+
176+
for path, default_value in self._default_value_updates:
177+
self._set_column_default_value(path, default_value)
178+
179+
for op, path, other_path in self._move_operations:
180+
if op == "first":
181+
self._do_move_first(path)
182+
elif op == "before":
183+
self._do_move_before(path, other_path) # type: ignore
184+
elif op == "after":
185+
self._do_move_after(path, other_path) # type: ignore
186+
187+
# Restore identifier fields if they were explicitly set
188+
if self._identifier_field_updates is not None:
189+
self._identifier_field_names = self._identifier_field_updates.copy()
133190

134191
def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
135192
"""Determine if the case of schema needs to be considered when comparing column names.
@@ -186,6 +243,19 @@ def add_column(
186243
Returns:
187244
This for method chaining.
188245
"""
246+
self._column_additions.append((path, field_type, doc, required, default_value))
247+
self._do_add_column(path, field_type, doc, required, default_value)
248+
return self
249+
250+
def _do_add_column(
251+
self,
252+
path: str | tuple[str, ...],
253+
field_type: IcebergType,
254+
doc: str | None,
255+
required: bool,
256+
default_value: L | None,
257+
) -> None:
258+
"""Internal method to add a column. Used by add_column and _reset_state."""
189259
if isinstance(path, str):
190260
if "." in path:
191261
raise ValueError(f"Cannot add column with ambiguous name: {path}, provide a tuple instead")
@@ -256,8 +326,6 @@ def add_column(
256326
else:
257327
self._adds[parent_id] = [field]
258328

259-
return self
260-
261329
def delete_column(self, path: str | tuple[str, ...]) -> UpdateSchema:
262330
"""Delete a column from a table.
263331
@@ -267,6 +335,12 @@ def delete_column(self, path: str | tuple[str, ...]) -> UpdateSchema:
267335
Returns:
268336
The UpdateSchema with the delete operation staged.
269337
"""
338+
self._column_deletions.append(path)
339+
self._do_delete_column(path)
340+
return self
341+
342+
def _do_delete_column(self, path: str | tuple[str, ...]) -> None:
343+
"""Internal method to delete a column. Used by delete_column and _reset_state."""
270344
name = (path,) if isinstance(path, str) else path
271345
full_name = ".".join(name)
272346

@@ -279,8 +353,6 @@ def delete_column(self, path: str | tuple[str, ...]) -> UpdateSchema:
279353

280354
self._deletes.add(field.field_id)
281355

282-
return self
283-
284356
def set_default_value(self, path: str | tuple[str, ...], default_value: L | None) -> UpdateSchema:
285357
"""Set the default value of a column.
286358
@@ -290,8 +362,8 @@ def set_default_value(self, path: str | tuple[str, ...], default_value: L | None
290362
Returns:
291363
The UpdateSchema with the delete operation staged.
292364
"""
365+
self._default_value_updates.append((path, default_value))
293366
self._set_column_default_value(path, default_value)
294-
295367
return self
296368

297369
def rename_column(self, path_from: str | tuple[str, ...], new_name: str) -> UpdateSchema:
@@ -304,6 +376,12 @@ def rename_column(self, path_from: str | tuple[str, ...], new_name: str) -> Upda
304376
Returns:
305377
The UpdateSchema with the rename operation staged.
306378
"""
379+
self._column_renames.append((path_from, new_name))
380+
self._do_rename_column(path_from, new_name)
381+
return self
382+
383+
def _do_rename_column(self, path_from: str | tuple[str, ...], new_name: str) -> None:
384+
"""Internal method to rename a column. Used by rename_column and _reset_state."""
307385
path_from = ".".join(path_from) if isinstance(path_from, tuple) else path_from
308386
field_from = self._schema.find_field(path_from, self._case_sensitive)
309387

@@ -338,8 +416,6 @@ def rename_column(self, path_from: str | tuple[str, ...], new_name: str) -> Upda
338416
new_identifier_path = f"{from_field_correct_casing[: -len(field_from.name)]}{new_name}"
339417
self._identifier_field_names.add(new_identifier_path)
340418

341-
return self
342-
343419
def make_column_optional(self, path: str | tuple[str, ...]) -> UpdateSchema:
344420
"""Make a column optional.
345421
@@ -349,10 +425,12 @@ def make_column_optional(self, path: str | tuple[str, ...]) -> UpdateSchema:
349425
Returns:
350426
The UpdateSchema with the requirement change staged.
351427
"""
428+
self._optional_columns.append(path)
352429
self._set_column_requirement(path, required=False)
353430
return self
354431

355432
def set_identifier_fields(self, *fields: str) -> None:
433+
self._identifier_field_updates = set(fields)
356434
self._identifier_field_names = set(fields)
357435

358436
def _set_column_requirement(self, path: str | tuple[str, ...], required: bool) -> None:
@@ -454,12 +532,25 @@ def update_column(
454532
Returns:
455533
The UpdateSchema with the type update staged.
456534
"""
457-
path = (path,) if isinstance(path, str) else path
458-
full_name = ".".join(path)
459-
460535
if field_type is None and required is None and doc is None:
461536
return self
462537

538+
# Store intent for retry support
539+
self._column_updates.append((path, field_type, required, doc))
540+
self._do_update_column(path, field_type, required, doc)
541+
return self
542+
543+
def _do_update_column(
544+
self,
545+
path: str | tuple[str, ...],
546+
field_type: IcebergType | None,
547+
required: bool | None,
548+
doc: str | None,
549+
) -> None:
550+
"""Internal method to update a column. Used by update_column and _reset_state."""
551+
path = (path,) if isinstance(path, str) else path
552+
full_name = ".".join(path)
553+
463554
field = self._schema.find_field(full_name, self._case_sensitive)
464555

465556
if field.field_id in self._deletes:
@@ -500,8 +591,6 @@ def update_column(
500591
if required is not None:
501592
self._set_column_requirement(path, required=required)
502593

503-
return self
504-
505594
def _find_for_move(self, name: str) -> int | None:
506595
try:
507596
return self._schema.find_field(name, self._case_sensitive).field_id
@@ -544,6 +633,12 @@ def move_first(self, path: str | tuple[str, ...]) -> UpdateSchema:
544633
Returns:
545634
The UpdateSchema with the move operation staged.
546635
"""
636+
self._move_operations.append(("first", path, None))
637+
self._do_move_first(path)
638+
return self
639+
640+
def _do_move_first(self, path: str | tuple[str, ...]) -> None:
641+
"""Internal method to move a field to first position. Used by move_first and _reset_state."""
547642
full_name = ".".join(path) if isinstance(path, tuple) else path
548643

549644
field_id = self._find_for_move(full_name)
@@ -553,8 +648,6 @@ def move_first(self, path: str | tuple[str, ...]) -> UpdateSchema:
553648

554649
self._move(_Move(field_id=field_id, full_name=full_name, op=_MoveOperation.First))
555650

556-
return self
557-
558651
def move_before(self, path: str | tuple[str, ...], before_path: str | tuple[str, ...]) -> UpdateSchema:
559652
"""Move the field to before another field.
560653
@@ -564,6 +657,12 @@ def move_before(self, path: str | tuple[str, ...], before_path: str | tuple[str,
564657
Returns:
565658
The UpdateSchema with the move operation staged.
566659
"""
660+
self._move_operations.append(("before", path, before_path))
661+
self._do_move_before(path, before_path)
662+
return self
663+
664+
def _do_move_before(self, path: str | tuple[str, ...], before_path: str | tuple[str, ...]) -> None:
665+
"""Internal method to move a field before another. Used by move_before and _reset_state."""
567666
full_name = ".".join(path) if isinstance(path, tuple) else path
568667
field_id = self._find_for_move(full_name)
569668

@@ -587,8 +686,6 @@ def move_before(self, path: str | tuple[str, ...], before_path: str | tuple[str,
587686

588687
self._move(_Move(field_id=field_id, full_name=full_name, other_field_id=before_field_id, op=_MoveOperation.Before))
589688

590-
return self
591-
592689
def move_after(self, path: str | tuple[str, ...], after_name: str | tuple[str, ...]) -> UpdateSchema:
593690
"""Move the field to after another field.
594691
@@ -598,6 +695,12 @@ def move_after(self, path: str | tuple[str, ...], after_name: str | tuple[str, .
598695
Returns:
599696
The UpdateSchema with the move operation staged.
600697
"""
698+
self._move_operations.append(("after", path, after_name))
699+
self._do_move_after(path, after_name)
700+
return self
701+
702+
def _do_move_after(self, path: str | tuple[str, ...], after_name: str | tuple[str, ...]) -> None:
703+
"""Internal method to move a field after another. Used by move_after and _reset_state."""
601704
full_name = ".".join(path) if isinstance(path, tuple) else path
602705

603706
field_id = self._find_for_move(full_name)
@@ -616,8 +719,6 @@ def move_after(self, path: str | tuple[str, ...], after_name: str | tuple[str, .
616719

617720
self._move(_Move(field_id=field_id, full_name=full_name, other_field_id=after_field_id, op=_MoveOperation.After))
618721

619-
return self
620-
621722
def _commit(self) -> UpdatesAndRequirements:
622723
"""Apply the pending changes and commit."""
623724
from pyiceberg.table import TableProperties

tests/catalog/test_sql.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,15 @@ def test_append_table(catalog: SqlCatalog, table_schema_simple: Schema, table_id
14231423
],
14241424
)
14251425
def test_concurrent_commit_table(catalog: SqlCatalog, table_schema_simple: Schema, table_identifier: Identifier) -> None:
1426+
"""Test that concurrent schema updates are resolved via retry (Java-like behavior).
1427+
1428+
When two writers concurrently update the schema:
1429+
1. First writer succeeds
1430+
2. Second writer's first attempt fails (stale schema_id)
1431+
3. Second writer retries with refreshed metadata and succeeds
1432+
1433+
This matches Java's BaseTransaction.applyUpdates() behavior.
1434+
"""
14261435
namespace = Catalog.namespace_from(table_identifier)
14271436
catalog.create_namespace(namespace)
14281437
table_a = catalog.create_table(table_identifier, table_schema_simple)
@@ -1431,10 +1440,17 @@ def test_concurrent_commit_table(catalog: SqlCatalog, table_schema_simple: Schem
14311440
with table_a.update_schema() as update:
14321441
update.add_column(path="b", field_type=IntegerType())
14331442

1434-
with pytest.raises(CommitFailedException, match="Requirement failed: current schema id has changed: expected 0, found 1"):
1435-
# This one should fail since it already has been updated
1436-
with table_b.update_schema() as update:
1437-
update.add_column(path="c", field_type=IntegerType())
1443+
# With retry support, this now succeeds after refreshing metadata
1444+
# (Java-like behavior: retry resolves conflicts)
1445+
with table_b.update_schema() as update:
1446+
update.add_column(path="c", field_type=IntegerType())
1447+
1448+
# Verify both columns were added
1449+
table_a.refresh()
1450+
field_names = [f.name for f in table_a.schema().fields]
1451+
assert "b" in field_names
1452+
assert "c" in field_names
1453+
assert table_a.schema().schema_id == 2
14381454

14391455

14401456
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)