diff --git a/pyproject.toml b/pyproject.toml index b39e26fd5..c593549a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dependencies = [ "mypy~=1.10.0", "numpy==1.26.4", "pandas==1.4.1", + "snowflake-sqlalchemy>=1.7.3", ] [project.entry-points.databricks] diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index df9d678bb..13bca4107 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -8,6 +8,8 @@ from sqlalchemy import text from sqlalchemy.exc import OperationalError +from databricks.labs.remorph.reconcile.connectors.snowflake import SnowflakeDataSource + logger = logging.getLogger(__name__) logger.setLevel("INFO") @@ -56,7 +58,31 @@ def _create_connector(db_type: str, config: dict[str, Any]) -> DatabaseConnector class SnowflakeConnector(_BaseConnector): def _connect(self) -> Engine: - raise NotImplementedError("Snowflake connector not implemented") + # pylint: disable=import-outside-toplevel + import snowflake.sqlalchemy # type: ignore + + # Snowflake does not follow a traditional SQL Alchemy connection string URL; they have their own. + # e.g., connection_string = (f"snowflake://{user}:{pw}@{account}") + # https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy + sqlalchemy_driver = "snowflake" + url_parts = self.config["server"].split(".") + parsed_url = f"{url_parts[0]}.{url_parts[1]}.{url_parts[2]}" + connection_string = snowflake.sqlalchemy.URL( + drivername=sqlalchemy_driver, + account=parsed_url, + user=self.config["user"], + database=self.config["database"], + schema=self.config["schema"], + warehouse=self.config["warehouse"], + ) + + # Users can optionally specify a private key to use + conn_args = {} + if "pem_private_key" in self.config: + private_key_bytes = SnowflakeDataSource.get_private_key(self.config["pem_private_key"]) + conn_args = {"private_key": private_key_bytes} + + return create_engine(connection_string, connect_args=conn_args) class MSSQLConnector(_BaseConnector): diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index e1cc5a8fd..2f5316cd9 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -65,6 +65,12 @@ def mock_credentials(): 'database': 'TEST_TSQL_JDBC', 'driver': 'ODBC Driver 18 for SQL Server', }, + 'snowflake': { + 'server': 'TEST_SNOWFLAKE_JDBC', + 'pem_private_key': 'TEST_SNOWFLAKE_PRIVATE_KEY', + 'database': 'TEST_SNOWFLAKE_DB', + 'schema': 'TEST_SNOWFLAKE_SCHEMA', + }, }, ): yield diff --git a/tests/integration/connections/helpers.py b/tests/integration/connections/helpers.py index 44e9c4e51..3a5ab766d 100644 --- a/tests/integration/connections/helpers.py +++ b/tests/integration/connections/helpers.py @@ -8,14 +8,34 @@ def get_db_manager(product_name: str, source: str) -> DatabaseManager: env = TestEnvGetter(True) config = create_credential_manager(product_name, env).get_credentials(source) - # since the kv has only URL so added explicit parse rules - base_url, params = config['server'].replace("jdbc:", "", 1).split(";", 1) + # Some JDBC connection strings separate hostname and query params + # by semicolons, while others use ampersands + if ";" in config["server"]: + base_url, params = config["server"].replace("jdbc:", "", 1).split(";", 1) + elif "?" in config["server"]: + base_url, params = config["server"].replace("jdbc:", "", 1).split("?", 1) + else: # There are no query params + base_url = config["server"].replace("jdbc:", "", 1) + params = None url_parts = urlparse(base_url) server = url_parts.hostname - query_params = dict(param.split("=", 1) for param in params.split(";") if "=" in param) - database = query_params.get("database", "") - config['server'] = server - config['database'] = database + config["server"] = server + + if params: + # Some JDBC connection strings separate params by semicolons + # while others separate by ampersands + if ";" in params: + query_param_sep = ";" + elif "&" in params: + query_param_sep = "&" + else: + raise ValueError("Unknown param separator in JDBC connection string.") + + for param in params.split(query_param_sep): + split_param = param.split("=", 1) + # Only pull out necessary connection params + if split_param[0] in {"database", "warehouse", "user"}: + config[split_param[0]] = split_param[1] return DatabaseManager(source, config) diff --git a/tests/integration/connections/test_snowflake_connector.py b/tests/integration/connections/test_snowflake_connector.py new file mode 100644 index 000000000..f2db0bc1a --- /dev/null +++ b/tests/integration/connections/test_snowflake_connector.py @@ -0,0 +1,25 @@ +import pytest + +from databricks.labs.remorph.connections.database_manager import SnowflakeConnector +from .helpers import get_db_manager + + +@pytest.fixture() +def db_manager(mock_credentials): + return get_db_manager("remorph", "snowflake") + + +def test_snowflake_connector_connection(db_manager): + assert isinstance(db_manager.connector, SnowflakeConnector) + + +def test_snowflake_connector_execute_query(db_manager): + # Execute a sample query + query = "SELECT 'Hello, World!' AS message" + result = db_manager.execute_query(query) + row = result.fetchone() + assert row[0] == "Hello, World!" + + +def test_connection_test(db_manager): + assert db_manager.check_connection()