diff --git a/tests/integration/test_catalog.py b/tests/integration/test_catalog.py index 587c13b35b..a4b0a82583 100644 --- a/tests/integration/test_catalog.py +++ b/tests/integration/test_catalog.py @@ -26,6 +26,7 @@ from pyiceberg.catalog.rest import RestCatalog from pyiceberg.catalog.sql import SqlCatalog from pyiceberg.exceptions import ( + CommitFailedException, NamespaceAlreadyExistsError, NamespaceNotEmptyError, NoSuchNamespaceError, @@ -34,6 +35,7 @@ ) from pyiceberg.io import WAREHOUSE from pyiceberg.schema import Schema +from pyiceberg.types import BooleanType, FloatType, IntegerType, ListType, LongType, MapType, NestedField, StringType, StructType from tests.conftest import clean_up @@ -343,3 +345,116 @@ def test_update_namespace_properties(test_catalog: Catalog, database_name: str) else: assert k in update_report.removed assert "updated test description" == test_catalog.load_namespace_properties(database_name)["comment"] + + +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_update_table_schema(test_catalog: Catalog, table_schema_nested: Schema, table_name: str, database_name: str) -> None: + identifier = (database_name, table_name) + test_catalog.create_namespace(database_name) + table = test_catalog.create_table(identifier, table_schema_nested) + + with table.update_schema() as update: + update.add_column("new_col", LongType()) + + assert table.schema().find_field("new_col", case_sensitive=False) + + +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_update_table_schema_conflict( + test_catalog: Catalog, table_schema_nested: Schema, table_name: str, database_name: str +) -> None: + identifier = (database_name, table_name) + test_catalog.create_namespace(database_name) + table = test_catalog.create_table(identifier, table_schema_nested) + + update = table.update_schema().add_column("new_col", LongType()) + + # update the schema concurrently so that the original update fails + concurrent_table = test_catalog.load_table(identifier) + + with concurrent_table.update_schema(allow_incompatible_changes=True) as concurrent_update: + concurrent_update.add_column("new_col", StringType()) + + # attempt to commit the original update + with pytest.raises(CommitFailedException, match="Requirement failed: current schema"): + update.commit() + + assert concurrent_table.schema() == Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + NestedField( + field_id=4, + name="qux", + field_type=ListType(type="list", element_id=8, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=5, + name="quux", + field_type=MapType( + type="map", + key_id=9, + key_type=StringType(), + value_id=10, + value_type=MapType( + type="map", key_id=11, key_type=StringType(), value_id=12, value_type=IntegerType(), value_required=True + ), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=6, + name="location", + field_type=ListType( + type="list", + element_id=13, + element_type=StructType( + fields=( + NestedField(field_id=14, name="latitude", field_type=FloatType(), required=False), + NestedField(field_id=15, name="longitude", field_type=FloatType(), required=False), + ) + ), + element_required=True, + ), + required=True, + ), + NestedField( + field_id=7, + name="person", + field_type=StructType( + fields=( + NestedField(field_id=16, name="name", field_type=StringType(), required=False), + NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), + ) + ), + required=False, + ), + NestedField(field_id=18, name="new_col", field_type=StringType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + + +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_update_table_schema_then_change_back( + test_catalog: Catalog, table_schema_nested: Schema, table_name: str, database_name: str +) -> None: + identifier = (database_name, table_name) + test_catalog.create_namespace(database_name) + table = test_catalog.create_table(identifier, table_schema_nested) + original_schema_struct = table.schema().as_struct() + + table.update_schema().add_column("col1", StringType()).add_column("col2", StringType()).add_column( + "col3", StringType() + ).commit() + + table_with_cols = test_catalog.load_table(identifier) + table_with_cols.update_schema().delete_column("col1").delete_column("col2").delete_column("col3").commit() + + reverted_table = test_catalog.load_table(identifier) + assert reverted_table.schema().as_struct() == original_schema_struct