Skip to content

Commit

Permalink
Native Vector support
Browse files Browse the repository at this point in the history
  • Loading branch information
daimor committed Feb 13, 2024
1 parent 122c1b6 commit 55b7c01
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 4 deletions.
2 changes: 2 additions & 0 deletions sqlalchemy_iris/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .base import VARBINARY
from .base import VARCHAR
from .base import IRISListBuild
from .base import IRISVector

base.dialect = dialect = iris.dialect

Expand All @@ -47,5 +48,6 @@
"VARBINARY",
"VARCHAR",
"IRISListBuild",
"IRISVector",
"dialect",
]
23 changes: 21 additions & 2 deletions sqlalchemy_iris/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import intersystems_iris.dbapi._DBAPI as dbapi
import intersystems_iris._IRISNative as IRISNative
from . import information_schema as ischema
from sqlalchemy import exc
from sqlalchemy.orm import aliased
Expand Down Expand Up @@ -91,7 +92,8 @@ def check_constraints(cls):
from .types import IRISDate
from .types import IRISDateTime
from .types import IRISUniqueIdentifier
from .types import IRISListBuild
from .types import IRISListBuild # noqa
from .types import IRISVector # noqa


ischema_names = {
Expand Down Expand Up @@ -398,7 +400,9 @@ def check_constraints(cls):
class IRISCompiler(sql.compiler.SQLCompiler):
"""IRIS specific idiosyncrasies"""

def visit_exists_unary_operator(self, element, operator, within_columns_clause=False, **kw):
def visit_exists_unary_operator(
self, element, operator, within_columns_clause=False, **kw
):
if within_columns_clause:
return "(SELECT 1 WHERE EXISTS(%s))" % self.process(element.element, **kw)
else:
Expand Down Expand Up @@ -853,6 +857,8 @@ class IRISDialect(default.DefaultDialect):
supports_empty_insert = False
supports_is_distinct_from = False

supports_vectors = None

colspecs = colspecs

ischema_names = ischema_names
Expand All @@ -870,6 +876,11 @@ class IRISDialect(default.DefaultDialect):
def __init__(self, **kwargs):
default.DefaultDialect.__init__(self, **kwargs)

def _get_server_version_info(self, connection):
server_version = connection.connection._connection_info._server_version
server_version = server_version[server_version.find("Version") + 8:].split(" ")[0].split(".")
return tuple([int(''.join(filter(str.isdigit, v))) for v in server_version])

_isolation_lookup = set(
[
"READ UNCOMMITTED",
Expand All @@ -888,6 +899,14 @@ def on_connect(conn):
if super_ is not None:
super_(conn)

iris = IRISNative.createIRIS(conn)
self.supports_vectors = iris.classMethodBoolean("%SYSTEM.License", "GetFeature", 28)
if self.supports_vectors:
with conn.cursor() as cursor:
# Distance or similarity
cursor.execute("select vector_cosine(to_vector('1'), to_vector('1'))")
self.vector_cosine_similarity = cursor.fetchone()[0] == 0

self._dictionary_access = False
with conn.cursor() as cursor:
cursor.execute("%CHECKPRIV SELECT ON %Dictionary.PropertyDefinition")
Expand Down
12 changes: 12 additions & 0 deletions sqlalchemy_iris/requirements.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from sqlalchemy.testing.requirements import SuiteRequirements
from sqlalchemy.testing.exclusions import against
from sqlalchemy.testing.exclusions import only_on

try:
from alembic.testing.requirements import SuiteRequirements as AlembicRequirements
Expand Down Expand Up @@ -257,3 +259,13 @@ def fk_onupdate_restrict(self):
@property
def fk_ondelete_restrict(self):
return exclusions.closed()

def _iris_vector(self, config):
if not against(config, "iris >= 2024.1"):
return False
else:
return config.db.dialect.supports_vectors

@property
def iris_vector(self):
return only_on(lambda config: self._iris_vector(config))
70 changes: 68 additions & 2 deletions sqlalchemy_iris/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import datetime
from decimal import Decimal
from sqlalchemy import func
from sqlalchemy import func, text
from sqlalchemy.sql import sqltypes
from sqlalchemy.types import UserDefinedType
from sqlalchemy.types import UserDefinedType, Float
from uuid import UUID as _python_UUID
from intersystems_iris import IRISList

Expand Down Expand Up @@ -247,6 +247,72 @@ def func(self, funcname: str, other):
return getattr(func, funcname)(self, irislist.getBuffer())


class IRISVector(UserDefinedType):
cache_ok = True

def __init__(self, max_items: int = None, item_type: type = float):
super(UserDefinedType, self).__init__()
if item_type not in [float, int, Decimal]:
raise TypeError(
f"IRISVector expected int, float or Decimal; got {type.__name__}; expected: int, float, Decimal"
)
self.max_items = max_items
self.item_type = item_type
item_type_server = (
"decimal"
if self.item_type is float
else "float"
if self.item_type is Decimal
else "int"
)
self.item_type_server = item_type_server

def get_col_spec(self, **kw):
if self.max_items is None and self.item_type is None:
return "VECTOR"
len = str(self.max_items or "")
return f"VECTOR({self.item_type_server}, {len})"

def bind_processor(self, dialect):
def process(value):
if not value:
return value
if not isinstance(value, list) and not isinstance(value, tuple):
raise ValueError("expected list or tuple, got '%s'" % type(value))
return f"[{','.join([str(v) for v in value])}]"

return process

def result_processor(self, dialect, coltype):
def process(value):
if not value:
return value
vals = value.split(",")
vals = [self.item_type(v) for v in vals]
return vals

return process

class comparator_factory(UserDefinedType.Comparator):
# def l2_distance(self, other):
# return self.func('vector_l2', other)

def max_inner_product(self, other):
return self.func('vector_dot_product', other)

def cosine_distance(self, other):
return self.func('vector_cosine', other)

def cosine(self, other):
return (1 - self.func('vector_cosine', other))

def func(self, funcname: str, other):
if not isinstance(other, list) and not isinstance(other, tuple):
raise ValueError("expected list or tuple, got '%s'" % type(other))
othervalue = f"[{','.join([str(v) for v in other])}]"
return getattr(func, funcname)(self, func.to_vector(othervalue, text(self.type.item_type_server)))


class BIT(sqltypes.TypeEngine):
__visit_name__ = "BIT"

Expand Down
96 changes: 96 additions & 0 deletions tests/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from sqlalchemy.types import VARBINARY
from sqlalchemy.types import BINARY
from sqlalchemy_iris import TINYINT
from sqlalchemy_iris import INTEGER
from sqlalchemy_iris import IRISListBuild
from sqlalchemy_iris import IRISVector
from sqlalchemy.exc import DatabaseError
import pytest

Expand Down Expand Up @@ -337,3 +339,97 @@ def test_listbuild(self):
([1.0] * 50, 1),
],
)


class IRISVectorTest(fixtures.TablesTest):
__backend__ = True

__requires__ = ("iris_vector",)

@classmethod
def define_tables(cls, metadata):
Table(
"data",
metadata,
Column("id", INTEGER),
Column("emb", IRISVector(3, float)),
)

@classmethod
def fixtures(cls):
return dict(
data=(
(
"id",
"emb",
),
(
1,
[1, 1, 1],
),
(
2,
[2, 2, 2],
),
(
3,
[1, 1, 2],
),
)
)

def _assert_result(self, select, result):
with config.db.connect() as conn:
eq_(conn.execute(select).fetchall(), result)

def test_vector(self):
self._assert_result(
select(self.tables.data.c.emb),
[
([1, 1, 1],),
([2, 2, 2],),
([1, 1, 2],),
],
)
self._assert_result(
select(self.tables.data.c.id).where(self.tables.data.c.emb == [2, 2, 2]),
[
(2,),
],
)

def test_cosine(self):
self._assert_result(
select(
self.tables.data.c.id,
).order_by(self.tables.data.c.emb.cosine([1, 1, 1])),
[
(1,),
(2,),
(3,),
],
)

def test_cosine_distance(self):
self._assert_result(
select(
self.tables.data.c.id,
).order_by(1 - self.tables.data.c.emb.cosine_distance([1, 1, 1])),
[
(1,),
(2,),
(3,),
],
)

def test_max_inner_product(self):
self._assert_result(
select(
self.tables.data.c.id,
).order_by(self.tables.data.c.emb.max_inner_product([1, 1, 1])),
[
(1,),
(3,),
(2,),
],
)

0 comments on commit 55b7c01

Please sign in to comment.