Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: First draft of relationship embeddings #379

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from lancedb.pydantic import Vector, LanceModel

from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine import DataPoint, Relationship
from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties
from ..models.ScoredResult import ScoredResult
Expand Down Expand Up @@ -72,6 +72,49 @@ class LanceDataPoint(LanceModel):
exist_ok = True,
)



async def create_relationships(self, collection_name: str, relationships: list[Relationship]):
"""Create and store Relationship embeddings in LanceDB."""
connection = await self.get_connection()

# Ensure collection exists
if not await self.has_collection(collection_name):
await self.create_collection(collection_name, Relationship)

collection = await connection.open_table(collection_name)

# Generate embeddings
data_vectors = await self.embed_data([
" ".join([str(v) for v in rel.get_embeddable_properties().values()])
for rel in relationships
])
Comment on lines +88 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle potential None values in embeddable properties

When generating embeddings, if any embeddable property of a relationship is None, converting it to a string will result in the string "None", which might negatively impact the embeddings.

Apply this diff to filter out None values:

 data_vectors = await self.embed_data([
-    " ".join([str(v) for v in rel.get_embeddable_properties().values()])
+    " ".join([str(v) for v in rel.get_embeddable_properties().values() if v is not None])
     for rel in relationships
 ])

This ensures that only valid property values are included in the embeddings.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_vectors = await self.embed_data([
" ".join([str(v) for v in rel.get_embeddable_properties().values()])
for rel in relationships
])
data_vectors = await self.embed_data([
" ".join([str(v) for v in rel.get_embeddable_properties().values() if v is not None])
for rel in relationships
])


# Dynamic LanceDataPoint class for Relationship
vector_size = self.embedding_engine.get_vector_size()

class LanceRelationship(LanceModel):
id: str
vector: Vector(vector_size)
payload: dict

# Prepare LanceDB-compatible data points
lance_relationships = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for this, we have the index datapoints method, which indexes the datapoint pydantic object based on its index_fields property.

LanceRelationship(
id=str(rel.id),
vector=data_vectors[index],
payload=rel.to_dict()
)
for index, rel in enumerate(relationships)
]

# Insert relationships into LanceDB
await collection.merge_insert("id") \
.when_matched_update_all() \
.when_not_matched_insert_all() \
.execute(lance_relationships)
print(f"Inserted {len(relationships)} relationships into LanceDB")

async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
connection = await self.get_connection()

Expand Down
1 change: 1 addition & 0 deletions cognee/infrastructure/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .models.DataPoint import DataPoint
from .models.Relationship import Relationship
70 changes: 70 additions & 0 deletions cognee/infrastructure/engine/models/Relationship.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from datetime import datetime, timezone
from typing import Optional, Any, Dict
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import pickle

# Define metadata type
class RelationshipMetaData(TypedDict):
index_fields: list[str]


class Relationship(BaseModel):
__tablename__ = "relationship"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edge-type embeddings are already implemented here:
#251

This collected the distinct edge types, creates a pedantic model for edge types: https://github.com/topoteretes/cognee/pull/251/files#diff-676c14525be801de96c5734e9d56bb784f3aaf40fc60d4f030a16e06f17317f9

and embed the relationship name into the edge_type collection.

id: UUID = Field(default_factory=uuid4)
source_id: UUID # ID of the source node
target_id: UUID # ID of the target node
relationship_type: str # Type of relationship
weight: Optional[float] = None # Weight of the edge (optional)
created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
version: str = "0.1"
_metadata: Optional[RelationshipMetaData] = {
"index_fields": [],
"type": "Relationship"
}
Comment on lines +23 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Correct the _metadata field type or default value

The _metadata field is annotated as Optional[RelationshipMetaData] but is assigned a default value that includes the key "type", which is not defined in the RelationshipMetaData TypedDict. This can lead to type inconsistencies and potential runtime errors.

Option 1: Update RelationshipMetaData to include the "type" key.

 class RelationshipMetaData(TypedDict):
     index_fields: list[str]
+    type: str

Option 2: Adjust the _metadata default value to match the specified type.

 _metadata: Optional[RelationshipMetaData] = {
     "index_fields": [],
-    "type": "Relationship"
 }

Ensure that the _metadata field's type annotation matches its default value.

Committable suggestion skipped: line range outside the PR's diff.


class Config:
underscore_attrs_are_private = True

def update_version(self, new_version: str):
"""Update the version and updated_at timestamp."""
self.version = new_version
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)

def to_json(self) -> str:
"""Serialize the instance to a JSON string."""
return self.json()

@classmethod
def from_json(cls, json_str: str):
"""Deserialize the instance from a JSON string."""
return cls.model_validate_json(json_str)

def to_pickle(self) -> bytes:
"""Serialize the instance to pickle-compatible bytes."""
return pickle.dumps(self.dict())

@classmethod
def from_pickle(cls, pickled_data: bytes):
"""Deserialize the instance from pickled bytes."""
data = pickle.loads(pickled_data)
return cls(**data)

def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Serialize model to a dictionary."""
return self.model_dump(**kwargs)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Relationship":
"""Deserialize model from a dictionary."""
return cls.model_validate(data)

def get_embeddable_properties(self):
"""Retrieve embeddable properties for edge embeddings."""
return {field: getattr(self, field, None) for field in self._metadata["index_fields"]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure _metadata is not None before accessing keys

In the get_embeddable_properties method, accessing self._metadata["index_fields"] without checking if _metadata is not None can lead to an exception if _metadata is None.

Apply this diff to add a check:

 def get_embeddable_properties(self):
+    if self._metadata and "index_fields" in self._metadata:
         return {field: getattr(self, field, None) for field in self._metadata["index_fields"]}
+    return {}

This ensures that you only access index_fields when _metadata is properly initialized.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return {field: getattr(self, field, None) for field in self._metadata["index_fields"]}
if self._metadata and "index_fields" in self._metadata:
return {field: getattr(self, field, None) for field in self._metadata["index_fields"]}
return {}


def get_embeddable_property_names(self):
"""Retrieve names of embeddable properties."""
return self._metadata["index_fields"]
Comment on lines +69 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle potential None in _metadata when accessing index_fields

Similarly, in the get_embeddable_property_names method, ensure that _metadata is not None before accessing index_fields.

Apply this diff to add a check:

 def get_embeddable_property_names(self):
+    if self._metadata and "index_fields" in self._metadata:
         return self._metadata["index_fields"]
+    return []

This prevents potential TypeError or KeyError exceptions.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""Retrieve names of embeddable properties."""
return self._metadata["index_fields"]
def get_embeddable_property_names(self):
"""Retrieve names of embeddable properties."""
if self._metadata and "index_fields" in self._metadata:
return self._metadata["index_fields"]
return []

Loading