Skip to content

Commit

Permalink
IRIS ListBuild as a Column type
Browse files Browse the repository at this point in the history
  • Loading branch information
daimor committed Dec 20, 2023
1 parent f471175 commit 6732a07
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sqlalchemy_iris/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .base import TINYINT
from .base import VARBINARY
from .base import VARCHAR
from .base import IRISListBuild

base.dialect = dialect = iris.dialect

Expand All @@ -45,5 +46,6 @@
"TINYINT",
"VARBINARY",
"VARCHAR",
"IRISListBuild",
"dialect",
]
1 change: 1 addition & 0 deletions sqlalchemy_iris/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def check_constraints(cls):
from .types import IRISDate
from .types import IRISDateTime
from .types import IRISUniqueIdentifier
from .types import IRISListBuild


ischema_names = {
Expand Down
47 changes: 47 additions & 0 deletions sqlalchemy_iris/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import datetime
from decimal import Decimal
from sqlalchemy.sql import sqltypes
from sqlalchemy.types import UserDefinedType
from uuid import UUID as _python_UUID
from intersystems_iris import IRISList

HOROLOG_ORDINAL = datetime.date(1840, 12, 31).toordinal()

Expand Down Expand Up @@ -194,6 +196,47 @@ def process(value):
return None


class IRISListBuild(UserDefinedType):
cache_ok = True

def __init__(self, max_items: int = None, item_type: type = float):
super(UserDefinedType, self).__init__()
self.max_items = max_items
max_length = None
if type is float or type is int:
max_length = max_items * 10
elif max_items:
max_length = 65535
self.max_length = max_length

def get_col_spec(self, **kw):
if self.max_length is None:
return "VARBINARY(65535)"
return "VARBINARY(%d)" % self.max_length

def bind_processor(self, dialect):
def process(value):
irislist = IRISList()
if not value:
return value
if not isinstance(value, list) and not isinstance(value, tuple):
raise ValueError("expected list or tuple, got '%s'" % type(value))
for item in value:
irislist.add(item)
return irislist.getBuffer()

return process

def result_processor(self, dialect, coltype):
def process(value):
if value:
irislist = IRISList(value)
return irislist._list_data
return value

return process


class BIT(sqltypes.TypeEngine):
__visit_name__ = "BIT"

Expand All @@ -212,3 +255,7 @@ class LONGVARCHAR(sqltypes.VARCHAR):

class LONGVARBINARY(sqltypes.VARBINARY):
__visit_name__ = "LONGVARBINARY"


class LISTBUILD(sqltypes.VARBINARY):
__visit_name__ = "VARCHAR"
40 changes: 40 additions & 0 deletions tests/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sqlalchemy.types import VARBINARY
from sqlalchemy.types import BINARY
from sqlalchemy_iris import TINYINT
from sqlalchemy_iris import IRISListBuild
from sqlalchemy.exc import DatabaseError
import pytest

Expand Down Expand Up @@ -281,3 +282,42 @@ class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest):
)
def test_fk_ref(self, connection, metadata, use_composite, tablename, columnname):
super().test_fk_ref(connection, metadata, use_composite, tablename, columnname)


class IRISListBuildTest(fixtures.TablesTest):
__backend__ = True

@classmethod
def define_tables(cls, metadata):
Table(
"data",
metadata,
Column("val", IRISListBuild(10, float)),
)

@classmethod
def fixtures(cls):
return dict(
data=(
("val",),
([1.0] * 50,),
([1.23] * 50,),
([i for i in range(0, 50)],),
(None,),
)
)

def _assert_result(self, select, result):
with config.db.connect() as conn:
eq_(conn.execute(select).fetchall(), result)

def test_listbuild(self):
self._assert_result(
select(self.tables.data),
[
([1.0] * 50,),
([1.23] * 50,),
([i for i in range(0, 50)],),
(None,),
],
)

0 comments on commit 6732a07

Please sign in to comment.