Skip to content

fix(sql): Add fallback to source_defined_primary_key in CatalogProvider #627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions airbyte_cdk/sql/shared/catalog_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,21 @@ def get_stream_properties(
def get_primary_keys(
self,
stream_name: str,
) -> list[str]:
"""Return the primary keys for the given stream."""
pks = self.get_configured_stream_info(stream_name).primary_key
) -> list[str] | None:
"""Return the primary key column names for the given stream.

We return `source_defined_primary_key` if set, or `primary_key` otherwise. If both are set,
we assume they should not should differ, since Airbyte data integrity constraints do not
permit overruling a source's pre-defined primary keys. If neither is set, we return `None`.

Returns:
A list of column names that constitute the primary key, or None if no primary key is defined.
"""
configured_stream = self.get_configured_stream_info(stream_name)
pks = configured_stream.stream.source_defined_primary_key or configured_stream.primary_key

if not pks:
return []
return None

normalized_pks: list[list[str]] = [
[LowerCaseNormalizer.normalize(c) for c in pk] for pk in pks
Expand Down
15 changes: 12 additions & 3 deletions airbyte_cdk/sql/shared/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,13 @@ def _merge_temp_table_to_final_table(
"""
nl = "\n"
columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)}
pk_columns = {
self._quote_identifier(c) for c in self.catalog_provider.get_primary_keys(stream_name)
}
primary_keys = self.catalog_provider.get_primary_keys(stream_name)
if not primary_keys:
raise exc.AirbyteInternalError(
message="Cannot merge tables without primary keys. Primary keys are required for merge operations.",
context={"stream_name": stream_name},
)
pk_columns = {self._quote_identifier(c) for c in primary_keys}
non_pk_columns = columns - pk_columns
join_clause = f"{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns)
set_clause = f"{nl} , ".join(f"{col} = tmp.{col}" for col in non_pk_columns)
Expand Down Expand Up @@ -725,6 +729,11 @@ def _emulated_merge_temp_table_to_final_table(
final_table = self._get_table_by_name(final_table_name)
temp_table = self._get_table_by_name(temp_table_name)
pk_columns = self.catalog_provider.get_primary_keys(stream_name)
if not pk_columns:
raise exc.AirbyteInternalError(
message="Cannot merge tables without primary keys. Primary keys are required for merge operations.",
context={"stream_name": stream_name},
)

columns_to_update: set[str] = self._get_sql_column_definitions(
stream_name=stream_name
Expand Down
11 changes: 10 additions & 1 deletion airbyte_cdk/test/standard_tests/connector_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,16 @@ def connector(cls) -> type[IConnector] | Callable[[], IConnector] | None:
try:
module = importlib.import_module(expected_module_name)
except ModuleNotFoundError as e:
raise ImportError(f"Could not import module '{expected_module_name}'.") from e
raise ImportError(
f"Could not import module '{expected_module_name}'. "
"Please ensure you are running from within the connector's virtual environment, "
"for instance by running `poetry run airbyte-cdk connector test` from the "
"connector directory. If the issue persists, check that the connector "
f"module matches the expected module name '{expected_module_name}' and that the "
f"connector class matches the expected class name '{expected_class_name}'. "
"Alternatively, you can run `airbyte-cdk image test` to run a subset of tests "
"against the connector's image."
) from e
finally:
# Change back to the original working directory
os.chdir(cwd_snapshot)
Expand Down
78 changes: 78 additions & 0 deletions unit_tests/sql/shared/test_catalog_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from unittest.mock import Mock

import pytest

from airbyte_cdk.models import AirbyteStream, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream
from airbyte_cdk.sql.shared.catalog_providers import CatalogProvider


class TestCatalogProvider:
"""Test cases for CatalogProvider.get_primary_keys() method."""

@pytest.mark.parametrize(
"configured_primary_key,source_defined_primary_key,expected_result,test_description",
[
(["configured_id"], ["source_id"], ["source_id"], "prioritizes source when both set"),
([], ["source_id"], ["source_id"], "uses source when configured empty"),
(None, ["source_id"], ["source_id"], "uses source when configured None"),
(
["configured_id"],
[],
["configured_id"],
"falls back to configured when source empty",
),
(
["configured_id"],
None,
["configured_id"],
"falls back to configured when source None",
),
([], [], None, "returns None when both empty"),
(None, None, None, "returns None when both None"),
([], ["id1", "id2"], ["id1", "id2"], "handles composite keys from source"),
],
)
def test_get_primary_keys_parametrized(
self, configured_primary_key, source_defined_primary_key, expected_result, test_description
):
"""Test primary key fallback logic with various input combinations."""
configured_pk_wrapped = (
None
if configured_primary_key is None
else [[pk] for pk in configured_primary_key]
if configured_primary_key
else []
)
source_pk_wrapped = (
None
if source_defined_primary_key is None
else [[pk] for pk in source_defined_primary_key]
if source_defined_primary_key
else []
)

stream = AirbyteStream(
name="test_stream",
json_schema={
"type": "object",
"properties": {
"id": {"type": "string"},
"id1": {"type": "string"},
"id2": {"type": "string"},
},
},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=source_pk_wrapped,
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=configured_pk_wrapped,
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == expected_result
Loading