diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 29d09e2604..fa331772f8 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -185,6 +185,39 @@ with catalog.create_table_transaction(identifier="docs_example.bids", schema=sch txn.set_properties(test_a="test_aa", test_b="test_b", test_c="test_c") ``` +## Replace a table + +Atomically replace an existing table's schema, partition spec, sort order, location, and properties. The table UUID and history (snapshots, schemas, specs, sort orders, metadata log) are preserved; the current snapshot is cleared (the `main` branch ref is removed). `replace_table` redefines the table in this way; `replace_table_transaction` lets you write new data alongside this change to permit RTAS (replace-table-as-select) workflows. + +```python +from pyiceberg.schema import Schema +from pyiceberg.types import NestedField, LongType, StringType, BooleanType + +new_schema = Schema( + NestedField(field_id=1, name="datetime", field_type=LongType(), required=False), + NestedField(field_id=2, name="symbol", field_type=StringType(), required=False), + NestedField(field_id=3, name="active", field_type=BooleanType(), required=False), +) +catalog.replace_table(identifier="docs_example.bids", schema=new_schema) +``` + +Field IDs are reused by name from the previous schema; new columns get fresh IDs above `last-column-id`. + +Unlike the other fields, table properties are *merged* on replace: properties you don't pass are preserved on the table. To remove a property as part of the replace, use `replace_table_transaction` and drop it explicitly within the transaction. + +Use `replace_table_transaction` to stage additional changes (writes, property updates, schema evolution) before committing — for example, swap the schema and write new data atomically: + +```python +with catalog.replace_table_transaction(identifier="docs_example.bids", schema=df.schema) as txn: + txn.append(df) +``` + +To upgrade the table's format version as part of the replace, pass `format-version` in `properties`: + +```python +catalog.replace_table(identifier="docs_example.bids", schema=new_schema, properties={"format-version": "2"}) +``` + ## Register a table To register a table using existing metadata: diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index 95ceaa539f..17c9ab7c96 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -42,20 +42,25 @@ ) from pyiceberg.io import FileIO, load_file_io from pyiceberg.manifest import ManifestFile -from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec -from pyiceberg.schema import Schema +from pyiceberg.partitioning import ( + UNPARTITIONED_PARTITION_SPEC, + PartitionSpec, + assign_fresh_partition_spec_ids_for_replace, +) +from pyiceberg.schema import Schema, assign_fresh_schema_ids_for_replace from pyiceberg.serializers import ToOutputFile from pyiceberg.table import ( DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, CommitTableResponse, CreateTableTransaction, + ReplaceTableTransaction, StagedTable, Table, TableProperties, ) from pyiceberg.table.locations import load_location_provider from pyiceberg.table.metadata import TableMetadata, TableMetadataV1, new_table_metadata -from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder +from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder, assign_fresh_sort_order_ids from pyiceberg.table.update import ( TableRequirement, TableUpdate, @@ -444,6 +449,135 @@ def create_table_if_not_exists( except TableAlreadyExistsError: return self.load_table(identifier) + def replace_table( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> Table: + """Atomically replace a table's schema, spec, sort order, location, and properties. + + The table UUID and history (snapshots, schemas, specs, sort orders) are preserved. + The current snapshot is cleared (main branch ref is removed). + + Args: + identifier (str | Identifier): Table identifier. + schema (Schema): New table schema. + location (str | None): New table location. Defaults to the existing location. + partition_spec (PartitionSpec): New partition spec. + sort_order (SortOrder): New sort order. + properties (Properties): Properties to apply. Merged on top of the existing + table properties: keys present here override existing values; existing keys + not present here are preserved. To remove a property, follow up with a + transaction that removes it explicitly. + + Returns: + Table: the replaced table instance. + + Raises: + NoSuchTableError: If the table does not exist. + """ + return self.replace_table_transaction( + identifier, schema, location, partition_spec, sort_order, properties + ).commit_transaction() + + @abstractmethod + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + """Create a ReplaceTableTransaction. + + The transaction can be used to stage additional changes (schema evolution, + partition evolution, etc.) before committing. + + Args: + identifier (str | Identifier): Table identifier. + schema (Schema): New table schema. + location (str | None): New table location. Defaults to the existing location. + partition_spec (PartitionSpec): New partition spec. + sort_order (SortOrder): New sort order. + properties (Properties): Properties to apply. Merged on top of the existing + table properties: keys present here override existing values; existing keys + not present here are preserved. To remove a property, follow up with a + transaction that removes it explicitly. + + Returns: + ReplaceTableTransaction: A transaction for the replace operation. + + Raises: + NoSuchTableError: If the table does not exist. + """ + + def _replace_staged_table( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None, + partition_spec: PartitionSpec, + sort_order: SortOrder, + properties: Properties, + ) -> tuple[StagedTable, Schema, PartitionSpec, SortOrder, str]: + """Load the existing table and build fresh schema/spec/sort-order for replacement. + + - reuses existing field IDs by name (from the current schema) + - reuses partition field IDs by `(source, transform)` across all specs (v2+), + or carries forward the current spec with `VoidTransform`s (v1) + - reassigns sort field IDs against the fresh schema + - resolves `location` to the existing table's location when omitted + + Returns: + A tuple `(staged_table, fresh_schema, fresh_partition_spec, fresh_sort_order, resolved_location)`. + """ + existing_table = self.load_table(identifier) + existing_metadata = existing_table.metadata + + requested_format_version = properties.get(TableProperties.FORMAT_VERSION) + if requested_format_version is not None and int(requested_format_version) < existing_metadata.format_version: + raise ValueError( + f"Cannot downgrade format-version from {existing_metadata.format_version} to {requested_format_version}" + ) + resolved_format_version = ( + int(requested_format_version) if requested_format_version is not None else existing_metadata.format_version + ) + iceberg_schema = self._convert_schema_if_needed(schema, cast(TableVersion, resolved_format_version)) + iceberg_schema.check_format_version_compatibility(cast(TableVersion, resolved_format_version)) + + fresh_schema, _ = assign_fresh_schema_ids_for_replace( + iceberg_schema, existing_metadata.schema(), existing_metadata.last_column_id + ) + + fresh_partition_spec, _ = assign_fresh_partition_spec_ids_for_replace( + partition_spec, + iceberg_schema, + fresh_schema, + existing_metadata.partition_specs, + existing_metadata.last_partition_id, + format_version=existing_metadata.format_version, + current_spec=existing_metadata.spec(), + ) + + fresh_sort_order = assign_fresh_sort_order_ids(sort_order, iceberg_schema, fresh_schema) + + resolved_location = location.rstrip("/") if location else existing_metadata.location + + staged_table = StagedTable( + identifier=existing_table.name(), + metadata=existing_metadata, + metadata_location=existing_table.metadata_location, + io=existing_table.io, + catalog=self, + ) + return staged_table, fresh_schema, fresh_partition_spec, fresh_sort_order, resolved_location + @abstractmethod def load_table(self, identifier: str | Identifier) -> Table: """Load the table's metadata and returns the table instance. @@ -924,6 +1058,28 @@ def create_table_transaction( self._create_staged_table(identifier, schema, location, partition_spec, sort_order, properties) ) + @override + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + staged_table, fresh_schema, fresh_spec, fresh_sort_order, resolved_location = self._replace_staged_table( + identifier, schema, location, partition_spec, sort_order, properties + ) + return ReplaceTableTransaction( + table=staged_table, + new_schema=fresh_schema, + new_spec=fresh_spec, + new_sort_order=fresh_sort_order, + new_location=resolved_location, + new_properties=properties, + ) + @override def table_exists(self, identifier: str | Identifier) -> bool: try: diff --git a/pyiceberg/catalog/noop.py b/pyiceberg/catalog/noop.py index aeb3c72843..06348903af 100644 --- a/pyiceberg/catalog/noop.py +++ b/pyiceberg/catalog/noop.py @@ -28,6 +28,7 @@ from pyiceberg.table import ( CommitTableResponse, CreateTableTransaction, + ReplaceTableTransaction, Table, ) from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder @@ -68,6 +69,18 @@ def create_table_transaction( ) -> CreateTableTransaction: raise NotImplementedError + @override + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + raise NotImplementedError + @override def load_table(self, identifier: str | Identifier) -> Table: raise NotImplementedError diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index 7fa81312d1..072b9f2034 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -68,13 +68,18 @@ FileIO, load_file_io, ) -from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec, assign_fresh_partition_spec_ids +from pyiceberg.partitioning import ( + UNPARTITIONED_PARTITION_SPEC, + PartitionSpec, + assign_fresh_partition_spec_ids, +) from pyiceberg.schema import Schema, assign_fresh_schema_ids from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, CreateTableTransaction, FileScanTask, + ReplaceTableTransaction, StagedTable, Table, TableIdentifier, @@ -953,6 +958,29 @@ def create_table_transaction( staged_table = self._response_to_staged_table(self.identifier_to_tuple(identifier), table_response) return CreateTableTransaction(staged_table) + @override + @retry(**_RETRY_ARGS) + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + staged_table, fresh_schema, fresh_spec, fresh_sort_order, resolved_location = self._replace_staged_table( + identifier, schema, location, partition_spec, sort_order, properties + ) + return ReplaceTableTransaction( + table=staged_table, + new_schema=fresh_schema, + new_spec=fresh_spec, + new_sort_order=fresh_sort_order, + new_location=resolved_location, + new_properties=properties, + ) + @override @retry(**_RETRY_ARGS) def create_view( diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index 3de185d886..b51f37443c 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -335,6 +335,175 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID) +def assign_fresh_partition_spec_ids_for_replace( + spec: PartitionSpec, + old_schema: Schema, + fresh_schema: Schema, + existing_specs: list[PartitionSpec], + last_partition_id: int | None, + format_version: int = 2, + current_spec: PartitionSpec | None = None, +) -> tuple[PartitionSpec, int]: + """Assign partition field IDs for a replace operation, reusing IDs from existing specs. + + - For v2+, reuse partition field IDs by `(source_id, transform)` across all existing specs. + New fields get IDs starting from `last_partition_id + 1`. + - For v1, the current spec's fields must be preserved (v1 specs are append-only). Fields + absent from the new spec are carried forward with a `VoidTransform`. Matching new fields + reuse the existing partition field ID; remaining new fields are appended with fresh IDs. + + Args: + spec: The new partition spec to assign IDs to. Its `source_id`s reference `old_schema`. + old_schema: The schema that the new spec's `source_id`s reference. + fresh_schema: The schema with freshly assigned field IDs. + existing_specs: All partition specs from the existing table metadata. + last_partition_id: The current table's `last_partition_id`. + format_version: Table format version. Required to be set to 1 for v1 carry-forward. + current_spec: The current default partition spec. Required when `format_version <= 1`. + + Returns: + A tuple of `(fresh_spec, new_last_partition_id)`. + """ + effective_last_partition_id = last_partition_id if last_partition_id is not None else PARTITION_FIELD_ID_START - 1 + + if format_version <= 1: + if current_spec is None: + raise ValueError("current_spec is required for v1 replace_table") + return _assign_fresh_partition_spec_ids_for_replace_v1( + spec, old_schema, fresh_schema, current_spec, effective_last_partition_id + ) + + # v2+: reuse field IDs by (source_id, transform) across all specs. When the same + # (source_id, transform) appears in multiple specs, prefer the highest field_id. + transform_to_field_id: dict[tuple[int, str], int] = {} + for existing_spec in existing_specs: + for field in existing_spec.fields: + key = (field.source_id, str(field.transform)) + if key not in transform_to_field_id or field.field_id > transform_to_field_id[key]: + transform_to_field_id[key] = field.field_id + + next_id = effective_last_partition_id + partition_fields = [] + for field in spec.fields: + original_column_name = old_schema.find_column_name(field.source_id) + if original_column_name is None: + raise ValueError(f"Could not find in old schema: {field}") + fresh_field = fresh_schema.find_field(original_column_name) + if fresh_field is None: + raise ValueError(f"Could not find field in fresh schema: {original_column_name}") + + validate_partition_name(field.name, field.transform, fresh_field.field_id, fresh_schema, set()) + + key = (fresh_field.field_id, str(field.transform)) + if key in transform_to_field_id: + partition_field_id = transform_to_field_id[key] + else: + next_id += 1 + partition_field_id = next_id + transform_to_field_id[key] = partition_field_id + + partition_fields.append( + PartitionField( + name=field.name, + source_id=fresh_field.field_id, + field_id=partition_field_id, + transform=field.transform, + ) + ) + + # `next_id` starts at `effective_last_partition_id` and only increments, so it is the + # new last partition id. + return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID), next_id + + +def _assign_fresh_partition_spec_ids_for_replace_v1( + spec: PartitionSpec, + old_schema: Schema, + fresh_schema: Schema, + current_spec: PartitionSpec, + effective_last_partition_id: int, +) -> tuple[PartitionSpec, int]: + """v1 branch of `assign_fresh_partition_spec_ids_for_replace`. See parent docstring.""" + # Build (fresh_source_id, transform) → (new_field, fresh_source_id) for the new spec, + # in insertion order so leftover fields keep their declared order on append. + new_field_by_key: dict[tuple[int, str], tuple[PartitionField, int]] = {} + new_field_names: list[str] = [] + for new_field in spec.fields: + col_name = old_schema.find_column_name(new_field.source_id) + if col_name is None: + raise ValueError(f"Could not find in old schema: {new_field}") + fresh_field = fresh_schema.find_field(col_name) + if fresh_field is None: + raise ValueError(f"Could not find field in fresh schema: {col_name}") + validate_partition_name(new_field.name, new_field.transform, fresh_field.field_id, fresh_schema, set()) + key = (fresh_field.field_id, str(new_field.transform)) + new_field_by_key[key] = (new_field, fresh_field.field_id) + new_field_names.append(new_field.name) + + # Walk current spec, carrying forward each field. Matching new fields consume their key; + # missing fields become void transforms. + used_names: set[str] = set(new_field_names) + partition_fields = [] + for cur_field in current_spec.fields: + key = (cur_field.source_id, str(cur_field.transform)) + match = new_field_by_key.pop(key, None) + if match is not None: + new_field, fresh_source_id = match + partition_fields.append( + PartitionField( + name=new_field.name, + source_id=fresh_source_id, + field_id=cur_field.field_id, + transform=new_field.transform, + ) + ) + used_names.add(new_field.name) + else: + void_name = _unique_void_name(cur_field.name, cur_field.field_id, used_names) + used_names.add(void_name) + partition_fields.append( + PartitionField( + name=void_name, + source_id=cur_field.source_id, + field_id=cur_field.field_id, + transform=VoidTransform(), + ) + ) + + # Append remaining new fields at the end with fresh partition IDs. + next_id = effective_last_partition_id + for new_field, fresh_source_id in new_field_by_key.values(): + next_id += 1 + partition_fields.append( + PartitionField( + name=new_field.name, + source_id=fresh_source_id, + field_id=next_id, + transform=new_field.transform, + ) + ) + + # `next_id` starts at `effective_last_partition_id` and only increments, so it is the + # new last partition id. + return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID), next_id + + +def _unique_void_name(base_name: str, field_id: int, used_names: set[str]) -> str: + """Pick a void-transform name that does not collide with already-used names. + + First tries `base_name`; if taken, tries `base_name_{field_id}`; if still taken, + appends `_2`, `_3`, ... until unique. + """ + if base_name not in used_names: + return base_name + candidate = f"{base_name}_{field_id}" + suffix = 2 + while candidate in used_names: + candidate = f"{base_name}_{field_id}_{suffix}" + suffix += 1 + return candidate + + T = TypeVar("T") diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index fd60eb8f94..7ae198c74d 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1380,6 +1380,62 @@ def primitive(self, primitive: PrimitiveType) -> PrimitiveType: return primitive +class _SetFreshIDsForReplace(_SetFreshIDs): + """Assign fresh IDs for a replace operation, reusing IDs from the base schema by field name. + + For each field in the new schema, if a field with the same full name exists in the + base schema, its ID is reused; otherwise a fresh ID is allocated starting from + last_column_id + 1. + + Note: ID reuse is purely name-based — a field whose name matches but whose type differs + (e.g. `int` → `string`) will reuse the base ID. This is intentional: replace allows + arbitrary schema changes; type compatibility is the caller's responsibility. + """ + + def __init__(self, old_id_to_base_id: dict[int, int], starting_id: int) -> None: + self.old_id_to_new_id: dict[int, int] = {} + self._old_id_to_base_id = old_id_to_base_id + counter = itertools.count(starting_id + 1) + self.next_id_func = lambda: next(counter) + + def _get_and_increment(self, current_id: int) -> int: + if current_id in self._old_id_to_base_id: + new_id = self._old_id_to_base_id[current_id] + else: + new_id = self.next_id_func() + self.old_id_to_new_id[current_id] = new_id + return new_id + + +def assign_fresh_schema_ids_for_replace(schema: Schema, base_schema: Schema, last_column_id: int) -> tuple[Schema, int]: + """Assign fresh IDs to a schema for a replace operation, reusing IDs from the base schema. + + For each field in the new schema, if a field with the same full path name exists + in the base schema, its ID is reused. New fields get IDs starting from + last_column_id + 1. + + Args: + schema: The new schema to assign IDs to. + base_schema: The existing table's current schema (IDs are reused from here by name). + last_column_id: The current table's last_column_id (new IDs start above this). + + Returns: + A tuple of (fresh_schema, new_last_column_id). + """ + base_name_to_id = index_by_name(base_schema) + new_id_to_name = index_name_by_id(schema) + + old_id_to_base_id: dict[int, int] = {} + for old_id, name in new_id_to_name.items(): + if name in base_name_to_id: + old_id_to_base_id[old_id] = base_name_to_id[name] + + visitor = _SetFreshIDsForReplace(old_id_to_base_id, last_column_id) + fresh_schema = pre_order_visit(schema, visitor) + new_last_column_id = max(fresh_schema.highest_field_id, last_column_id) + return fresh_schema, new_last_column_id + + # Implementation copied from Apache Iceberg repo. def make_compatible_name(name: str) -> str: """Make a field name compatible with Avro specification. diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b8d87143c9..06221675c4 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -58,10 +58,13 @@ AddSchemaUpdate, AddSortOrderUpdate, AssertCreate, + AssertLastAssignedFieldId, + AssertLastAssignedPartitionId, AssertRefSnapshotId, AssertTableUUID, AssignUUIDUpdate, RemovePropertiesUpdate, + RemoveSnapshotRefUpdate, SetCurrentSchemaUpdate, SetDefaultSortOrderUpdate, SetDefaultSpecUpdate, @@ -1009,6 +1012,150 @@ def commit_transaction(self) -> Table: return self._table +class ReplaceTableTransaction(Transaction): + """A transaction that replaces an existing table's schema, spec, sort order, location, and properties. + + The existing table UUID, snapshots, snapshot log, metadata log, and history are preserved. + The "main" branch ref is removed (current-snapshot-id set to -1), and new + schema/spec/sort-order/location/properties are applied. + """ + + def __init__( + self, + table: StagedTable, + new_schema: Schema, + new_spec: PartitionSpec, + new_sort_order: SortOrder, + new_location: str, + new_properties: Properties, + ) -> None: + super().__init__(table, autocommit=False) + self._initial_changes(table.metadata, new_schema, new_spec, new_sort_order, new_location, new_properties) + + def _initial_changes( + self, + table_metadata: TableMetadata, + new_schema: Schema, + new_spec: PartitionSpec, + new_sort_order: SortOrder, + new_location: str, + new_properties: Properties, + ) -> None: + """Set the initial changes that transform the existing table into the replacement. + + Always emits `SetCurrentSchema` / `SetDefaultPartitionSpec` / `SetDefaultSortOrder` + (even when the resulting id is reused) so the request body unambiguously signals a + replace. Bumps `format-version` when the new properties request it. + """ + # Upgrade format-version if requested via properties. + requested_format_version_str = new_properties.get(TableProperties.FORMAT_VERSION) + if requested_format_version_str is not None: + requested_format_version = int(requested_format_version_str) + if requested_format_version > table_metadata.format_version: + self._updates += (UpgradeFormatVersionUpdate(format_version=requested_format_version),) + + # Remove the main branch ref to clear the current snapshot. + self._updates += (RemoveSnapshotRefUpdate(ref_name=MAIN_BRANCH),) + + # Schema: reuse an existing schema_id if structurally identical, else add a new one + # with a fresh schema_id (max + 1, matching UpdateSchema's convention). + existing_schema_id = self._find_matching_schema_id(table_metadata, new_schema) + if existing_schema_id is not None: + self._updates += (SetCurrentSchemaUpdate(schema_id=existing_schema_id),) + else: + next_schema_id = max((s.schema_id for s in table_metadata.schemas), default=-1) + 1 + schema_with_fresh_id = new_schema.model_copy(update={"schema_id": next_schema_id}) + self._updates += ( + AddSchemaUpdate(schema_=schema_with_fresh_id), + SetCurrentSchemaUpdate(schema_id=-1), + ) + + # Partition spec: same reuse-or-add pattern. Assign a fresh spec_id on add to avoid + # collisions with existing specs (AddPartitionSpecUpdate refuses duplicate IDs). + effective_spec = UNPARTITIONED_PARTITION_SPEC if new_spec.is_unpartitioned() else new_spec + existing_spec_id = self._find_matching_spec_id(table_metadata, effective_spec) + if existing_spec_id is not None: + self._updates += (SetDefaultSpecUpdate(spec_id=existing_spec_id),) + else: + next_spec_id = max((s.spec_id for s in table_metadata.partition_specs), default=-1) + 1 + spec_with_fresh_id = PartitionSpec(*effective_spec.fields, spec_id=next_spec_id) + self._updates += ( + AddPartitionSpecUpdate(spec=spec_with_fresh_id), + SetDefaultSpecUpdate(spec_id=-1), + ) + + # Sort order: same reuse-or-add pattern with fresh order_id on add. + effective_sort_order = UNSORTED_SORT_ORDER if new_sort_order.is_unsorted else new_sort_order + existing_order_id = self._find_matching_sort_order_id(table_metadata, effective_sort_order) + if existing_order_id is not None: + self._updates += (SetDefaultSortOrderUpdate(sort_order_id=existing_order_id),) + else: + next_order_id = max((o.order_id for o in table_metadata.sort_orders), default=-1) + 1 + sort_order_with_fresh_id = SortOrder(*effective_sort_order.fields, order_id=next_order_id) + self._updates += ( + AddSortOrderUpdate(sort_order=sort_order_with_fresh_id), + SetDefaultSortOrderUpdate(sort_order_id=-1), + ) + + # Set location if changed. + if new_location != table_metadata.location: + self._updates += (SetLocationUpdate(location=new_location),) + + # Merge properties (SetPropertiesUpdate merges onto existing properties). + # Strip `format-version` so it does not get persisted as a regular property. + persisted_properties = {k: v for k, v in new_properties.items() if k != TableProperties.FORMAT_VERSION} + if persisted_properties: + self._updates += (SetPropertiesUpdate(updates=persisted_properties),) + + @staticmethod + def _find_matching_schema_id(table_metadata: TableMetadata, schema: Schema) -> int | None: + """Find an existing schema structurally equal to the given one, returning its schema_id or None.""" + for existing in table_metadata.schemas: + if existing == schema: + return existing.schema_id + return None + + @staticmethod + def _find_matching_spec_id(table_metadata: TableMetadata, spec: PartitionSpec) -> int | None: + """Find an existing partition spec with the same fields, returning its spec_id or None.""" + for existing in table_metadata.partition_specs: + if existing.fields == spec.fields: + return existing.spec_id + return None + + @staticmethod + def _find_matching_sort_order_id(table_metadata: TableMetadata, sort_order: SortOrder) -> int | None: + """Find an existing sort order with the same fields, returning its order_id or None.""" + for existing in table_metadata.sort_orders: + if existing.fields == sort_order.fields: + return existing.order_id + return None + + def commit_transaction(self) -> Table: + """Commit the changes to the catalog. + + Returns: + The table with the updates applied. + """ + if len(self._updates) > 0: + base = self._table.metadata + requirements: tuple[TableRequirement, ...] = ( + AssertTableUUID(uuid=base.table_uuid), + AssertLastAssignedFieldId(last_assigned_field_id=base.last_column_id), + ) + if base.last_partition_id is not None: + requirements += (AssertLastAssignedPartitionId(last_assigned_partition_id=base.last_partition_id),) + self._table._do_commit( # pylint: disable=W0212 + updates=self._updates, + requirements=requirements, + ) + + self._updates = () + self._requirements = () + + return self._table + + class Namespace(IcebergRootModel[list[str]]): """Reference to one or more levels of a namespace.""" diff --git a/tests/catalog/test_catalog_behaviors.py b/tests/catalog/test_catalog_behaviors.py index 01e0d2ce31..9157049a77 100644 --- a/tests/catalog/test_catalog_behaviors.py +++ b/tests/catalog/test_catalog_behaviors.py @@ -45,7 +45,7 @@ from pyiceberg.table.snapshots import Operation from pyiceberg.table.sorting import NullOrder, SortDirection, SortField, SortOrder from pyiceberg.table.update import AddSchemaUpdate, SetCurrentSchemaUpdate -from pyiceberg.transforms import IdentityTransform +from pyiceberg.transforms import IdentityTransform, VoidTransform from pyiceberg.typedef import Identifier from pyiceberg.types import BooleanType, IntegerType, LongType, NestedField, StringType @@ -387,6 +387,321 @@ def test_load_table_from_self_identifier( assert table.metadata == loaded_table.metadata +_SIMPLE_SCHEMA = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), +) + + +def _create_simple_table( + catalog: Catalog, + identifier: Identifier, + *, + schema: Schema = _SIMPLE_SCHEMA, + format_version: int = 2, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + properties: dict[str, str] | None = None, +) -> tuple[Identifier, Schema]: + namespace = Catalog.namespace_from(identifier) + catalog.create_namespace_if_not_exists(namespace) + merged_properties = {"format-version": str(format_version), **(properties or {})} + catalog.create_table(identifier, schema=schema, partition_spec=partition_spec, properties=merged_properties) + return identifier, schema + + +def _simple_data(num_rows: int = 2) -> pa.Table: + return pa.Table.from_pydict( + {"id": list(range(num_rows)), "data": [chr(ord("a") + i) for i in range(num_rows)]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.large_string())]), + ) + + +def test_replace_table_preserves_uuid_and_clears_current_snapshot(catalog: Catalog, test_table_identifier: Identifier) -> None: + _create_simple_table(catalog, test_table_identifier) + original = catalog.load_table(test_table_identifier) + original.append(_simple_data()) + after_append = catalog.load_table(test_table_identifier) + assert after_append.metadata.current_snapshot_id is not None, "fixture must produce a snapshot before we replace" + snapshots_before = len(after_append.metadata.snapshots) + assert snapshots_before >= 1 + + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="extra", field_type=BooleanType(), required=False), + ) + replaced = catalog.replace_table(test_table_identifier, schema=new_schema) + + assert replaced.metadata.table_uuid == original.metadata.table_uuid + assert replaced.metadata.current_snapshot_id is None + assert len(replaced.metadata.snapshots) == snapshots_before + + +@pytest.mark.parametrize( + "extra_fields, expected_schema_ids, expected_current, expected_fields, expected_last_col_id", + [ + ([], [0], 0, {"id": 1, "data": 2}, 2), + ( + [NestedField(field_id=99, name="extra", field_type=BooleanType(), required=False)], + [0, 1], + 1, + {"id": 1, "data": 2, "extra": 3}, + 3, + ), + ], + ids=["identical-reuses-id", "extended-adds-new-id"], +) +def test_replace_table_schema_id_reuse( + catalog: Catalog, + test_table_identifier: Identifier, + extra_fields: list[NestedField], + expected_schema_ids: list[int], + expected_current: int, + expected_fields: dict[str, int], + expected_last_col_id: int, +) -> None: + _, base_schema = _create_simple_table(catalog, test_table_identifier) + new_schema = Schema(*base_schema.fields, *extra_fields) + replaced = catalog.replace_table(test_table_identifier, schema=new_schema) + + assert sorted(s.schema_id for s in replaced.metadata.schemas) == expected_schema_ids + assert replaced.metadata.current_schema_id == expected_current + assert {f.name: f.field_id for f in replaced.metadata.schema().fields} == expected_fields + assert replaced.metadata.last_column_id == expected_last_col_id + + +def test_replace_table_preserves_identifier_field_ids(catalog: Catalog, test_table_identifier: Identifier) -> None: + schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + identifier_field_ids=[1], + ) + _create_simple_table(catalog, test_table_identifier, schema=schema) + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="extra", field_type=BooleanType(), required=False), + identifier_field_ids=[1], + ) + replaced = catalog.replace_table(test_table_identifier, schema=new_schema) + assert list(replaced.schema().identifier_field_ids) == [1] + + +def test_replace_table_drops_identifier_field(catalog: Catalog, test_table_identifier: Identifier) -> None: + schema_with_id = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + identifier_field_ids=[1], + ) + _create_simple_table(catalog, test_table_identifier, schema=schema_with_id) + schema_without_id = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + replaced = catalog.replace_table(test_table_identifier, schema=schema_without_id) + assert list(replaced.schema().identifier_field_ids) == [] + + +def test_replace_table_reuses_partition_spec_id(catalog: Catalog, test_table_identifier: Identifier) -> None: + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, name="id_part", transform=IdentityTransform())) + _, schema = _create_simple_table(catalog, test_table_identifier, partition_spec=spec) + replaced = catalog.replace_table(test_table_identifier, schema=schema, partition_spec=spec) + assert [s.spec_id for s in replaced.metadata.partition_specs] == [0] + assert replaced.metadata.default_spec_id == 0 + + +def test_replace_table_with_sort_order_changes(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier) + sort = SortOrder(SortField(source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC)) + + sorted_table = catalog.replace_table(test_table_identifier, schema=schema, sort_order=sort) + assert sorted_table.sort_order().fields == sort.fields + sorted_order_id = sorted_table.metadata.default_sort_order_id + assert sorted_order_id != 0 + + unsorted_table = catalog.replace_table(test_table_identifier, schema=schema) + assert unsorted_table.sort_order().is_unsorted + assert unsorted_table.metadata.default_sort_order_id == 0 + + replayed = catalog.replace_table(test_table_identifier, schema=schema, sort_order=sort) + assert replayed.metadata.default_sort_order_id == sorted_order_id + + +def test_replace_table_inherits_existing_location(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier) + existing = catalog.load_table(test_table_identifier).metadata.location + replaced = catalog.replace_table(test_table_identifier, schema=schema) + assert replaced.metadata.location == existing + + +@pytest.mark.parametrize("trailing_slash", [False, True], ids=["no-slash", "trailing-slash"]) +def test_replace_table_uses_explicit_location( + catalog: Catalog, test_table_identifier: Identifier, tmp_path: Path, trailing_slash: bool +) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier) + bare = f"file://{tmp_path}/relocated" + arg = bare + "/" if trailing_slash else bare + replaced = catalog.replace_table(test_table_identifier, schema=schema, location=arg) + assert replaced.metadata.location == bare + + +def test_replace_table_merges_properties_with_overrides_and_additions( + catalog: Catalog, test_table_identifier: Identifier +) -> None: + schema = Schema(NestedField(field_id=1, name="id", field_type=LongType(), required=False)) + _create_simple_table(catalog, test_table_identifier, schema=schema, properties={"keep": "yes", "override": "old"}) + replaced = catalog.replace_table(test_table_identifier, schema=schema, properties={"override": "new", "new_key": "v"}) + assert replaced.properties["keep"] == "yes" + assert replaced.properties["override"] == "new" + assert replaced.properties["new_key"] == "v" + + +def test_replace_table_upgrades_format_version(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier, format_version=1) + assert catalog.load_table(test_table_identifier).format_version == 1 + replaced = catalog.replace_table(test_table_identifier, schema=schema, properties={"format-version": "2"}) + assert replaced.format_version == 2 + # `format-version` is a control input, not a regular property — must not leak into persisted properties. + assert "format-version" not in replaced.properties + + +def test_replace_table_rejects_format_version_downgrade(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier, format_version=2) + with pytest.raises(ValueError, match="Cannot downgrade format-version"): + catalog.replace_table(test_table_identifier, schema=schema, properties={"format-version": "1"}) + + +def test_replace_table_v1_carries_forward_partition_fields_as_void(catalog: Catalog, test_table_identifier: Identifier) -> None: + """v1 specs are append-only; dropped partition fields are carried forward as VoidTransform.""" + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, name="id_part", transform=IdentityTransform())) + _, schema = _create_simple_table(catalog, test_table_identifier, partition_spec=spec, format_version=1) + + replaced = catalog.replace_table(test_table_identifier, schema=schema) + new_spec = replaced.spec() + void_field = next(f for f in new_spec.fields if f.field_id == 1000) + assert isinstance(void_field.transform, VoidTransform) + assert void_field.source_id == 1 + assert void_field.name == "id_part" + + +def test_replace_table_v2_does_not_carry_forward_void_field(catalog: Catalog, test_table_identifier: Identifier) -> None: + """v2 specs are not append-only — a dropped partition field is not carried forward (unlike v1).""" + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, name="id_part", transform=IdentityTransform())) + _, schema = _create_simple_table(catalog, test_table_identifier, partition_spec=spec, format_version=2) + + replaced = catalog.replace_table(test_table_identifier, schema=schema) + new_spec = replaced.spec() + assert new_spec.is_unpartitioned() + assert all(not isinstance(f.transform, VoidTransform) for f in new_spec.fields) + + +def test_replace_after_format_version_upgrade(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier, format_version=1) + upgraded = catalog.replace_table(test_table_identifier, schema=schema, properties={"format-version": "2"}) + assert upgraded.format_version == 2 + + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="extra", field_type=BooleanType(), required=False), + ) + replaced = catalog.replace_table(test_table_identifier, schema=new_schema) + assert replaced.format_version == 2 + assert {f.name for f in replaced.schema().fields} == {"id", "data", "extra"} + + +def test_replace_table_raises_when_table_does_not_exist(catalog: Catalog, test_table_identifier: Identifier) -> None: + schema = Schema(NestedField(field_id=1, name="id", field_type=LongType(), required=False)) + with pytest.raises(NoSuchTableError): + catalog.replace_table(test_table_identifier, schema=schema) + + +def test_replace_table_transaction_can_stage_additional_changes(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier) + with catalog.replace_table_transaction(test_table_identifier, schema=schema) as txn: + txn.set_properties({"staged": "yes"}) + replaced = catalog.load_table(test_table_identifier) + assert replaced.properties.get("staged") == "yes" + + +def test_replace_table_transaction_with_write_atomic_rtas(catalog: Catalog, test_table_identifier: Identifier) -> None: + _create_simple_table(catalog, test_table_identifier) + catalog.load_table(test_table_identifier).append(_simple_data(num_rows=1)) + old_snapshot_id = catalog.load_table(test_table_identifier).current_snapshot().snapshot_id # type: ignore[union-attr] + + new_data = pa.Table.from_pydict( + {"id": [10, 20], "name": ["alice", "bob"]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("name", pa.large_string())]), + ) + with catalog.replace_table_transaction(test_table_identifier, schema=new_data.schema) as txn: + txn.append(new_data) + + replaced = catalog.load_table(test_table_identifier) + new_snapshot = replaced.current_snapshot() + assert new_snapshot is not None + assert new_snapshot.snapshot_id != old_snapshot_id + assert new_snapshot.parent_snapshot_id is None, "fresh start: new snapshot must not inherit old lineage" + assert any(s.snapshot_id == old_snapshot_id for s in replaced.metadata.snapshots) + assert {f.name for f in replaced.schema().fields} == {"id", "name"} + assert replaced.scan().to_arrow().num_rows == 2 + + +def test_replace_table_transaction_rolls_back_on_failure(catalog: Catalog, test_table_identifier: Identifier) -> None: + _create_simple_table(catalog, test_table_identifier) + catalog.load_table(test_table_identifier).append(_simple_data()) + before = catalog.load_table(test_table_identifier).metadata + + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="extra", field_type=BooleanType(), required=False), + ) + + def run_failing_replace() -> None: + with catalog.replace_table_transaction(test_table_identifier, schema=new_schema): + raise RuntimeError("simulated failure inside replace transaction") + + with pytest.raises(RuntimeError, match="simulated failure inside replace transaction"): + run_failing_replace() + + after = catalog.load_table(test_table_identifier).metadata + assert after.table_uuid == before.table_uuid + assert after.current_snapshot_id == before.current_snapshot_id + assert after.current_schema_id == before.current_schema_id + assert len(after.schemas) == len(before.schemas) + + +def test_concurrent_replace_table(catalog: Catalog, test_table_identifier: Identifier) -> None: + """`AssertLastAssignedFieldId` rejects the second of two replaces staged from the same base.""" + _create_simple_table(catalog, test_table_identifier) + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="extra", field_type=BooleanType(), required=False), + ) + txn_a = catalog.replace_table_transaction(test_table_identifier, schema=new_schema) + txn_b = catalog.replace_table_transaction(test_table_identifier, schema=new_schema) + + txn_a.commit_transaction() + with pytest.raises(CommitFailedException, match="last assigned field id"): + txn_b.commit_transaction() + + +def test_replace_table_allows_subsequent_append(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier) + catalog.load_table(test_table_identifier).append(_simple_data(num_rows=3)) + + replaced = catalog.replace_table(test_table_identifier, schema=schema) + replaced.append( + pa.Table.from_pydict( + {"id": [42], "data": ["after-replace"]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.large_string())]), + ) + ) + after = catalog.load_table(test_table_identifier) + assert after.scan().to_arrow().num_rows == 1 + + # Rename table tests diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index 2adfe9f06e..309e288bc5 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -19,6 +19,7 @@ import base64 import os +import uuid from collections.abc import Callable from typing import Any, cast from unittest import mock @@ -64,7 +65,7 @@ from pyiceberg.table.sorting import SortField, SortOrder from pyiceberg.transforms import IdentityTransform, TruncateTransform from pyiceberg.typedef import RecursiveDict -from pyiceberg.types import StringType +from pyiceberg.types import BooleanType, IntegerType, NestedField, StringType from pyiceberg.utils.config import Config from pyiceberg.view import View from pyiceberg.view.metadata import ViewMetadata, ViewVersion @@ -2898,3 +2899,116 @@ def test_load_table_without_storage_credentials( ) assert actual.metadata.model_dump() == expected.metadata.model_dump() assert actual == expected + + +def _mock_replace_endpoints( + rest_mock: Mocker, + namespace: str, + table: str, + load_response: dict[str, Any], + commit_response: dict[str, Any], +) -> None: + rest_mock.get( + f"{TEST_URI}v1/namespaces/{namespace}/tables/{table}", + json=load_response, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/{namespace}/tables/{table}", + json=commit_response, + status_code=200, + request_headers=TEST_HEADERS, + ) + + +def test_replace_table_transaction_wire_payload( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + _mock_replace_endpoints( + rest_mock, + "fokko", + "fokko2", + example_table_metadata_with_snapshot_v1_rest_json, + example_table_metadata_no_snapshot_v1_rest_json, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="new_col", field_type=BooleanType(), required=False), + ) + catalog.replace_table_transaction(identifier=("fokko", "fokko2"), schema=new_schema).commit_transaction() + request = rest_mock.last_request.json() + + fixture_metadata = example_table_metadata_with_snapshot_v1_rest_json["metadata"] + assert request["requirements"] == [ + {"type": "assert-table-uuid", "uuid": table_uuid}, + {"type": "assert-last-assigned-field-id", "last-assigned-field-id": fixture_metadata["last-column-id"]}, + {"type": "assert-last-assigned-partition-id", "last-assigned-partition-id": fixture_metadata["last-partition-id"]}, + ] + + actions = [u["action"] for u in request["updates"]] + assert len(actions) == len(set(actions)), f"duplicate actions in request: {actions}" + updates_by_action = {u["action"]: u for u in request["updates"]} + + assert updates_by_action["remove-snapshot-ref"] == {"action": "remove-snapshot-ref", "ref-name": "main"} + added_schema = updates_by_action["add-schema"]["schema"] + assert {f["name"]: f["id"] for f in added_schema["fields"]} == {"id": 1, "data": 2, "new_col": 3} + # schema-id=-1 is the wire sentinel meaning "the schema we just added in this commit". + assert updates_by_action["set-current-schema"]["schema-id"] == -1 + assert updates_by_action["set-default-spec"]["spec-id"] == fixture_metadata["default-spec-id"] + assert updates_by_action["set-default-sort-order"]["sort-order-id"] == fixture_metadata["default-sort-order-id"] + + +def test_replace_table_transaction_404_raises( + rest_mock: Mocker, +) -> None: + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/missing", + json={"error": {"message": "Table not found", "type": "NoSuchTableException", "code": 404}}, + status_code=404, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + with pytest.raises(NoSuchTableError): + catalog.replace_table_transaction( + identifier=("fokko", "missing"), + schema=Schema(NestedField(field_id=1, name="id", field_type=IntegerType(), required=False)), + ) + + +def test_replace_table_issues_commit_post_immediately( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """`replace_table` commits during the call; `replace_table_transaction` defers the POST until the caller commits.""" + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + _mock_replace_endpoints( + rest_mock, + "fokko", + "fokko2", + example_table_metadata_with_snapshot_v1_rest_json, + example_table_metadata_no_snapshot_v1_rest_json, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + + catalog.replace_table_transaction(identifier=("fokko", "fokko2"), schema=new_schema) + methods_after_open = [r.method for r in rest_mock.request_history] + assert "POST" not in methods_after_open + + replaced = catalog.replace_table(identifier=("fokko", "fokko2"), schema=new_schema) + assert replaced.metadata.table_uuid == uuid.UUID(table_uuid) + methods_after_replace = [r.method for r in rest_mock.request_history] + assert "POST" in methods_after_replace, "replace_table must commit immediately" diff --git a/tests/integration/test_catalog.py b/tests/integration/test_catalog.py index 751dbe0479..401f3418f9 100644 --- a/tests/integration/test_catalog.py +++ b/tests/integration/test_catalog.py @@ -20,6 +20,7 @@ from collections.abc import Generator from pathlib import Path, PosixPath +import pyarrow as pa import pytest from pytest_lazy_fixtures import lf @@ -43,7 +44,7 @@ from pyiceberg.table.metadata import INITIAL_SPEC_ID from pyiceberg.table.sorting import INITIAL_SORT_ORDER_ID, SortField, SortOrder from pyiceberg.transforms import BucketTransform, DayTransform, IdentityTransform -from pyiceberg.types import IntegerType, LongType, NestedField, TimestampType, UUIDType +from pyiceberg.types import BooleanType, IntegerType, LongType, NestedField, StringType, TimestampType, UUIDType from tests.conftest import ( clean_up, does_support_atomic_concurrent_updates, @@ -85,7 +86,15 @@ def sqlite_catalog_file(warehouse: Path) -> Generator[Catalog, None, None]: @pytest.fixture(scope="function") def rest_catalog() -> Generator[Catalog, None, None]: - test_catalog = RestCatalog("rest", uri="http://localhost:8181") + test_catalog = RestCatalog( + "rest", + **{ + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) yield test_catalog @@ -866,3 +875,63 @@ def test_load_missing_table(test_catalog: Catalog, database_name: str, table_nam with pytest.raises(NoSuchTableError): test_catalog.load_table(identifier) + + +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_replace_table(test_catalog: Catalog, database_name: str, table_name: str) -> None: + test_catalog.create_namespace(database_name) + identifier = (database_name, table_name) + + original_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + original = test_catalog.create_table(identifier, schema=original_schema) + original.append( + pa.Table.from_pydict( + {"id": [1, 2, 3], "data": ["a", "b", "c"]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.large_string())]), + ) + ) + original.refresh() + original_snapshot_id = original.current_snapshot().snapshot_id # type: ignore[union-attr] + + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="name", field_type=StringType(), required=False), + NestedField(field_id=3, name="active", field_type=BooleanType(), required=False), + ) + replaced = test_catalog.replace_table(identifier, schema=new_schema) + + assert replaced.metadata.table_uuid == original.metadata.table_uuid + assert replaced.current_snapshot() is None + assert any(s.snapshot_id == original_snapshot_id for s in replaced.metadata.snapshots) + + +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_replace_table_transaction(test_catalog: Catalog, database_name: str, table_name: str) -> None: + test_catalog.create_namespace(database_name) + identifier = (database_name, table_name) + + old_data = pa.Table.from_pydict( + {"id": [1], "data": ["old"]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.large_string())]), + ) + original = test_catalog.create_table(identifier, schema=old_data.schema) + original.append(old_data) + old_snapshot_id = test_catalog.load_table(identifier).current_snapshot().snapshot_id # type: ignore[union-attr] + + new_data = pa.Table.from_pydict( + {"id": [10, 20], "name": ["alice", "bob"]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("name", pa.large_string())]), + ) + with test_catalog.replace_table_transaction(identifier, schema=new_data.schema) as txn: + txn.append(new_data) + + replaced = test_catalog.load_table(identifier) + assert replaced.current_snapshot() is not None + assert replaced.current_snapshot().snapshot_id != old_snapshot_id # type: ignore[union-attr] + assert any(s.snapshot_id == old_snapshot_id for s in replaced.metadata.snapshots) + assert replaced.scan().to_arrow().num_rows == 2 diff --git a/tests/table/test_partitioning.py b/tests/table/test_partitioning.py index a27046ef30..5fb81dc4d4 100644 --- a/tests/table/test_partitioning.py +++ b/tests/table/test_partitioning.py @@ -22,7 +22,12 @@ import pytest from pyiceberg.exceptions import ValidationError -from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec +from pyiceberg.partitioning import ( + UNPARTITIONED_PARTITION_SPEC, + PartitionField, + PartitionSpec, + assign_fresh_partition_spec_ids_for_replace, +) from pyiceberg.schema import Schema from pyiceberg.transforms import ( BucketTransform, @@ -31,6 +36,7 @@ IdentityTransform, MonthTransform, TruncateTransform, + VoidTransform, YearTransform, ) from pyiceberg.typedef import Record @@ -298,3 +304,104 @@ def test_incompatible_transform_source_type() -> None: spec.check_compatible(schema) assert "Invalid source field foo with type int for transform: year" in str(exc.value) + + +_REPLACE_SCHEMA_FOR_PARTITION = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), +) + + +@pytest.mark.parametrize( + "new_spec, existing_specs, last_partition_id, expected_field_id, expected_last_partition_id", + [ + # Reuse-by-identity: same (source_id, IdentityTransform) already in an existing spec. + pytest.param( + PartitionSpec(PartitionField(source_id=1, field_id=999, transform=IdentityTransform(), name="id")), + [PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=0)], + 1000, + 1000, + 1000, + id="reuse-identity", + ), + # Reuse-by-(source,bucket): same source_id + same BucketTransform, even under a renamed field. + pytest.param( + PartitionSpec(PartitionField(source_id=1, field_id=999, transform=BucketTransform(8), name="id_bucket_renamed")), + [ + PartitionSpec( + PartitionField(source_id=1, field_id=1042, transform=BucketTransform(8), name="id_bucket"), spec_id=0 + ) + ], + 1042, + 1042, + 1042, + id="reuse-bucket-under-rename", + ), + # No match: fresh id above last_partition_id. + pytest.param( + PartitionSpec(PartitionField(source_id=1, field_id=999, transform=IdentityTransform(), name="id")), + [PartitionSpec(spec_id=0)], + 999, + 1000, + 1000, + id="new-field-above-last-partition-id", + ), + ], +) +def test_assign_fresh_partition_spec_ids_for_replace_v2( + new_spec: PartitionSpec, + existing_specs: list[PartitionSpec], + last_partition_id: int, + expected_field_id: int, + expected_last_partition_id: int, +) -> None: + fresh_spec, new_last_pid = assign_fresh_partition_spec_ids_for_replace( + new_spec, _REPLACE_SCHEMA_FOR_PARTITION, _REPLACE_SCHEMA_FOR_PARTITION, existing_specs, last_partition_id + ) + assert fresh_spec.fields[0].field_id == expected_field_id + assert new_last_pid == expected_last_partition_id + + +def test_assign_fresh_partition_spec_ids_for_replace_v1_carries_forward_as_void() -> None: + """v1 specs are append-only: a field absent from the new spec is carried forward as void.""" + current_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=0) + # New spec drops "id" entirely, partitioned by "data" instead. + new_spec = PartitionSpec(PartitionField(source_id=2, field_id=999, transform=IdentityTransform(), name="data")) + fresh_spec, new_last_pid = assign_fresh_partition_spec_ids_for_replace( + new_spec, + _REPLACE_SCHEMA_FOR_PARTITION, + _REPLACE_SCHEMA_FOR_PARTITION, + existing_specs=[current_spec], + last_partition_id=1000, + format_version=1, + current_spec=current_spec, + ) + # Two fields: the carried-forward void at field_id=1000, and the new "data" field above it. + fields_by_id = {f.field_id: f for f in fresh_spec.fields} + assert isinstance(fields_by_id[1000].transform, VoidTransform) + assert fields_by_id[1000].name == "id" + assert fields_by_id[1001].name == "data" + assert isinstance(fields_by_id[1001].transform, IdentityTransform) + assert new_last_pid == 1001 + + +def test_assign_fresh_partition_spec_ids_for_replace_v1_renames_void_on_name_collision() -> None: + """When a void field's name collides with a new field's name, a unique suffix is added.""" + current_spec = PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="data"), spec_id=0 + ) + # New spec partitions "data" by a different transform — the OLD "data" must be voided + # under a different name to avoid collision with the NEW "data" partition. + new_spec = PartitionSpec(PartitionField(source_id=2, field_id=999, transform=IdentityTransform(), name="data")) + fresh_spec, _ = assign_fresh_partition_spec_ids_for_replace( + new_spec, + _REPLACE_SCHEMA_FOR_PARTITION, + _REPLACE_SCHEMA_FOR_PARTITION, + existing_specs=[current_spec], + last_partition_id=1000, + format_version=1, + current_spec=current_spec, + ) + void_field = next(f for f in fresh_spec.fields if isinstance(f.transform, VoidTransform)) + assert void_field.name != "data", "void name must not collide with active partition name" + assert void_field.name == "data_1000" diff --git a/tests/test_schema.py b/tests/test_schema.py index 93ddc16202..5f5368fda5 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -26,6 +26,7 @@ Accessor, Schema, _check_schema_compatible, + assign_fresh_schema_ids_for_replace, build_position_accessors, index_by_id, index_by_name, @@ -1815,3 +1816,114 @@ def test_check_schema_compatible_optional_map_field_present() -> None: ) # Should not raise - schemas match _check_schema_compatible(requested_schema, provided_schema) + + +@pytest.mark.parametrize( + "new_fields, expected_ids, expected_last_col_id", + [ + # All columns reused by name: IDs come from base, last_column_id unchanged. + ([("id", IntegerType()), ("data", StringType())], [1, 2], 2), + # Mix of reused and new: new column gets ID above last_column_id. + ([("id", IntegerType()), ("data", StringType()), ("new_col", BooleanType())], [1, 2, 3], 3), + # No column names match: all fresh IDs starting from last_column_id + 1. + ([("x", IntegerType()), ("y", IntegerType())], [3, 4], 4), + ], + ids=[ + "all-reused-keeps-last-col-id", + "new-field-bumps-last-col-id", + "no-name-overlap-bumps-from-base", + ], +) +def test_assign_fresh_schema_ids_for_replace_primitive_fields( + new_fields: list[tuple[str, IcebergType]], expected_ids: list[int], expected_last_col_id: int +) -> None: + """Replace schema field IDs are reused from the base schema by name; new fields get IDs above last_column_id.""" + base_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + new_schema = Schema( + *( + NestedField(field_id=10 * (i + 1), name=name, field_type=field_type, required=False) + for i, (name, field_type) in enumerate(new_fields) + ) + ) + fresh, last_col_id = assign_fresh_schema_ids_for_replace(new_schema, base_schema, 2) + assert [f.field_id for f in fresh.fields] == expected_ids + assert last_col_id == expected_last_col_id + + +def test_assign_fresh_schema_ids_for_replace_with_nested_struct() -> None: + """Test that nested struct field IDs are reused by full path name.""" + base_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField( + field_id=2, + name="location", + field_type=StructType( + NestedField(field_id=3, name="lat", field_type=FloatType(), required=False), + NestedField(field_id=4, name="lon", field_type=FloatType(), required=False), + ), + required=False, + ), + ) + new_schema = Schema( + NestedField(field_id=10, name="id", field_type=IntegerType(), required=False), + NestedField( + field_id=20, + name="location", + field_type=StructType( + NestedField(field_id=30, name="lat", field_type=FloatType(), required=False), + NestedField(field_id=40, name="lon", field_type=FloatType(), required=False), + NestedField(field_id=50, name="alt", field_type=FloatType(), required=False), + ), + required=False, + ), + ) + fresh, last_col_id = assign_fresh_schema_ids_for_replace(new_schema, base_schema, 4) + assert fresh.fields[0].field_id == 1 # id reused + assert fresh.fields[1].field_id == 2 # location reused + loc_fields = fresh.fields[1].field_type.fields + assert loc_fields[0].field_id == 3 # location.lat reused + assert loc_fields[1].field_id == 4 # location.lon reused + assert loc_fields[2].field_id == 5 # location.alt is new + assert last_col_id == 5 + + +def test_assign_fresh_schema_ids_for_replace_with_list_and_map() -> None: + """`element_id`, `key_id`, and `value_id` are reused by name path (e.g. `tags.element`, `m.key`, `m.value`).""" + base_schema = Schema( + NestedField( + field_id=1, + name="tags", + field_type=ListType(element_id=2, element_type=StringType(), element_required=False), + required=False, + ), + NestedField( + field_id=3, + name="m", + field_type=MapType(key_id=4, key_type=StringType(), value_id=5, value_type=IntegerType(), value_required=False), + required=False, + ), + ) + new_schema = Schema( + NestedField( + field_id=10, + name="tags", + field_type=ListType(element_id=20, element_type=StringType(), element_required=False), + required=False, + ), + NestedField( + field_id=30, + name="m", + field_type=MapType(key_id=40, key_type=StringType(), value_id=50, value_type=IntegerType(), value_required=False), + required=False, + ), + ) + fresh, last_col_id = assign_fresh_schema_ids_for_replace(new_schema, base_schema, 5) + assert fresh.fields[0].field_id == 1 + assert fresh.fields[0].field_type.element_id == 2 + assert fresh.fields[1].field_id == 3 + assert fresh.fields[1].field_type.key_id == 4 + assert fresh.fields[1].field_type.value_id == 5 + assert last_col_id == 5