diff --git a/airbyte_cdk/sql/shared/catalog_providers.py b/airbyte_cdk/sql/shared/catalog_providers.py index 80713a35a..d9016a37d 100644 --- a/airbyte_cdk/sql/shared/catalog_providers.py +++ b/airbyte_cdk/sql/shared/catalog_providers.py @@ -119,11 +119,21 @@ def get_stream_properties( def get_primary_keys( self, stream_name: str, - ) -> list[str]: - """Return the primary keys for the given stream.""" - pks = self.get_configured_stream_info(stream_name).primary_key + ) -> list[str] | None: + """Return the primary key column names for the given stream. + + We return `source_defined_primary_key` if set, or `primary_key` otherwise. If both are set, + we assume they should not should differ, since Airbyte data integrity constraints do not + permit overruling a source's pre-defined primary keys. If neither is set, we return `None`. + + Returns: + A list of column names that constitute the primary key, or None if no primary key is defined. + """ + configured_stream = self.get_configured_stream_info(stream_name) + pks = configured_stream.stream.source_defined_primary_key or configured_stream.primary_key + if not pks: - return [] + return None normalized_pks: list[list[str]] = [ [LowerCaseNormalizer.normalize(c) for c in pk] for pk in pks diff --git a/airbyte_cdk/sql/shared/sql_processor.py b/airbyte_cdk/sql/shared/sql_processor.py index a53925206..238ff6c69 100644 --- a/airbyte_cdk/sql/shared/sql_processor.py +++ b/airbyte_cdk/sql/shared/sql_processor.py @@ -666,9 +666,13 @@ def _merge_temp_table_to_final_table( """ nl = "\n" columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)} - pk_columns = { - self._quote_identifier(c) for c in self.catalog_provider.get_primary_keys(stream_name) - } + primary_keys = self.catalog_provider.get_primary_keys(stream_name) + if not primary_keys: + raise exc.AirbyteInternalError( + message="Cannot merge tables without primary keys. Primary keys are required for merge operations.", + context={"stream_name": stream_name}, + ) + pk_columns = {self._quote_identifier(c) for c in primary_keys} non_pk_columns = columns - pk_columns join_clause = f"{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns) set_clause = f"{nl} , ".join(f"{col} = tmp.{col}" for col in non_pk_columns) @@ -725,6 +729,11 @@ def _emulated_merge_temp_table_to_final_table( final_table = self._get_table_by_name(final_table_name) temp_table = self._get_table_by_name(temp_table_name) pk_columns = self.catalog_provider.get_primary_keys(stream_name) + if not pk_columns: + raise exc.AirbyteInternalError( + message="Cannot merge tables without primary keys. Primary keys are required for merge operations.", + context={"stream_name": stream_name}, + ) columns_to_update: set[str] = self._get_sql_column_definitions( stream_name=stream_name diff --git a/airbyte_cdk/test/standard_tests/connector_base.py b/airbyte_cdk/test/standard_tests/connector_base.py index 588b7d0bd..b945f1572 100644 --- a/airbyte_cdk/test/standard_tests/connector_base.py +++ b/airbyte_cdk/test/standard_tests/connector_base.py @@ -59,7 +59,16 @@ def connector(cls) -> type[IConnector] | Callable[[], IConnector] | None: try: module = importlib.import_module(expected_module_name) except ModuleNotFoundError as e: - raise ImportError(f"Could not import module '{expected_module_name}'.") from e + raise ImportError( + f"Could not import module '{expected_module_name}'. " + "Please ensure you are running from within the connector's virtual environment, " + "for instance by running `poetry run airbyte-cdk connector test` from the " + "connector directory. If the issue persists, check that the connector " + f"module matches the expected module name '{expected_module_name}' and that the " + f"connector class matches the expected class name '{expected_class_name}'. " + "Alternatively, you can run `airbyte-cdk image test` to run a subset of tests " + "against the connector's image." + ) from e finally: # Change back to the original working directory os.chdir(cwd_snapshot) diff --git a/unit_tests/sql/shared/test_catalog_providers.py b/unit_tests/sql/shared/test_catalog_providers.py new file mode 100644 index 000000000..af66ba974 --- /dev/null +++ b/unit_tests/sql/shared/test_catalog_providers.py @@ -0,0 +1,78 @@ +from unittest.mock import Mock + +import pytest + +from airbyte_cdk.models import AirbyteStream, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream +from airbyte_cdk.sql.shared.catalog_providers import CatalogProvider + + +class TestCatalogProvider: + """Test cases for CatalogProvider.get_primary_keys() method.""" + + @pytest.mark.parametrize( + "configured_primary_key,source_defined_primary_key,expected_result,test_description", + [ + (["configured_id"], ["source_id"], ["source_id"], "prioritizes source when both set"), + ([], ["source_id"], ["source_id"], "uses source when configured empty"), + (None, ["source_id"], ["source_id"], "uses source when configured None"), + ( + ["configured_id"], + [], + ["configured_id"], + "falls back to configured when source empty", + ), + ( + ["configured_id"], + None, + ["configured_id"], + "falls back to configured when source None", + ), + ([], [], None, "returns None when both empty"), + (None, None, None, "returns None when both None"), + ([], ["id1", "id2"], ["id1", "id2"], "handles composite keys from source"), + ], + ) + def test_get_primary_keys_parametrized( + self, configured_primary_key, source_defined_primary_key, expected_result, test_description + ): + """Test primary key fallback logic with various input combinations.""" + configured_pk_wrapped = ( + None + if configured_primary_key is None + else [[pk] for pk in configured_primary_key] + if configured_primary_key + else [] + ) + source_pk_wrapped = ( + None + if source_defined_primary_key is None + else [[pk] for pk in source_defined_primary_key] + if source_defined_primary_key + else [] + ) + + stream = AirbyteStream( + name="test_stream", + json_schema={ + "type": "object", + "properties": { + "id": {"type": "string"}, + "id1": {"type": "string"}, + "id2": {"type": "string"}, + }, + }, + supported_sync_modes=["full_refresh"], + source_defined_primary_key=source_pk_wrapped, + ) + configured_stream = ConfiguredAirbyteStream( + stream=stream, + sync_mode="full_refresh", + destination_sync_mode="overwrite", + primary_key=configured_pk_wrapped, + ) + catalog = ConfiguredAirbyteCatalog(streams=[configured_stream]) + + provider = CatalogProvider(catalog) + result = provider.get_primary_keys("test_stream") + + assert result == expected_result