Skip to content

Commit 0f41252

Browse files
committed
feat: Added QdrantStore
1 parent f4b2472 commit 0f41252

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

datapipe/store/qdrant.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import pandas as pd
2+
import hashlib
3+
import uuid
4+
5+
from typing import Dict, List, Optional
6+
from qdrant_client import QdrantClient
7+
from qdrant_client.conversions import common_types as types
8+
from qdrant_client.http import models as rest
9+
from qdrant_client.http.exceptions import UnexpectedResponse
10+
11+
from datapipe.types import DataSchema, MetaSchema, IndexDF, DataDF, data_to_index
12+
from datapipe.store.table_store import TableStore
13+
14+
15+
class CollectionParams(rest.CreateCollection):
16+
pass
17+
18+
19+
class QdrantStore(TableStore):
20+
def __init__(
21+
self,
22+
name: str,
23+
host: str,
24+
port: int,
25+
schema: DataSchema,
26+
pk_field: str,
27+
embedding_field: str,
28+
collection_params: CollectionParams
29+
):
30+
super().__init__()
31+
self.name = name
32+
self.host = host
33+
self.port = port
34+
self.schema = schema
35+
self.pk_field = pk_field
36+
self.embedding_field = embedding_field
37+
self.collection_params = collection_params
38+
self.inited = False
39+
self.client: QdrantClient = None
40+
41+
pk_columns = [column for column in self.schema if column.primary_key]
42+
43+
if len(pk_columns) != 1 and pk_columns[0].name != pk_field:
44+
raise ValueError("Incorrect prymary key columns in schema")
45+
46+
self.paylods_filelds = [column.name for column in self.schema if column.name != self.embedding_field]
47+
48+
def __init(self):
49+
self.client = QdrantClient(host=self.host, port=self.port)
50+
51+
try:
52+
self.client.get_collection(self.name)
53+
except UnexpectedResponse as e:
54+
if e.status_code == 404:
55+
self.client.http.collections_api.create_collection(
56+
collection_name=self.name,
57+
create_collection=self.collection_params
58+
)
59+
60+
def __check_init(self):
61+
if not self.inited:
62+
self.__init()
63+
self.inited = True
64+
65+
def __get_ids(self, df):
66+
return df[self.pk_field].apply(
67+
lambda x: str(uuid.UUID(bytes=hashlib.md5(str(x).encode('utf-8')).digest()))
68+
).to_list()
69+
70+
def get_primary_schema(self) -> DataSchema:
71+
return [column for column in self.schema if column.primary_key]
72+
73+
def get_meta_schema(self) -> MetaSchema:
74+
return []
75+
76+
def insert_rows(self, df: DataDF) -> None:
77+
self.__check_init()
78+
79+
if len(df) == 0:
80+
return
81+
82+
self.client.upsert(
83+
self.name,
84+
rest.Batch(
85+
ids=self.__get_ids(df),
86+
vectors=df[self.embedding_field].apply(list).to_list(),
87+
payloads=df[self.paylods_filelds].to_dict(orient='records')
88+
),
89+
wait=True,
90+
)
91+
92+
def update_rows(self, df: DataDF) -> None:
93+
self.insert_rows(df)
94+
95+
def delete_rows(self, idx: IndexDF) -> None:
96+
self.__check_init()
97+
98+
if len(idx) == 0:
99+
return
100+
101+
self.client.delete(
102+
self.name,
103+
rest.PointIdsList(
104+
points=self.__get_ids(idx)
105+
),
106+
wait=True,
107+
)
108+
109+
def read_rows(self, idx: Optional[IndexDF] = None) -> DataDF:
110+
self.__check_init()
111+
112+
if not idx:
113+
raise Exception("Qrand doesn't support full store reading")
114+
115+
response = self.client.http.points_api.get_points(
116+
self.name,
117+
point_request=rest.PointRequest(
118+
ids=self.__get_ids(idx),
119+
with_payload=True,
120+
with_vector=True
121+
)
122+
)
123+
124+
records = []
125+
126+
for point in response.result:
127+
record = point.payload
128+
record[self.embedding_field] = point.vector
129+
130+
records.append(record)
131+
132+
return pd.DataFrame.from_records(records)[[column.name for column in self.schema]]

0 commit comments

Comments
 (0)