Skip to content

Add Snowflake Connector Implementation #1559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ dependencies = [
"mypy~=1.10.0",
"numpy==1.26.4",
"pandas==1.4.1",
"snowflake-sqlalchemy>=1.7.3",
]

[project.entry-points.databricks]
Expand Down
28 changes: 27 additions & 1 deletion src/databricks/labs/remorph/connections/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from sqlalchemy import text
from sqlalchemy.exc import OperationalError

from databricks.labs.remorph.reconcile.connectors.snowflake import SnowflakeDataSource

logger = logging.getLogger(__name__)
logger.setLevel("INFO")

Expand Down Expand Up @@ -56,7 +58,31 @@ def _create_connector(db_type: str, config: dict[str, Any]) -> DatabaseConnector

class SnowflakeConnector(_BaseConnector):
def _connect(self) -> Engine:
raise NotImplementedError("Snowflake connector not implemented")
# pylint: disable=import-outside-toplevel
import snowflake.sqlalchemy # type: ignore

# Snowflake does not follow a traditional SQL Alchemy connection string URL; they have their own.
# e.g., connection_string = (f"snowflake://{user}:{pw}@{account}")
# https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy
sqlalchemy_driver = "snowflake"
url_parts = self.config["server"].split(".")
parsed_url = f"{url_parts[0]}.{url_parts[1]}.{url_parts[2]}"
connection_string = snowflake.sqlalchemy.URL(
drivername=sqlalchemy_driver,
account=parsed_url,
user=self.config["user"],
database=self.config["database"],
schema=self.config["schema"],
warehouse=self.config["warehouse"],
)

# Users can optionally specify a private key to use
conn_args = {}
if "pem_private_key" in self.config:
private_key_bytes = SnowflakeDataSource.get_private_key(self.config["pem_private_key"])
conn_args = {"private_key": private_key_bytes}

return create_engine(connection_string, connect_args=conn_args)


class MSSQLConnector(_BaseConnector):
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def mock_credentials():
'database': 'TEST_TSQL_JDBC',
'driver': 'ODBC Driver 18 for SQL Server',
},
'snowflake': {
'server': 'TEST_SNOWFLAKE_JDBC',
'pem_private_key': 'TEST_SNOWFLAKE_PRIVATE_KEY',
'database': 'TEST_SNOWFLAKE_DB',
'schema': 'TEST_SNOWFLAKE_SCHEMA',
},
},
):
yield
32 changes: 26 additions & 6 deletions tests/integration/connections/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,34 @@ def get_db_manager(product_name: str, source: str) -> DatabaseManager:
env = TestEnvGetter(True)
config = create_credential_manager(product_name, env).get_credentials(source)

# since the kv has only URL so added explicit parse rules
base_url, params = config['server'].replace("jdbc:", "", 1).split(";", 1)
# Some JDBC connection strings separate hostname and query params
# by semicolons, while others use ampersands
if ";" in config["server"]:
base_url, params = config["server"].replace("jdbc:", "", 1).split(";", 1)
elif "?" in config["server"]:
base_url, params = config["server"].replace("jdbc:", "", 1).split("?", 1)
else: # There are no query params
base_url = config["server"].replace("jdbc:", "", 1)
params = None

url_parts = urlparse(base_url)
server = url_parts.hostname
query_params = dict(param.split("=", 1) for param in params.split(";") if "=" in param)
database = query_params.get("database", "")
config['server'] = server
config['database'] = database
config["server"] = server

if params:
# Some JDBC connection strings separate params by semicolons
# while others separate by ampersands
if ";" in params:
query_param_sep = ";"
elif "&" in params:
query_param_sep = "&"
else:
raise ValueError("Unknown param separator in JDBC connection string.")

for param in params.split(query_param_sep):
split_param = param.split("=", 1)
# Only pull out necessary connection params
if split_param[0] in {"database", "warehouse", "user"}:
config[split_param[0]] = split_param[1]

return DatabaseManager(source, config)
25 changes: 25 additions & 0 deletions tests/integration/connections/test_snowflake_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest

from databricks.labs.remorph.connections.database_manager import SnowflakeConnector
from .helpers import get_db_manager


@pytest.fixture()
def db_manager(mock_credentials):
return get_db_manager("remorph", "snowflake")


def test_snowflake_connector_connection(db_manager):
assert isinstance(db_manager.connector, SnowflakeConnector)


def test_snowflake_connector_execute_query(db_manager):
# Execute a sample query
query = "SELECT 'Hello, World!' AS message"
result = db_manager.execute_query(query)
row = result.fetchone()
assert row[0] == "Hello, World!"


def test_connection_test(db_manager):
assert db_manager.check_connection()
Loading