Skip to content

Commit

Permalink
Fix types with stubs
Browse files Browse the repository at this point in the history
  • Loading branch information
Askir committed Oct 24, 2024
1 parent 1df8ac8 commit 08828ff
Show file tree
Hide file tree
Showing 22 changed files with 322 additions and 61 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ strict = true
ignore_missing_imports = true
namespace_packages = true

[tool.pyright]
typeCheckingMode = "strict"
stubPath = "timescale_vector/typings"

[tool.ruff]
line-length = 120
indent-width = 4
Expand Down Expand Up @@ -105,12 +109,12 @@ select = [
[tool.uv]
dev-dependencies = [
"mypy>=1.12.0",
"types-psycopg2>=2.9.21.20240819",
"ruff>=0.6.9",
"pytest>=8.3.3",
"langchain>=0.3.3",
"langchain-openai>=0.2.2",
"langchain-community>=0.3.2",
"pandas>=2.2.3",
"pytest-asyncio>=0.24.0",
"pyright>=1.1.386",
]
4 changes: 2 additions & 2 deletions tests/async_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ async def search_date(start_date: datetime | str | None, end_date: datetime | st
assert len(rec) == expected

# using filters
filter = {}
filter: dict[str, str | datetime] = {}
if start_date is not None:
filter["__start_date"] = start_date
if end_date is not None:
Expand All @@ -338,7 +338,7 @@ async def search_date(start_date: datetime | str | None, end_date: datetime | st
rec = await vec.search([1.0, 2.0], limit=4, filter=filter)
assert len(rec) == expected
# using predicates
predicates = []
predicates: list[tuple[str, str, str|datetime]] = []
if start_date is not None:
predicates.append(("__uuid_timestamp", ">=", start_date))
if end_date is not None:
Expand Down
4 changes: 2 additions & 2 deletions tests/pg_vectorizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_document(blog: dict[str, Any]) -> list[Document]:
chunk_size=1000,
chunk_overlap=200,
)
docs = []
docs: list[Document] = []
for chunk in text_splitter.split_text(blog["contents"]):
content = f"Author {blog['author']}, title: {blog['title']}, contents:{chunk}"
metadata = {
Expand Down Expand Up @@ -71,7 +71,7 @@ def embed_and_write(blog_instances: list[Any], vectorizer: Vectorize) -> None:
metadata_for_delete = [{"blog_id": blog["locked_id"]} for blog in blog_instances]
vector_store.delete_by_metadata(metadata_for_delete)

documents = []
documents: list[Document] = []
for blog in blog_instances:
# skip blogs that are not published yet, or are deleted (will be None because of left join)
if blog["published_time"] is not None:
Expand Down
4 changes: 2 additions & 2 deletions tests/sync_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def search_date(start_date: datetime | str | None, end_date: datetime | str | No
assert len(rec) == expected

# using filters
filter = {}
filter: dict[str, str|datetime] = {}
if start_date is not None:
filter["__start_date"] = start_date
if end_date is not None:
Expand All @@ -250,7 +250,7 @@ def search_date(start_date: datetime | str | None, end_date: datetime | str | No
rec = vec.search([1.0, 2.0], limit=4, filter=filter)
assert len(rec) == expected
# using predicates
predicates = []
predicates: list[tuple[str, str, str|datetime]] = []
if start_date is not None:
predicates.append(("__uuid_timestamp", ">=", start_date))
if end_date is not None:
Expand Down
34 changes: 19 additions & 15 deletions timescale_vector/client/async_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json
import uuid
from collections.abc import Iterable, Mapping
from collections.abc import Mapping
from datetime import datetime, timedelta
from typing import Any, Literal
from typing import Any, Literal, cast

from asyncpg import Connection, Pool, Record, connect, create_pool
from asyncpg.pool import PoolAcquireContext
from pgvector.asyncpg import register_vector
from pgvector.asyncpg import register_vector # type: ignore

from timescale_vector.client.index import BaseIndex, QueryParams
from timescale_vector.client.predicates import Predicates
Expand Down Expand Up @@ -77,7 +77,7 @@ async def _default_max_db_connections(self) -> int:
await conn.close()
if num_connections is None:
return 10
return num_connections # type: ignore
return cast(int, num_connections)

async def connect(self) -> PoolAcquireContext:
"""
Expand All @@ -94,7 +94,12 @@ async def connect(self) -> PoolAcquireContext:
async def init(conn: Connection) -> None:
await register_vector(conn)
# decode to a dict, but accept a string as input in upsert
await conn.set_type_codec("jsonb", encoder=str, decoder=json.loads, schema="pg_catalog")
await conn.set_type_codec(
"jsonb",
encoder=str,
decoder=json.loads,
schema="pg_catalog"
)

self.pool = await create_pool(
dsn=self.service_url,
Expand Down Expand Up @@ -122,12 +127,12 @@ async def table_is_empty(self) -> bool:
rec = await pool.fetchrow(query)
return rec is None

def munge_record(self, records: list[tuple[Any, ...]]) -> Iterable[tuple[uuid.UUID, str, str, list[float]]]:

def munge_record(self, records: list[tuple[Any, ...]]) -> list[tuple[uuid.UUID, str, str, list[float]]]:
metadata_is_dict = isinstance(records[0][1], dict)
if metadata_is_dict:
munged_records = map(lambda item: Async._convert_record_meta_to_json(item), records)

return munged_records if metadata_is_dict else records
return list(map(lambda item: Async._convert_record_meta_to_json(item), records))
return records

@staticmethod
def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]:
Expand Down Expand Up @@ -188,15 +193,15 @@ async def delete_by_ids(self, ids: list[uuid.UUID] | list[str]) -> list[Record]:
"""
(query, params) = self.builder.delete_by_ids_query(ids)
async with await self.connect() as pool:
return await pool.fetch(query, *params) # type: ignore
return await pool.fetch(query, *params)

async def delete_by_metadata(self, filter: dict[str, str] | list[dict[str, str]]) -> list[Record]:
"""
Delete records by metadata filters.
"""
(query, params) = self.builder.delete_by_metadata_query(filter)
async with await self.connect() as pool:
return await pool.fetch(query, *params) # type: ignore
return await pool.fetch(query, *params)

async def drop_table(self) -> None:
"""
Expand All @@ -221,7 +226,7 @@ async def _get_approx_count(self) -> int:
query = self.builder.get_approx_count_query()
async with await self.connect() as pool:
rec = await pool.fetchrow(query)
return rec[0] if rec is not None else 0
return cast(int, rec[0] if rec is not None else 0)

async def drop_embedding_index(self) -> None:
"""
Expand All @@ -248,7 +253,6 @@ async def create_embedding_index(self, index: BaseIndex) -> None:
-------
None
"""
# todo: can we make geting the records lazy?
num_records = await self._get_approx_count()
query = self.builder.create_embedding_index_query(index, lambda: num_records)

Expand Down Expand Up @@ -294,7 +298,7 @@ async def search(
statements = query_params.get_statements()
for statement in statements:
await pool.execute(statement)
return await pool.fetch(query, *params) # type: ignore
return await pool.fetch(query, *params)
else:
async with await self.connect() as pool:
return await pool.fetch(query, *params) # type: ignore
return await pool.fetch(query, *params)
14 changes: 4 additions & 10 deletions timescale_vector/client/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Predicates:
"@>": "@>", # array contains
}

PredicateValue = str | int | float | datetime | list | tuple # type: ignore
PredicateValue = str | int | float | datetime | list[Any] | tuple[Any]

def __init__(
self,
Expand Down Expand Up @@ -53,13 +53,7 @@ def __init__(
if isinstance(clauses[0], str):
if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)):
raise ValueError(f"Invalid clause format: {clauses}")
self.clauses: list[
Predicates
| tuple[str, Predicates.PredicateValue]
| tuple[str, str, Predicates.PredicateValue]
| str
| Predicates.PredicateValue
] = [clauses]
self.clauses = [clauses]
else:
self.clauses = list(clauses)

Expand All @@ -85,9 +79,9 @@ def add_clause(
if isinstance(clause[0], str):
if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)):
raise ValueError(f"Invalid clause format: {clause}")
self.clauses.append(clause)
self.clauses.append(clause) # type: ignore
else:
self.clauses.extend(list(clause))
self.clauses.extend(list(clause)) # type: ignore

def __and__(self, other: "Predicates") -> "Predicates":
new_predicates = Predicates(self, other, operator="AND")
Expand Down
5 changes: 3 additions & 2 deletions timescale_vector/client/query_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyright: reportPrivateUsage=false
import json
import uuid
from collections.abc import Callable, Mapping
Expand Down Expand Up @@ -261,7 +262,7 @@ def _where_clause_for_filter(
json_object = json.dumps(filter)
params = params + [json_object]
elif isinstance(filter, list):
any_params = []
any_params: list[str] = []
for _idx, filter_dict in enumerate(filter, start=len(params) + 1):
any_params.append(json.dumps(filter_dict))
where = f"metadata @> ANY(${len(params) + 1}::jsonb[])"
Expand Down Expand Up @@ -310,7 +311,7 @@ def search_query(
if end_date is not None:
del filter["__end_date"]

where_clauses = []
where_clauses: list[str] = []
if filter is not None:
(where_filter, params) = self._where_clause_for_filter(params, filter)
where_clauses.append(where_filter)
Expand Down
32 changes: 17 additions & 15 deletions timescale_vector/client/sync_client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import json
import re
import uuid
from collections.abc import Iterable, Iterator, Mapping
from collections.abc import Iterator, Mapping
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Any, Literal

import numpy as np
import pgvector.psycopg2
import psycopg2.extras
import psycopg2.pool
from numpy import ndarray
from pgvector.psycopg2 import register_vector # type: ignore
from psycopg2 import connect
from psycopg2.extensions import connection as PSYConnection
from psycopg2.extras import DictCursor, register_uuid
from psycopg2.pool import SimpleConnectionPool

from timescale_vector.client.index import BaseIndex, QueryParams
from timescale_vector.client.predicates import Predicates
Expand Down Expand Up @@ -65,25 +67,25 @@ def __init__(
schema_name,
)
self.service_url: str = service_url
self.pool: psycopg2.pool.SimpleConnectionPool | None = None
self.pool: SimpleConnectionPool | None = None
self.max_db_connections: int | None = max_db_connections
self.time_partition_interval: timedelta | None = time_partition_interval
psycopg2.extras.register_uuid()
register_uuid()

def default_max_db_connections(self) -> int:
"""
Gets a default value for the number of max db connections to use.
"""
query = self.builder.default_max_db_connection_query()
conn = psycopg2.connect(dsn=self.service_url)
conn = connect(dsn=self.service_url)
with conn.cursor() as cur:
cur.execute(query)
num_connections = cur.fetchone()
conn.close()
return num_connections[0] # type: ignore

@contextmanager
def connect(self) -> Iterator[psycopg2.extensions.connection]:
def connect(self) -> Iterator[PSYConnection]:
"""
Establishes a connection to a PostgreSQL database using psycopg2 and allows it's
use in a context manager.
Expand All @@ -92,15 +94,15 @@ def connect(self) -> Iterator[psycopg2.extensions.connection]:
if self.max_db_connections is None:
self.max_db_connections = self.default_max_db_connections()

self.pool = psycopg2.pool.SimpleConnectionPool(
self.pool = SimpleConnectionPool(
1,
self.max_db_connections,
dsn=self.service_url,
cursor_factory=psycopg2.extras.DictCursor,
cursor_factory=DictCursor,
)

connection = self.pool.getconn()
pgvector.psycopg2.register_vector(connection)
register_vector(connection)
try:
yield connection
connection.commit()
Expand Down Expand Up @@ -157,12 +159,12 @@ def table_is_empty(self) -> bool:
rec = cur.fetchone()
return rec is None

def munge_record(self, records: list[tuple[Any, ...]]) -> Iterable[tuple[uuid.UUID, str, str, list[float]]]:
def munge_record(self, records: list[tuple[Any, ...]]) -> list[tuple[uuid.UUID, str, str, list[float]]]:
metadata_is_dict = isinstance(records[0][1], dict)
if metadata_is_dict:
munged_records = map(lambda item: Sync._convert_record_meta_to_json(item), records)
return list(map(lambda item: Sync._convert_record_meta_to_json(item), records))

return munged_records if metadata_is_dict else records
return records

@staticmethod
def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]:
Expand Down Expand Up @@ -200,7 +202,7 @@ def create_tables(self) -> None:
query = self.builder.get_create_query()
# don't use a connection pool for this because the vector extension may not be installed yet
# and if it's not installed, register_vector will fail.
conn = psycopg2.connect(dsn=self.service_url)
conn = connect(dsn=self.service_url)
with conn.cursor() as cur:
cur.execute(query)
conn.commit()
Expand Down
2 changes: 1 addition & 1 deletion timescale_vector/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def uuid_from_time(
"""
if time_arg is None:
return uuid.uuid1(node, clock_seq)
if hasattr(time_arg, "utctimetuple"):
if isinstance(time_arg, datetime):
# this is different from the Cassandra version,
# we assume that a naive datetime is in system time and convert it to UTC
# we do this because naive datetimes are interpreted as timestamps (without timezone) in postgres
Expand Down
1 change: 1 addition & 0 deletions timescale_vector/pgvectorizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyright: reportPrivateUsage=false
__all__ = ["Vectorize"]

import re
Expand Down
Empty file.
46 changes: 46 additions & 0 deletions timescale_vector/typings/asyncpg/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Any, Protocol, TypeVar, Sequence
from . import pool, connection

# Core types
T = TypeVar('T')

class Record(Protocol):
def __getitem__(self, key: int | str) -> Any: ...
def __iter__(self) -> Any: ...
def __len__(self) -> int: ...
def get(self, key: str, default: T = None) -> T | None: ...
def keys(self) -> Sequence[str]: ...
def values(self) -> Sequence[Any]: ...
def items(self) -> Sequence[tuple[str, Any]]: ...

# Allow dictionary-style access to fields
def __getattr__(self, name: str) -> Any: ...

# Re-exports
Connection = connection.Connection
Pool = pool.Pool
Record = Record

# Functions
async def connect(
dsn: str | None = None,
*,
host: str | None = None,
port: int | None = None,
user: str | None = None,
password: str | None = None,
database: str | None = None,
timeout: int = 60
) -> Connection: ...

async def create_pool(
dsn: str | None = None,
*,
min_size: int = 10,
max_size: int = 10,
max_queries: int = 50000,
max_inactive_connection_lifetime: float = 300.0,
setup: Any | None = None,
init: Any | None = None,
**connect_kwargs: Any
) -> Pool: ...
Loading

0 comments on commit 08828ff

Please sign in to comment.