Skip to content

Commit

Permalink
First draft of relationship embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Vasilije1990 committed Dec 17, 2024
1 parent 92ecd8a commit 4545f22
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 1 deletion.
45 changes: 44 additions & 1 deletion cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from lancedb.pydantic import Vector, LanceModel

from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine import DataPoint, Relationship
from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties
from ..models.ScoredResult import ScoredResult
Expand Down Expand Up @@ -72,6 +72,49 @@ class LanceDataPoint(LanceModel):
exist_ok = True,
)



async def create_relationships(self, collection_name: str, relationships: list[Relationship]):
"""Create and store Relationship embeddings in LanceDB."""
connection = await self.get_connection()

# Ensure collection exists
if not await self.has_collection(collection_name):
await self.create_collection(collection_name, Relationship)

collection = await connection.open_table(collection_name)

# Generate embeddings
data_vectors = await self.embed_data([
" ".join([str(v) for v in rel.get_embeddable_properties().values()])
for rel in relationships
])

# Dynamic LanceDataPoint class for Relationship
vector_size = self.embedding_engine.get_vector_size()

class LanceRelationship(LanceModel):
id: str
vector: Vector(vector_size)
payload: dict

# Prepare LanceDB-compatible data points
lance_relationships = [
LanceRelationship(
id=str(rel.id),
vector=data_vectors[index],
payload=rel.to_dict()
)
for index, rel in enumerate(relationships)
]

# Insert relationships into LanceDB
await collection.merge_insert("id") \
.when_matched_update_all() \
.when_not_matched_insert_all() \
.execute(lance_relationships)
print(f"Inserted {len(relationships)} relationships into LanceDB")

async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
connection = await self.get_connection()

Expand Down
1 change: 1 addition & 0 deletions cognee/infrastructure/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .models.DataPoint import DataPoint
from .models.Relationship import Relationship
70 changes: 70 additions & 0 deletions cognee/infrastructure/engine/models/Relationship.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from datetime import datetime, timezone
from typing import Optional, Any, Dict
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import pickle

# Define metadata type
class RelationshipMetaData(TypedDict):
index_fields: list[str]


class Relationship(BaseModel):
__tablename__ = "relationship"
id: UUID = Field(default_factory=uuid4)
source_id: UUID # ID of the source node
target_id: UUID # ID of the target node
relationship_type: str # Type of relationship
weight: Optional[float] = None # Weight of the edge (optional)
created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
version: str = "0.1"
_metadata: Optional[RelationshipMetaData] = {
"index_fields": [],
"type": "Relationship"
}

class Config:
underscore_attrs_are_private = True

def update_version(self, new_version: str):
"""Update the version and updated_at timestamp."""
self.version = new_version
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)

def to_json(self) -> str:
"""Serialize the instance to a JSON string."""
return self.json()

@classmethod
def from_json(cls, json_str: str):
"""Deserialize the instance from a JSON string."""
return cls.model_validate_json(json_str)

def to_pickle(self) -> bytes:
"""Serialize the instance to pickle-compatible bytes."""
return pickle.dumps(self.dict())

@classmethod
def from_pickle(cls, pickled_data: bytes):
"""Deserialize the instance from pickled bytes."""
data = pickle.loads(pickled_data)
return cls(**data)

def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Serialize model to a dictionary."""
return self.model_dump(**kwargs)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Relationship":
"""Deserialize model from a dictionary."""
return cls.model_validate(data)

def get_embeddable_properties(self):
"""Retrieve embeddable properties for edge embeddings."""
return {field: getattr(self, field, None) for field in self._metadata["index_fields"]}

def get_embeddable_property_names(self):
"""Retrieve names of embeddable properties."""
return self._metadata["index_fields"]

0 comments on commit 4545f22

Please sign in to comment.