Skip to content

Support for complex params #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/databricks/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from databricks.sqlalchemy.base import DatabricksDialect
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ
from databricks.sqlalchemy._types import (
TINYINT,
TIMESTAMP,
TIMESTAMP_NTZ,
DatabricksArray,
DatabricksMap,
)

__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ"]
__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"]
44 changes: 44 additions & 0 deletions src/databricks/sqlalchemy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlalchemy
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import TypeDecorator, UserDefinedType

from databricks.sql.utils import ParamEscaper

Expand Down Expand Up @@ -321,3 +322,46 @@ class TINYINT(sqlalchemy.types.TypeDecorator):
@compiles(TINYINT, "databricks")
def compile_tinyint(type_, compiler, **kw):
return "TINYINT"


class DatabricksArray(UserDefinedType):
"""
A custom array type that can wrap any other SQLAlchemy type.

Examples:
DatabricksArray(String) -> ARRAY<STRING>
DatabricksArray(Integer) -> ARRAY<INT>
DatabricksArray(CustomType) -> ARRAY<CUSTOM_TYPE>
"""

def __init__(self, item_type):
self.item_type = item_type() if isinstance(item_type, type) else item_type


@compiles(DatabricksArray, "databricks")
def compile_databricks_array(type_, compiler, **kw):
inner = compiler.process(type_.item_type, **kw)

return f"ARRAY<{inner}>"


class DatabricksMap(UserDefinedType):
"""
A custom map type that can wrap any other SQLAlchemy types for both key and value.

Examples:
DatabricksMap(String, String) -> MAP<STRING,STRING>
DatabricksMap(Integer, String) -> MAP<INT,STRING>
DatabricksMap(String, DatabricksArray(Integer)) -> MAP<STRING,ARRAY<INT>>
"""

def __init__(self, key_type, value_type):
self.key_type = key_type() if isinstance(key_type, type) else key_type
self.value_type = value_type() if isinstance(value_type, type) else value_type


@compiles(DatabricksMap, "databricks")
def compile_databricks_map(type_, compiler, **kw):
key_type = compiler.process(type_.key_type, **kw)
value_type = compiler.process(type_.value_type, **kw)
return f"MAP<{key_type},{value_type}>"
20 changes: 19 additions & 1 deletion tests/test_local/test_ddl.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest
from sqlalchemy import Column, MetaData, String, Table, create_engine
from sqlalchemy import Column, MetaData, String, Table, Numeric, Integer, create_engine
from sqlalchemy.schema import (
CreateTable,
DropColumnComment,
DropTableComment,
SetColumnComment,
SetTableComment,
)
from databricks.sqlalchemy import DatabricksArray, DatabricksMap


class DDLTestBase:
Expand Down Expand Up @@ -94,3 +95,20 @@ def test_alter_table_drop_comment(self, table_with_comment):
stmt = DropTableComment(table_with_comment)
output = self.compile(stmt)
assert output == "COMMENT ON TABLE martin IS NULL"


class TestTableComplexTypeDDL(DDLTestBase):
@pytest.fixture(scope="class")
def metadata(self) -> MetaData:
metadata = MetaData()
col1 = Column("array_array_string", DatabricksArray(DatabricksArray(String)))
col2 = Column("map_string_string", DatabricksMap(String, String))
table = Table("complex_type", metadata, col1, col2)
return metadata

def test_create_table_with_complex_type(self, metadata):
stmt = CreateTable(metadata.tables["complex_type"])
output = self.compile(stmt)

assert "array_array_string ARRAY<ARRAY<STRING>>" in output
assert "map_string_string MAP<STRING,STRING>" in output
84 changes: 84 additions & 0 deletions tests/test_local/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,28 @@
get_comment_from_dte_output,
DatabricksSqlAlchemyParseException,
)
from sqlalchemy import (
BigInteger,
Boolean,
Date,
DateTime,
Integer,
Numeric,
String,
Time,
Uuid,
)

from databricks.sqlalchemy import (
DatabricksArray,
TIMESTAMP,
TINYINT,
DatabricksMap,
TIMESTAMP_NTZ,
)
from databricks.sqlalchemy import DatabricksDialect

dialect = DatabricksDialect()

# These are outputs from DESCRIBE TABLE EXTENDED
@pytest.mark.parametrize(
Expand Down Expand Up @@ -158,3 +179,66 @@ def test_filter_dict_by_value(match, output):

def test_get_comment_from_dte_output():
assert get_comment_from_dte_output(FMT_SAMPLE_DT_OUTPUT) == "some comment"


def get_databricks_non_compound_types():
return [
Integer,
String,
Boolean,
Date,
DateTime,
Time,
Uuid,
Numeric,
TINYINT,
TIMESTAMP,
TIMESTAMP_NTZ,
BigInteger
]


@pytest.mark.parametrize("internal_type", get_databricks_non_compound_types())
def test_array_parsing(internal_type):
array_type = DatabricksArray(internal_type())

actual_parsed = array_type.compile(dialect=dialect)
expected_parsed = "ARRAY<{}>".format(internal_type().compile(dialect=dialect))
assert actual_parsed == expected_parsed


@pytest.mark.parametrize("internal_type_1", get_databricks_non_compound_types())
@pytest.mark.parametrize("internal_type_2", get_databricks_non_compound_types())
def test_map_parsing(internal_type_1, internal_type_2):
map_type = DatabricksMap(internal_type_1(), internal_type_2())

actual_parsed = map_type.compile(dialect=dialect)
expected_parsed = "MAP<{},{}>".format(
internal_type_1().compile(dialect=dialect),
internal_type_2().compile(dialect=dialect),
)
assert actual_parsed == expected_parsed


@pytest.mark.parametrize("internal_type", get_databricks_non_compound_types())
def test_multilevel_array_type_parsing(internal_type):
array_type = DatabricksArray(DatabricksArray(DatabricksArray(internal_type())))

actual_parsed = array_type.compile(dialect=dialect)
expected_parsed = "ARRAY<ARRAY<ARRAY<{}>>>".format(
internal_type().compile(dialect=dialect)
)
assert actual_parsed == expected_parsed


@pytest.mark.parametrize("internal_type", get_databricks_non_compound_types())
def test_multilevel_map_type_parsing(internal_type):
map_type = DatabricksMap(
String, DatabricksMap(String, DatabricksMap(String, internal_type()))
)

actual_parsed = map_type.compile(dialect=dialect)
expected_parsed = "MAP<STRING,MAP<STRING,MAP<STRING,{}>>>".format(
internal_type().compile(dialect=dialect)
)
assert actual_parsed == expected_parsed
Loading