-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: First draft of relationship embeddings #379
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
from lancedb.pydantic import Vector, LanceModel | ||
|
||
from cognee.exceptions import InvalidValueError | ||
from cognee.infrastructure.engine import DataPoint | ||
from cognee.infrastructure.engine import DataPoint, Relationship | ||
from cognee.infrastructure.files.storage import LocalStorage | ||
from cognee.modules.storage.utils import copy_model, get_own_properties | ||
from ..models.ScoredResult import ScoredResult | ||
|
@@ -72,6 +72,49 @@ class LanceDataPoint(LanceModel): | |
exist_ok = True, | ||
) | ||
|
||
|
||
|
||
async def create_relationships(self, collection_name: str, relationships: list[Relationship]): | ||
"""Create and store Relationship embeddings in LanceDB.""" | ||
connection = await self.get_connection() | ||
|
||
# Ensure collection exists | ||
if not await self.has_collection(collection_name): | ||
await self.create_collection(collection_name, Relationship) | ||
|
||
collection = await connection.open_table(collection_name) | ||
|
||
# Generate embeddings | ||
data_vectors = await self.embed_data([ | ||
" ".join([str(v) for v in rel.get_embeddable_properties().values()]) | ||
for rel in relationships | ||
]) | ||
|
||
# Dynamic LanceDataPoint class for Relationship | ||
vector_size = self.embedding_engine.get_vector_size() | ||
|
||
class LanceRelationship(LanceModel): | ||
id: str | ||
vector: Vector(vector_size) | ||
payload: dict | ||
|
||
# Prepare LanceDB-compatible data points | ||
lance_relationships = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think for this, we have the index datapoints method, which indexes the datapoint pydantic object based on its index_fields property. |
||
LanceRelationship( | ||
id=str(rel.id), | ||
vector=data_vectors[index], | ||
payload=rel.to_dict() | ||
) | ||
for index, rel in enumerate(relationships) | ||
] | ||
|
||
# Insert relationships into LanceDB | ||
await collection.merge_insert("id") \ | ||
.when_matched_update_all() \ | ||
.when_not_matched_insert_all() \ | ||
.execute(lance_relationships) | ||
print(f"Inserted {len(relationships)} relationships into LanceDB") | ||
|
||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): | ||
connection = await self.get_connection() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .models.DataPoint import DataPoint | ||
from .models.Relationship import Relationship |
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,70 @@ | ||||||||||||||||
from datetime import datetime, timezone | ||||||||||||||||
from typing import Optional, Any, Dict | ||||||||||||||||
from uuid import UUID, uuid4 | ||||||||||||||||
from pydantic import BaseModel, Field | ||||||||||||||||
from typing_extensions import TypedDict | ||||||||||||||||
import pickle | ||||||||||||||||
|
||||||||||||||||
# Define metadata type | ||||||||||||||||
class RelationshipMetaData(TypedDict): | ||||||||||||||||
index_fields: list[str] | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class Relationship(BaseModel): | ||||||||||||||||
__tablename__ = "relationship" | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Edge-type embeddings are already implemented here: This collected the distinct edge types, creates a pedantic model for edge types: https://github.com/topoteretes/cognee/pull/251/files#diff-676c14525be801de96c5734e9d56bb784f3aaf40fc60d4f030a16e06f17317f9 and embed the relationship name into the edge_type collection. |
||||||||||||||||
id: UUID = Field(default_factory=uuid4) | ||||||||||||||||
source_id: UUID # ID of the source node | ||||||||||||||||
target_id: UUID # ID of the target node | ||||||||||||||||
relationship_type: str # Type of relationship | ||||||||||||||||
weight: Optional[float] = None # Weight of the edge (optional) | ||||||||||||||||
created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) | ||||||||||||||||
updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) | ||||||||||||||||
version: str = "0.1" | ||||||||||||||||
_metadata: Optional[RelationshipMetaData] = { | ||||||||||||||||
"index_fields": [], | ||||||||||||||||
"type": "Relationship" | ||||||||||||||||
} | ||||||||||||||||
Comment on lines
+23
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct the The Option 1: Update class RelationshipMetaData(TypedDict):
index_fields: list[str]
+ type: str Option 2: Adjust the _metadata: Optional[RelationshipMetaData] = {
"index_fields": [],
- "type": "Relationship"
} Ensure that the
|
||||||||||||||||
|
||||||||||||||||
class Config: | ||||||||||||||||
underscore_attrs_are_private = True | ||||||||||||||||
|
||||||||||||||||
def update_version(self, new_version: str): | ||||||||||||||||
"""Update the version and updated_at timestamp.""" | ||||||||||||||||
self.version = new_version | ||||||||||||||||
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000) | ||||||||||||||||
|
||||||||||||||||
def to_json(self) -> str: | ||||||||||||||||
"""Serialize the instance to a JSON string.""" | ||||||||||||||||
return self.json() | ||||||||||||||||
|
||||||||||||||||
@classmethod | ||||||||||||||||
def from_json(cls, json_str: str): | ||||||||||||||||
"""Deserialize the instance from a JSON string.""" | ||||||||||||||||
return cls.model_validate_json(json_str) | ||||||||||||||||
|
||||||||||||||||
def to_pickle(self) -> bytes: | ||||||||||||||||
"""Serialize the instance to pickle-compatible bytes.""" | ||||||||||||||||
return pickle.dumps(self.dict()) | ||||||||||||||||
|
||||||||||||||||
@classmethod | ||||||||||||||||
def from_pickle(cls, pickled_data: bytes): | ||||||||||||||||
"""Deserialize the instance from pickled bytes.""" | ||||||||||||||||
data = pickle.loads(pickled_data) | ||||||||||||||||
return cls(**data) | ||||||||||||||||
|
||||||||||||||||
def to_dict(self, **kwargs) -> Dict[str, Any]: | ||||||||||||||||
"""Serialize model to a dictionary.""" | ||||||||||||||||
return self.model_dump(**kwargs) | ||||||||||||||||
|
||||||||||||||||
@classmethod | ||||||||||||||||
def from_dict(cls, data: Dict[str, Any]) -> "Relationship": | ||||||||||||||||
"""Deserialize model from a dictionary.""" | ||||||||||||||||
return cls.model_validate(data) | ||||||||||||||||
|
||||||||||||||||
def get_embeddable_properties(self): | ||||||||||||||||
"""Retrieve embeddable properties for edge embeddings.""" | ||||||||||||||||
return {field: getattr(self, field, None) for field in self._metadata["index_fields"]} | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure In the Apply this diff to add a check: def get_embeddable_properties(self):
+ if self._metadata and "index_fields" in self._metadata:
return {field: getattr(self, field, None) for field in self._metadata["index_fields"]}
+ return {} This ensures that you only access 📝 Committable suggestion
Suggested change
|
||||||||||||||||
|
||||||||||||||||
def get_embeddable_property_names(self): | ||||||||||||||||
"""Retrieve names of embeddable properties.""" | ||||||||||||||||
return self._metadata["index_fields"] | ||||||||||||||||
Comment on lines
+69
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle potential Similarly, in the Apply this diff to add a check: def get_embeddable_property_names(self):
+ if self._metadata and "index_fields" in self._metadata:
return self._metadata["index_fields"]
+ return [] This prevents potential 📝 Committable suggestion
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle potential
None
values in embeddable propertiesWhen generating embeddings, if any embeddable property of a relationship is
None
, converting it to a string will result in the string"None"
, which might negatively impact the embeddings.Apply this diff to filter out
None
values:This ensures that only valid property values are included in the embeddings.
📝 Committable suggestion