Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add versioning to the data point model #378

Merged
merged 19 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/profiling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ jobs:
run: |
poetry install --no-interaction --all-extras
poetry run pip install pyinstrument
poetry run pip install parso
poetry run pip install jedi


# Set environment variables for SHAs
Expand Down
3 changes: 3 additions & 0 deletions cognee/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class BaseConfig(BaseSettings):
monitoring_tool: object = MonitoringTool.LANGFUSE
graphistry_username: Optional[str] = os.getenv("GRAPHISTRY_USERNAME")
graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD")
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY")
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")

model_config = SettingsConfigDict(env_file = ".env", extra = "allow")

Expand Down
67 changes: 57 additions & 10 deletions cognee/infrastructure/engine/models/DataPoint.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,39 @@


from datetime import datetime, timezone
from typing import Optional
from typing import Optional, Any, Dict
from uuid import UUID, uuid4

from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import pickle

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Organize imports and consider security implications

The imports should be organized according to the standard convention (stdlib, third-party, local). Also, using pickle for serialization poses security risks as it can execute arbitrary code during deserialization.

-
-
from datetime import datetime, timezone
from typing import Optional, Any, Dict
from uuid import UUID, uuid4
+from datetime import datetime, timezone
+from typing import Optional, Any, Dict
+from uuid import UUID, uuid4
+
+import json
+import pickle  # Consider removing in favor of json
+
+from pydantic import BaseModel, Field
+from typing_extensions import TypedDict
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from datetime import datetime, timezone
from typing import Optional
from typing import Optional, Any, Dict
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import pickle
from datetime import datetime, timezone
from typing import Optional, Any, Dict
from uuid import UUID, uuid4
import json
import pickle # Consider removing in favor of json
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
🧰 Tools
🪛 GitHub Actions: ruff format

[warning] File requires formatting using Ruff formatter


# Define metadata type
class MetaData(TypedDict):
index_fields: list[str]


# Updated DataPoint model with versioning and new fields
class DataPoint(BaseModel):
__tablename__ = "data_point"
id: UUID = Field(default_factory = uuid4)
updated_at: Optional[datetime] = datetime.now(timezone.utc)
id: UUID = Field(default_factory=uuid4)
created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between datetime.now(timezone.utc) and this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created_at is when the initial record was created, updated at is any change that happens

updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Simplify timestamp creation and add validation

The timestamp creation could be simplified and should validate against negative values.

Consider this improvement:

-    created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
-    updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
+    created_at: int = Field(
+        default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000),
+        ge=0
+    )
+    updated_at: int = Field(
+        default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000),
+        ge=0
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
created_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000),
ge=0
)
updated_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000),
ge=0
)

version: str = "0.1" # Default version
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep it as a number, and we can just increase it with each version. (1, 2, 3, 4...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kk

source: Optional[str] = None # Path to file, URL, etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

source is a Document model related property, doesn't belong to this general DataPoint model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair

type: Optional[str] = "text" # "text", "file", "image", "video"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add version format validation and type enumeration

The version string and type field should have proper validation.

Consider these improvements:

+from enum import Enum
+import re
+
+class DataPointType(str, Enum):
+    TEXT = "text"
+    FILE = "file"
+    IMAGE = "image"
+    VIDEO = "video"

class DataPoint(BaseModel):
    # ... other fields ...
-    version: str = "0.1"  # Default version
-    type: Optional[str] = "text"  # "text", "file", "image", "video"
+    version: str = Field(
+        default="0.1",
+        regex=r"^\d+\.\d+$"
+    )
+    type: Optional[DataPointType] = Field(default=DataPointType.TEXT)

Committable suggestion skipped: line range outside the PR's diff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for type, doesn't belong here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Laslzo asked me for this one, due to retriever logic. In general I agree

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the only thing is that the type should be the pydantic type. Like "Entity", "TextSummary" etc.

topological_rank: Optional[int] = 0
extra: Optional[str] = "extra" # For additional properties
Vasilije1990 marked this conversation as resolved.
Show resolved Hide resolved
_metadata: Optional[MetaData] = {
"index_fields": [],
"type": "DataPoint"
}

# class Config:
# underscore_attrs_are_private = True
# Override the Pydantic configuration
class Config:
underscore_attrs_are_private = True

@classmethod
Vasilije1990 marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove duplicate @classmethod decorator

The @classmethod decorator is duplicated.

-    @classmethod
-    @classmethod
+    @classmethod
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@classmethod
@classmethod
@classmethod

def get_embeddable_data(self, data_point):
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0 \
Expand All @@ -30,16 +42,51 @@ def get_embeddable_data(self, data_point):

if isinstance(attribute, str):
return attribute.strip()
else:
return attribute
return attribute

@classmethod
def get_embeddable_properties(self, data_point):
"""Retrieve all embeddable properties."""
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]]

return []

@classmethod
def get_embeddable_property_names(self, data_point):
return data_point._metadata["index_fields"] or []
"""Retrieve names of embeddable properties."""
return data_point._metadata["index_fields"] or []

def update_version(self, new_version: str):
"""Update the version and updated_at timestamp."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a number as a version, we can do +1 here then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

self.version = new_version
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve version update method with validation

The update_version method should validate the version format and reuse timestamp logic.

Consider this improvement:

     def update_version(self, new_version: str):
         """Update the version and updated_at timestamp."""
+        if not re.match(r"^\d+\.\d+$", new_version):
+            raise ValueError("Version must be in format 'X.Y'")
         self.version = new_version
-        self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)
+        self.updated_at = Field(
+            default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
+        ).default_factory()

Committable suggestion skipped: line range outside the PR's diff.

self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)

# JSON Serialization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this serialization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you can parallelize tasks, since you had issues with that. Pickle or json

def to_json(self) -> str:
"""Serialize the instance to a JSON string."""
return self.json()

@classmethod
def from_json(self, json_str: str):
"""Deserialize the instance from a JSON string."""
return self.model_validate_json(json_str)

# Pickle Serialization
def to_pickle(self) -> bytes:
"""Serialize the instance to pickle-compatible bytes."""
return pickle.dumps(self.dict())

@classmethod
def from_pickle(self, pickled_data: bytes):
"""Deserialize the instance from pickled bytes."""
data = pickle.loads(pickled_data)
Comment on lines +76 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Security concern: Replace pickle with a safer serialization method

Using pickle for serialization poses a security risk as it can execute arbitrary code during deserialization. Consider using a safer alternative like JSON or MessagePack.

-    def to_pickle(self) -> bytes:
-        """Serialize the instance to pickle-compatible bytes."""
-        return pickle.dumps(self.dict())
-
-    @classmethod
-    def from_pickle(self, pickled_data: bytes):
-        """Deserialize the instance from pickled bytes."""
-        data = pickle.loads(pickled_data)
-        return self(**data)
+    def to_bytes(self) -> bytes:
+        """Serialize the instance to bytes using JSON."""
+        return self.json().encode('utf-8')
+
+    @classmethod
+    def from_bytes(cls, data: bytes):
+        """Deserialize the instance from JSON bytes."""
+        return cls.parse_raw(data)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Pickle Serialization
def to_pickle(self) -> bytes:
"""Serialize the instance to pickle-compatible bytes."""
return pickle.dumps(self.dict())
@classmethod
def from_pickle(self, pickled_data: bytes):
"""Deserialize the instance from pickled bytes."""
data = pickle.loads(pickled_data)
return self(**data)
def to_bytes(self) -> bytes:
"""Serialize the instance to bytes using JSON."""
return self.json().encode('utf-8')
@classmethod
def from_bytes(cls, data: bytes):
"""Deserialize the instance from JSON bytes."""
return cls.parse_raw(data)

return self(**data)

def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Serialize model to a dictionary."""
return self.model_dump(**kwargs)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DataPoint":
"""Deserialize model from a dictionary."""
return cls.model_validate(data)
12 changes: 11 additions & 1 deletion cognee/infrastructure/llm/openai/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import litellm
import instructor
from pydantic import BaseModel

from cognee.shared.data_models import MonitoringTool
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.base_config import get_base_config

class OpenAIAdapter(LLMInterface):
name = "OpenAI"
Expand All @@ -35,6 +36,15 @@ def __init__(
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
base_config = get_base_config()
if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
# set callbacks
# litellm.success_callback = ["langfuse"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove these commented lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# litellm.failure_callback = ["langfuse"]
self.aclient.success_callback = ["langfuse"]
self.aclient.failure_callback = ["langfuse"]
self.client.success_callback = ["langfuse"]
self.client.failure_callback = ["langfuse"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor callback configuration and add error handling

The current implementation has several areas for improvement:

  1. Duplicate callback configuration for both clients
  2. Missing error handling for base_config
  3. Missing documentation for the monitoring feature

Consider refactoring like this:

+    def _configure_langfuse_callbacks(self, client):
+        """Configure Langfuse callbacks for the given client."""
+        client.success_callback = ["langfuse"]
+        client.failure_callback = ["langfuse"]

     def __init__(
         self,
         api_key: str,
         endpoint: str,
         api_version: str,
         model: str,
         transcription_model: str,
         streaming: bool = False,
     ):
+        """Initialize OpenAI adapter with optional Langfuse monitoring.
+        
+        Args:
+            api_key (str): OpenAI API key
+            endpoint (str): API endpoint
+            api_version (str): API version
+            model (str): Model identifier
+            transcription_model (str): Model for transcription
+            streaming (bool, optional): Enable streaming. Defaults to False.
+        """
         self.aclient = instructor.from_litellm(litellm.acompletion)
         self.client = instructor.from_litellm(litellm.completion)
         self.transcription_model = transcription_model
         self.model = model
         self.api_key = api_key
         self.endpoint = endpoint
         self.api_version = api_version
         self.streaming = streaming

-        base_config = get_base_config()
-        if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
-            self.aclient.success_callback = ["langfuse"]
-            self.aclient.failure_callback = ["langfuse"]
-            self.client.success_callback = ["langfuse"]
-            self.client.failure_callback = ["langfuse"]
+        try:
+            base_config = get_base_config()
+            if base_config and base_config.monitoring_tool == MonitoringTool.LANGFUSE:
+                self._configure_langfuse_callbacks(self.aclient)
+                self._configure_langfuse_callbacks(self.client)
+        except Exception as e:
+            # Log the error but don't fail initialization
+            print(f"Warning: Failed to configure monitoring: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
base_config = get_base_config()
if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
# set callbacks
# litellm.success_callback = ["langfuse"]
# litellm.failure_callback = ["langfuse"]
self.aclient.success_callback = ["langfuse"]
self.aclient.failure_callback = ["langfuse"]
self.client.success_callback = ["langfuse"]
self.client.failure_callback = ["langfuse"]
def _configure_langfuse_callbacks(self, client):
"""Configure Langfuse callbacks for the given client."""
client.success_callback = ["langfuse"]
client.failure_callback = ["langfuse"]
def __init__(
self,
api_key: str,
endpoint: str,
api_version: str,
model: str,
transcription_model: str,
streaming: bool = False,
):
"""Initialize OpenAI adapter with optional Langfuse monitoring.
Args:
api_key (str): OpenAI API key
endpoint (str): API endpoint
api_version (str): API version
model (str): Model identifier
transcription_model (str): Model for transcription
streaming (bool, optional): Enable streaming. Defaults to False.
"""
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
try:
base_config = get_base_config()
if base_config and base_config.monitoring_tool == MonitoringTool.LANGFUSE:
self._configure_langfuse_callbacks(self.aclient)
self._configure_langfuse_callbacks(self.client)
except Exception as e:
# Log the error but don't fail initialization
print(f"Warning: Failed to configure monitoring: {str(e)}")

💡 Codebase verification

Missing Langfuse environment documentation and host configuration

The Langfuse integration requires proper documentation and configuration:

  • The required environment variables (LANGFUSE_PUBLIC_KEY, LANGFUSE_SECRET_KEY, LANGFUSE_HOST) are not documented in any .md files
  • LANGFUSE_HOST is accessed directly with os.environ[] instead of os.getenv(), which could cause crashes if not set
🔗 Analysis chain

Verify Langfuse configuration requirements

Since Langfuse monitoring is being added, we should ensure all necessary configuration is present.


Let me check the actual Langfuse configuration content to verify if all necessary settings are present.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for Langfuse configuration in base config and environment variables

# Check if BaseConfig contains Langfuse configuration
echo "Checking BaseConfig for Langfuse configuration..."
rg -l "langfuse" --type py

# Check for environment variable documentation
echo "Checking for environment variable documentation..."
find . -type f -name "*.md" -exec grep -l "LANGFUSE" {} \;

Length of output: 456


Script:

#!/bin/bash
# Check Langfuse configuration details

# Check BaseConfig for Langfuse settings
echo "=== BaseConfig Langfuse settings ==="
rg -A 5 -B 5 "langfuse" cognee/base_config.py

# Check environment variables and configuration in data models
echo -e "\n=== Data models Langfuse configuration ==="
rg -A 5 -B 5 "langfuse" cognee/shared/data_models.py

# Check if Langfuse secret/public key settings exist
echo -e "\n=== Environment variables for Langfuse ==="
rg -i "LANGFUSE.*KEY|LANGFUSE.*SECRET" .

Length of output: 1434


async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
Expand Down
Loading