From 4545f22ed6a0b49330f50e45a555e32662799926 Mon Sep 17 00:00:00 2001 From: vasilije Date: Tue, 17 Dec 2024 20:58:03 +0100 Subject: [PATCH] First draft of relationship embeddings --- .../vector/lancedb/LanceDBAdapter.py | 45 +++++++++++- cognee/infrastructure/engine/__init__.py | 1 + .../engine/models/Relationship.py | 70 +++++++++++++++++++ 3 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 cognee/infrastructure/engine/models/Relationship.py diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 37d340004..49412d19a 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -6,7 +6,7 @@ from lancedb.pydantic import Vector, LanceModel from cognee.exceptions import InvalidValueError -from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine import DataPoint, Relationship from cognee.infrastructure.files.storage import LocalStorage from cognee.modules.storage.utils import copy_model, get_own_properties from ..models.ScoredResult import ScoredResult @@ -72,6 +72,49 @@ class LanceDataPoint(LanceModel): exist_ok = True, ) + + + async def create_relationships(self, collection_name: str, relationships: list[Relationship]): + """Create and store Relationship embeddings in LanceDB.""" + connection = await self.get_connection() + + # Ensure collection exists + if not await self.has_collection(collection_name): + await self.create_collection(collection_name, Relationship) + + collection = await connection.open_table(collection_name) + + # Generate embeddings + data_vectors = await self.embed_data([ + " ".join([str(v) for v in rel.get_embeddable_properties().values()]) + for rel in relationships + ]) + + # Dynamic LanceDataPoint class for Relationship + vector_size = self.embedding_engine.get_vector_size() + + class LanceRelationship(LanceModel): + id: str + vector: Vector(vector_size) + payload: dict + + # Prepare LanceDB-compatible data points + lance_relationships = [ + LanceRelationship( + id=str(rel.id), + vector=data_vectors[index], + payload=rel.to_dict() + ) + for index, rel in enumerate(relationships) + ] + + # Insert relationships into LanceDB + await collection.merge_insert("id") \ + .when_matched_update_all() \ + .when_not_matched_insert_all() \ + .execute(lance_relationships) + print(f"Inserted {len(relationships)} relationships into LanceDB") + async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): connection = await self.get_connection() diff --git a/cognee/infrastructure/engine/__init__.py b/cognee/infrastructure/engine/__init__.py index 26f567da9..d65a3a2bd 100644 --- a/cognee/infrastructure/engine/__init__.py +++ b/cognee/infrastructure/engine/__init__.py @@ -1 +1,2 @@ from .models.DataPoint import DataPoint +from .models.Relationship import Relationship diff --git a/cognee/infrastructure/engine/models/Relationship.py b/cognee/infrastructure/engine/models/Relationship.py new file mode 100644 index 000000000..b2ca6ff1a --- /dev/null +++ b/cognee/infrastructure/engine/models/Relationship.py @@ -0,0 +1,70 @@ +from datetime import datetime, timezone +from typing import Optional, Any, Dict +from uuid import UUID, uuid4 +from pydantic import BaseModel, Field +from typing_extensions import TypedDict +import pickle + +# Define metadata type +class RelationshipMetaData(TypedDict): + index_fields: list[str] + + +class Relationship(BaseModel): + __tablename__ = "relationship" + id: UUID = Field(default_factory=uuid4) + source_id: UUID # ID of the source node + target_id: UUID # ID of the target node + relationship_type: str # Type of relationship + weight: Optional[float] = None # Weight of the edge (optional) + created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) + updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) + version: str = "0.1" + _metadata: Optional[RelationshipMetaData] = { + "index_fields": [], + "type": "Relationship" + } + + class Config: + underscore_attrs_are_private = True + + def update_version(self, new_version: str): + """Update the version and updated_at timestamp.""" + self.version = new_version + self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000) + + def to_json(self) -> str: + """Serialize the instance to a JSON string.""" + return self.json() + + @classmethod + def from_json(cls, json_str: str): + """Deserialize the instance from a JSON string.""" + return cls.model_validate_json(json_str) + + def to_pickle(self) -> bytes: + """Serialize the instance to pickle-compatible bytes.""" + return pickle.dumps(self.dict()) + + @classmethod + def from_pickle(cls, pickled_data: bytes): + """Deserialize the instance from pickled bytes.""" + data = pickle.loads(pickled_data) + return cls(**data) + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Serialize model to a dictionary.""" + return self.model_dump(**kwargs) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Relationship": + """Deserialize model from a dictionary.""" + return cls.model_validate(data) + + def get_embeddable_properties(self): + """Retrieve embeddable properties for edge embeddings.""" + return {field: getattr(self, field, None) for field in self._metadata["index_fields"]} + + def get_embeddable_property_names(self): + """Retrieve names of embeddable properties.""" + return self._metadata["index_fields"]