Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add versioning to the data point model #378

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
2 changes: 2 additions & 0 deletions .github/workflows/profiling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ jobs:
run: |
poetry install --no-interaction --all-extras
poetry run pip install pyinstrument
poetry run pip install parso
poetry run pip install jedi


# Set environment variables for SHAs
Expand Down
3 changes: 3 additions & 0 deletions cognee/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class BaseConfig(BaseSettings):
monitoring_tool: object = MonitoringTool.LANGFUSE
graphistry_username: Optional[str] = os.getenv("GRAPHISTRY_USERNAME")
graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD")
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY")
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")

model_config = SettingsConfigDict(env_file = ".env", extra = "allow")

Expand Down
64 changes: 54 additions & 10 deletions cognee/infrastructure/engine/models/DataPoint.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,35 @@


from datetime import datetime, timezone
from typing import Optional
from typing import Optional, Any, Dict
from uuid import UUID, uuid4

from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import pickle


# Define metadata type
class MetaData(TypedDict):
index_fields: list[str]


# Updated DataPoint model with versioning and new fields
class DataPoint(BaseModel):
__tablename__ = "data_point"
id: UUID = Field(default_factory = uuid4)
updated_at: Optional[datetime] = datetime.now(timezone.utc)
id: UUID = Field(default_factory=uuid4)
created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between datetime.now(timezone.utc) and this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created_at is when the initial record was created, updated at is any change that happens

updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
Comment on lines +20 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Simplify timestamp creation and add validation

The timestamp creation could be simplified and should validate against negative values.

Consider this improvement:

-    created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
-    updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
+    created_at: int = Field(
+        default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000),
+        ge=0
+    )
+    updated_at: int = Field(
+        default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000),
+        ge=0
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
created_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000),
ge=0
)
updated_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000),
ge=0
)

version: int = 1 # Default version
type: Optional[str] = "text" # "text", "file", "image", "video"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for type, doesn't belong here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Laslzo asked me for this one, due to retriever logic. In general I agree

topological_rank: Optional[int] = 0
_metadata: Optional[MetaData] = {
"index_fields": [],
"type": "DataPoint"
}

# class Config:
# underscore_attrs_are_private = True
# Override the Pydantic configuration
class Config:
underscore_attrs_are_private = True

@classmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double @classmethod

def get_embeddable_data(self, data_point):
Expand All @@ -30,16 +39,51 @@ def get_embeddable_data(self, data_point):

if isinstance(attribute, str):
return attribute.strip()
else:
return attribute
return attribute

@classmethod
def get_embeddable_properties(self, data_point):
"""Retrieve all embeddable properties."""
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]]

return []

@classmethod
def get_embeddable_property_names(self, data_point):
return data_point._metadata["index_fields"] or []
"""Retrieve names of embeddable properties."""
return data_point._metadata["index_fields"] or []

def update_version(self):
"""Update the version and updated_at timestamp."""
self.version += 1
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)

# JSON Serialization
def to_json(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this serialization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you can parallelize tasks, since you had issues with that. Pickle or json

"""Serialize the instance to a JSON string."""
return self.json()

@classmethod
def from_json(self, json_str: str):
"""Deserialize the instance from a JSON string."""
return self.model_validate_json(json_str)

# Pickle Serialization
def to_pickle(self) -> bytes:
"""Serialize the instance to pickle-compatible bytes."""
return pickle.dumps(self.dict())

@classmethod
def from_pickle(self, pickled_data: bytes):
"""Deserialize the instance from pickled bytes."""
data = pickle.loads(pickled_data)
return self(**data)
Comment on lines +71 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Security concern: Replace pickle with a safer serialization method

Using pickle for serialization poses a security risk as it can execute arbitrary code during deserialization. Consider using a safer alternative like JSON or MessagePack.

-    def to_pickle(self) -> bytes:
-        """Serialize the instance to pickle-compatible bytes."""
-        return pickle.dumps(self.dict())
-
-    @classmethod
-    def from_pickle(self, pickled_data: bytes):
-        """Deserialize the instance from pickled bytes."""
-        data = pickle.loads(pickled_data)
-        return self(**data)
+    def to_bytes(self) -> bytes:
+        """Serialize the instance to bytes using JSON."""
+        return self.json().encode('utf-8')
+
+    @classmethod
+    def from_bytes(cls, data: bytes):
+        """Deserialize the instance from JSON bytes."""
+        return cls.parse_raw(data)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Pickle Serialization
def to_pickle(self) -> bytes:
"""Serialize the instance to pickle-compatible bytes."""
return pickle.dumps(self.dict())
@classmethod
def from_pickle(self, pickled_data: bytes):
"""Deserialize the instance from pickled bytes."""
data = pickle.loads(pickled_data)
return self(**data)
def to_bytes(self) -> bytes:
"""Serialize the instance to bytes using JSON."""
return self.json().encode('utf-8')
@classmethod
def from_bytes(cls, data: bytes):
"""Deserialize the instance from JSON bytes."""
return cls.parse_raw(data)


def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Serialize model to a dictionary."""
return self.model_dump(**kwargs)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DataPoint":
"""Deserialize model from a dictionary."""
return cls.model_validate(data)
9 changes: 8 additions & 1 deletion cognee/infrastructure/llm/openai/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import litellm
import instructor
from pydantic import BaseModel

from cognee.shared.data_models import MonitoringTool
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.base_config import get_base_config

class OpenAIAdapter(LLMInterface):
name = "OpenAI"
Expand All @@ -35,6 +36,12 @@ def __init__(
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
base_config = get_base_config()
if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
self.aclient.success_callback = ["langfuse"]
self.aclient.failure_callback = ["langfuse"]
self.client.success_callback = ["langfuse"]
self.client.failure_callback = ["langfuse"]

async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
Expand Down