diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 78676a774a..e97375f2c2 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -291,8 +291,6 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ if self._autocommit: self.commit_transaction() - self._updates = () - self._requirements = () return self @@ -890,13 +888,15 @@ def commit_transaction(self) -> Table: updates=self._updates, requirements=self._requirements, ) - return self._table - else: - return self._table + + self._updates = () + self._requirements = () + + return self._table class CreateTableTransaction(Transaction): - """A transaction that involves the creation of a a new table.""" + """A transaction that involves the creation of a new table.""" def _initial_changes(self, table_metadata: TableMetadata) -> None: """Set the initial changes that can reconstruct the initial table metadata when creating the CreateTableTransaction.""" @@ -941,11 +941,16 @@ def commit_transaction(self) -> Table: Returns: The table with the updates applied. """ - self._requirements = (AssertCreate(),) - self._table._do_commit( # pylint: disable=W0212 - updates=self._updates, - requirements=self._requirements, - ) + if len(self._updates) > 0: + self._requirements += (AssertCreate(),) + self._table._do_commit( # pylint: disable=W0212 + updates=self._updates, + requirements=self._requirements, + ) + + self._updates = () + self._requirements = () + return self._table diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 150d2b750c..8c3f1c29ef 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -1780,6 +1780,23 @@ def test_write_optional_list(session_catalog: Catalog) -> None: assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4 +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_double_commit_transaction( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.arrow_data_files" + tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, []) + + assert len(tbl.metadata.metadata_log) == 0 + + with tbl.transaction() as tx: + tx.append(arrow_table_with_null) + tx.commit_transaction() + + assert len(tbl.metadata.metadata_log) == 1 + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_evolve_and_write(