diff --git a/cognee/base_config.py b/cognee/base_config.py index 6b1b8811..085ede2c 100644 --- a/cognee/base_config.py +++ b/cognee/base_config.py @@ -16,6 +16,7 @@ class BaseConfig(BaseSettings): langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST") model_config = SettingsConfigDict(env_file=".env", extra="allow") + def to_dict(self) -> dict: return { "data_root_directory": self.data_root_directory, diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index db0d9308..60ba8515 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -1,24 +1,33 @@ + + from datetime import datetime, timezone -from typing import Optional +from typing import Optional, Any, Dict from uuid import UUID, uuid4 from pydantic import BaseModel, Field from typing_extensions import TypedDict +import pickle - +# Define metadata type class MetaData(TypedDict): index_fields: list[str] + +# Updated DataPoint model with versioning and new fields class DataPoint(BaseModel): __tablename__ = "data_point" id: UUID = Field(default_factory=uuid4) - updated_at: Optional[datetime] = datetime.now(timezone.utc) + created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) + updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) + version: int = 1 # Default version + type: Optional[str] = "text" # "text", "file", "image", "video" topological_rank: Optional[int] = 0 _metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"} - # class Config: - # underscore_attrs_are_private = True + # Override the Pydantic configuration + class Config: + underscore_attrs_are_private = True @classmethod def get_embeddable_data(self, data_point): @@ -31,12 +40,13 @@ def get_embeddable_data(self, data_point): if isinstance(attribute, str): return attribute.strip() - else: - return attribute + return attribute @classmethod def get_embeddable_properties(self, data_point): + """Retrieve all embeddable properties.""" if data_point._metadata and len(data_point._metadata["index_fields"]) > 0: + return [ getattr(data_point, field, None) for field in data_point._metadata["index_fields"] ] @@ -45,4 +55,42 @@ def get_embeddable_properties(self, data_point): @classmethod def get_embeddable_property_names(self, data_point): + + """Retrieve names of embeddable properties.""" return data_point._metadata["index_fields"] or [] + + def update_version(self): + """Update the version and updated_at timestamp.""" + self.version += 1 + self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000) + + # JSON Serialization + def to_json(self) -> str: + """Serialize the instance to a JSON string.""" + return self.json() + + @classmethod + def from_json(self, json_str: str): + """Deserialize the instance from a JSON string.""" + return self.model_validate_json(json_str) + + # Pickle Serialization + def to_pickle(self) -> bytes: + """Serialize the instance to pickle-compatible bytes.""" + return pickle.dumps(self.dict()) + + @classmethod + def from_pickle(self, pickled_data: bytes): + """Deserialize the instance from pickled bytes.""" + data = pickle.loads(pickled_data) + return self(**data) + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Serialize model to a dictionary.""" + return self.model_dump(**kwargs) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DataPoint": + """Deserialize model from a dictionary.""" + return cls.model_validate(data) + diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index d4566238..6ed5f3c4 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -12,6 +12,7 @@ from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.base_config import get_base_config + monitoring = get_base_config().monitoring_tool if monitoring == MonitoringTool.LANGFUSE: from langfuse.decorators import observe @@ -43,10 +44,20 @@ def __init__( self.api_version = api_version self.streaming = streaming - @observe(as_type="generation") - async def acreate_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] - ) -> BaseModel: + base_config = get_base_config() + + if base_config.monitoring_tool == MonitoringTool.LANGFUSE: + self.aclient.success_callback = ["langfuse"] + self.aclient.failure_callback = ["langfuse"] + self.client.success_callback = ["langfuse"] + self.client.failure_callback = ["langfuse"] + + + + @observe(as_type='generation') + async def acreate_structured_output(self, text_input: str, system_prompt: str, + response_model: Type[BaseModel]) -> BaseModel: + """Generate a response from a user query.""" return await self.aclient.chat.completions.create(