-
Notifications
You must be signed in to change notification settings - Fork 789
Expand file tree
/
Copy pathmem_reader.py
More file actions
111 lines (87 loc) · 3.94 KB
/
mem_reader.py
File metadata and controls
111 lines (87 loc) · 3.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from datetime import datetime
from typing import Any, ClassVar
from pydantic import ConfigDict, Field, field_validator, model_validator
from memos.configs.base import BaseConfig
from memos.configs.chunker import ChunkerConfigFactory
from memos.configs.embedder import EmbedderConfigFactory
from memos.configs.llm import LLMConfigFactory
class BaseMemReaderConfig(BaseConfig):
"""Base configuration class for MemReader."""
created_at: datetime = Field(
default_factory=datetime.now, description="Creation timestamp for the MemReader"
)
@field_validator("created_at", mode="before")
@classmethod
def parse_datetime(cls, value):
"""Parse datetime from string if needed."""
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00"))
return value
llm: LLMConfigFactory = Field(
..., description="LLM configuration for chat/doc memory extraction (fine-tuned model)"
)
general_llm: LLMConfigFactory | None = Field(
default=None,
description="General LLM for non-chat/doc tasks: hallucination filter, memory rewrite, "
"memory merge, tool trajectory, skill memory. Falls back to main llm if not set.",
)
image_parser_llm: LLMConfigFactory | None = Field(
default=None,
description="Vision LLM for image parsing. Falls back to general_llm if not set.",
)
embedder: EmbedderConfigFactory = Field(
..., description="Embedder configuration for the MemReader"
)
chunker: ChunkerConfigFactory = Field(
..., description="Chunker configuration for the MemReader"
)
remove_prompt_example: bool = Field(
default=False,
description="whether remove example in memory extraction prompt to save token",
)
chat_chunker: dict[str, Any] = Field(
default=None, description="Configuration for the MemReader chat chunk strategy"
)
class SimpleStructMemReaderConfig(BaseMemReaderConfig):
"""SimpleStruct MemReader configuration class."""
# Allow passing additional fields without raising validation errors
model_config = ConfigDict(extra="allow", strict=True)
class MultiModalStructMemReaderConfig(BaseMemReaderConfig):
"""MultiModalStruct MemReader configuration class."""
direct_markdown_hostnames: list[str] | None = Field(
default=None,
description="List of hostnames that should return markdown directly without parsing. "
"If None, reads from FILE_PARSER_DIRECT_MARKDOWN_HOSTNAMES environment variable.",
)
oss_config: dict[str, Any] | None = Field(
default=None,
description="OSS configuration for the MemReader",
)
skills_dir_config: dict[str, Any] | None = Field(
default=None,
description="Skills directory for the MemReader",
)
class StrategyStructMemReaderConfig(BaseMemReaderConfig):
"""StrategyStruct MemReader configuration class."""
model_config = ConfigDict(extra="allow", strict=True)
class MemReaderConfigFactory(BaseConfig):
"""Factory class for creating MemReader configurations."""
backend: str = Field(..., description="Backend for MemReader")
config: dict[str, Any] = Field(..., description="Configuration for the MemReader backend")
backend_to_class: ClassVar[dict[str, Any]] = {
"simple_struct": SimpleStructMemReaderConfig,
"multimodal_struct": MultiModalStructMemReaderConfig,
"strategy_struct": StrategyStructMemReaderConfig,
}
@field_validator("backend")
@classmethod
def validate_backend(cls, backend: str) -> str:
"""Validate the backend field."""
if backend not in cls.backend_to_class:
raise ValueError(f"Invalid backend: {backend}")
return backend
@model_validator(mode="after")
def create_config(self) -> "MemReaderConfigFactory":
config_class = self.backend_to_class[self.backend]
self.config = config_class(**self.config)
return self