Skip to content
This repository was archived by the owner on May 5, 2022. It is now read-only.

Commit

Permalink
test: add datatype.parse_sqltype testcases
Browse files Browse the repository at this point in the history
Signed-off-by: Đặng Minh Dũng <[email protected]>
  • Loading branch information
dungdm93 committed Jan 24, 2021
1 parent b8e0dcd commit 565ef0f
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 14 deletions.
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import tests.assertions # noqa
59 changes: 45 additions & 14 deletions sqlalchemy_trino/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,17 @@

# === Date and time ===
'date': sqltypes.DATE,
'time': sqltypes.Time,
'time with time zone': sqltypes.Time,
'time': sqltypes.TIME,
'timestamp': sqltypes.TIMESTAMP,
'timestamp with time zone': sqltypes.TIMESTAMP,

# 'interval year to month': IntervalOfYear, # TODO
'interval day to second': sqltypes.Interval,

# 'interval year to month':
# 'interval day to second':
#
# === Structural ===
'array': sqltypes.ARRAY,
# 'map': MAP
# 'row': ROW

# 'array': ARRAY,
# 'map': MAP
# 'row': ROW
#
# === Mixed ===
# 'ipaddress': IPADDRESS
# 'uuid': UUID,
Expand All @@ -53,13 +51,39 @@
# 'tdigest': TDIGEST,
}

SQLType = Union[TypeEngine, Type[TypeEngine]]


class MAP(TypeEngine):
pass
__visit_name__ = "MAP"

def __init__(self, key_type: SQLType, value_type: SQLType):
if isinstance(key_type, type):
key_type = key_type()
self.key_type: TypeEngine = key_type

if isinstance(value_type, type):
value_type = value_type()
self.value_type: TypeEngine = value_type

@property
def python_type(self):
return dict


class ROW(TypeEngine):
pass
__visit_name__ = "ROW"

def __init__(self, attr_types: Dict[str, SQLType]):
for name, attr_type in attr_types.items():
if isinstance(attr_type, type):
attr_type = attr_type()
attr_types[name] = attr_type
self.attr_types: Dict[str, TypeEngine] = attr_types

@property
def python_type(self):
return dict


def split(string: str, delimiter: str = ',',
Expand Down Expand Up @@ -106,15 +130,22 @@ def parse_sqltype(type_str: str) -> TypeEngine:

if type_name == "array":
item_type = parse_sqltype(type_opts)
if isinstance(item_type, sqltypes.ARRAY):
dimensions = (item_type.dimensions or 1) + 1
return sqltypes.ARRAY(item_type.item_type, dimensions=dimensions)
return sqltypes.ARRAY(item_type)
elif type_name == "map":
key_type_str, value_type_str = split(type_opts)
key_type = parse_sqltype(key_type_str)
value_type = parse_sqltype(value_type_str)
return MAP(key_type, value_type)
elif type_name == "row":
attr_types = split(type_opts)
return ROW() # TODO
attr_types: Dict[str, SQLType] = {}
for attr_str in split(type_opts):
name, attr_type_str = split(attr_str.strip(), delimiter=' ')
attr_type = parse_sqltype(attr_type_str)
attr_types[name] = attr_type
return ROW(attr_types)

if type_name not in _type_map:
util.warn(f"Did not recognize type '{type_name}'")
Expand Down
34 changes: 34 additions & 0 deletions tests/assertions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from assertpy import add_extension, assert_that
from sqlalchemy.sql.sqltypes import ARRAY

from sqlalchemy_trino.datatype import SQLType, MAP, ROW


def assert_sqltype(this: SQLType, that: SQLType):
if isinstance(this, type):
this = this()
if isinstance(that, type):
that = that()
assert_that(type(this)).is_same_as(type(that))
if isinstance(this, ARRAY):
assert_sqltype(this.item_type, that.item_type)
if this.dimensions is None or this.dimensions == 1:
assert_that(that.dimensions).is_in(None, 1)
else:
assert_that(this.dimensions).is_equal_to(this.dimensions)
elif isinstance(this, MAP):
assert_sqltype(this.key_type, that.key_type)
assert_sqltype(this.value_type, that.value_type)
elif isinstance(this, ROW):
assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types))
for name, this_attr in this.attr_types.items():
that_attr = this.attr_types[name]
assert_sqltype(this_attr, that_attr)
else:
assert_that(str(this)).is_equal_to(str(that))


@add_extension
def is_sqltype(self, that):
this = self.val
assert_sqltype(this, that)
111 changes: 111 additions & 0 deletions tests/test_datatype_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import pytest
from assertpy import assert_that
from sqlalchemy.sql.sqltypes import *
from sqlalchemy.sql.type_api import TypeEngine

from sqlalchemy_trino import datatype
from sqlalchemy_trino.datatype import MAP, ROW


@pytest.mark.parametrize(
'type_str, sql_type',
datatype._type_map.items(),
ids=datatype._type_map.keys()
)
def test_parse_simple_type(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
if not isinstance(actual_type, type):
actual_type = type(actual_type)
assert_that(actual_type).is_equal_to(sql_type)


parse_type_options_testcases = {
'VARCHAR(10)': VARCHAR(10),
'DECIMAL(20)': DECIMAL(20),
'DECIMAL(20, 3)': DECIMAL(20, 3),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_type_options_testcases.items(),
ids=parse_type_options_testcases.keys()
)
def test_parse_type_options(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_array_testcases = {
'array(integer)': ARRAY(INTEGER()),
'array(varchar(10))': ARRAY(VARCHAR(10)),
'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)),
'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_array_testcases.items(),
ids=parse_array_testcases.keys()
)
def test_parse_array(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_map_testcases = {
'map(char, integer)': MAP(CHAR(), INTEGER()),
'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)),
'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)),
'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))),
'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_map_testcases.items(),
ids=parse_map_testcases.keys()
)
def test_parse_map(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_row_testcases = {
'row(a integer, b varchar)': ROW(dict(a=INTEGER(), b=VARCHAR())),
'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))),
'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))':
ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_row_testcases.items(),
ids=parse_row_testcases.keys()
)
def test_parse_row(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_datetime_testcases = {
'date': DATE(),
'time': TIME(),
'time with time zone': TIME(timezone=True),
'timestamp': TIMESTAMP(),
'timestamp with time zone': TIMESTAMP(timezone=True),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_datetime_testcases.items(),
ids=parse_datetime_testcases.keys()
)
def test_parse_datetime(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)

0 comments on commit 565ef0f

Please sign in to comment.