Skip to content

Commit

Permalink
Merge pull request #403 from webcoderz/main
Browse files Browse the repository at this point in the history
fix pyproject bug
  • Loading branch information
zzstoatzz authored Feb 5, 2025
2 parents daf0883 + 8554771 commit ca4adb0
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 319 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ tests = [
"langchain_community",
"langchain_google_genai",
"langchain_groq",
"langchain-ollama',
"langchain-ollama",
"pytest-asyncio>=0.18.2,!=0.22.0,<0.23.0",
"pytest-env>=0.8,<2.0",
"pytest-rerunfailures>=10,<14",
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/memory/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class PostgresMemory(MemoryProvider):
description="Dimension of the embedding vectors. Match your model's output.",
)

embedding_fn: Callable = Field(
embedding_fn: OpenAIEmbeddings = Field(
default_factory=lambda: OpenAIEmbeddings(
model="text-embedding-ada-002",
),
Expand Down Expand Up @@ -263,7 +263,7 @@ class AsyncPostgresMemory(AsyncMemoryProvider):
description="Dimension of the embedding vectors. Must match your model output size.",
)

embedding_fn: Callable = Field(
embedding_fn: OpenAIEmbeddings = Field(
default_factory=lambda: OpenAIEmbeddings(model="text-embedding-ada-002"),
description="Function that turns a string into a numeric vector.",
)
Expand Down
2 changes: 2 additions & 0 deletions tests/llm/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ def test_get_groq_model(monkeypatch):
assert isinstance(model, ChatGroq)
assert model.model_name == "mixtral-8x7b-32768"


def test_get_ollama_model(monkeypatch):
model = get_model("ollama/qwen2.5")
assert isinstance(model, ChatOllama)
assert model.model == "qwen2.5"


def test_get_model_with_invalid_format():
with pytest.raises(ValueError, match="The model `gpt-4o` is not valid."):
get_model("gpt-4o")
Expand Down
Loading

0 comments on commit ca4adb0

Please sign in to comment.