Skip to content

Commit 8e72335

Browse files
committed
WIP update cohere
1 parent 7c22e73 commit 8e72335

File tree

1 file changed

+5
-76
lines changed

1 file changed

+5
-76
lines changed

tests/unit/llm/test_cohere_llm.py

Lines changed: 5 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -41,86 +41,17 @@ def test_cohere_llm_happy_path(mock_cohere: Mock) -> None:
4141
chat_response_mock = MagicMock()
4242
chat_response_mock.message.content = [MagicMock(text="cohere response text")]
4343
mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock
44+
mock_cohere.UserChatMessageV2.return_value = {"role": "user", "content": "test"}
4445
llm = CohereLLM(model_name="something")
4546
res = llm.invoke("my text")
4647
assert isinstance(res, LLMResponse)
4748
assert res.content == "cohere response text"
48-
49-
50-
def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) -> None:
51-
chat_response_mock = MagicMock()
52-
chat_response_mock.message.content = [MagicMock(text="cohere response text")]
53-
mock_cohere_client_chat = mock_cohere.ClientV2.return_value.chat
54-
mock_cohere_client_chat.return_value = chat_response_mock
55-
56-
system_instruction = "You are a helpful assistant."
57-
llm = CohereLLM(model_name="something")
58-
message_history = [
59-
{"role": "user", "content": "When does the sun come up in the summer?"},
60-
{"role": "assistant", "content": "Usually around 6am."},
61-
]
62-
question = "What about next season?"
63-
64-
res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore
65-
assert isinstance(res, LLMResponse)
66-
assert res.content == "cohere response text"
67-
messages = [{"role": "system", "content": system_instruction}]
68-
messages.extend(message_history)
69-
messages.append({"role": "user", "content": question})
70-
mock_cohere_client_chat.assert_called_once_with(
71-
messages=messages,
49+
mock_cohere.ClientV2.return_value.chat.assert_called_once_with(
50+
messages=[{"role": "user", "content": "test"}],
7251
model="something",
7352
)
7453

7554

76-
def test_cohere_llm_invoke_with_message_history_and_system_instruction(
77-
mock_cohere: Mock,
78-
) -> None:
79-
chat_response_mock = MagicMock()
80-
chat_response_mock.message.content = [MagicMock(text="cohere response text")]
81-
mock_cohere_client_chat = mock_cohere.ClientV2.return_value.chat
82-
mock_cohere_client_chat.return_value = chat_response_mock
83-
84-
system_instruction = "You are a helpful assistant."
85-
llm = CohereLLM(model_name="gpt")
86-
message_history = [
87-
{"role": "user", "content": "When does the sun come up in the summer?"},
88-
{"role": "assistant", "content": "Usually around 6am."},
89-
]
90-
question = "What about next season?"
91-
92-
res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore
93-
assert isinstance(res, LLMResponse)
94-
assert res.content == "cohere response text"
95-
messages = [{"role": "system", "content": system_instruction}]
96-
messages.extend(message_history)
97-
messages.append({"role": "user", "content": question})
98-
mock_cohere_client_chat.assert_called_once_with(
99-
messages=messages,
100-
model="gpt",
101-
)
102-
103-
104-
def test_cohere_llm_invoke_with_message_history_validation_error(
105-
mock_cohere: Mock,
106-
) -> None:
107-
chat_response_mock = MagicMock()
108-
chat_response_mock.message.content = [MagicMock(text="cohere response text")]
109-
mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock
110-
111-
system_instruction = "You are a helpful assistant."
112-
llm = CohereLLM(model_name="something", system_instruction=system_instruction)
113-
message_history = [
114-
{"role": "robot", "content": "When does the sun come up in the summer?"},
115-
{"role": "assistant", "content": "Usually around 6am."},
116-
]
117-
question = "What about next season?"
118-
119-
with pytest.raises(LLMGenerationError) as exc_info:
120-
llm.invoke(question, message_history) # type: ignore
121-
assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value)
122-
123-
12455
@pytest.mark.asyncio
12556
async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None:
12657
chat_response_mock = MagicMock(
@@ -139,16 +70,14 @@ async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None:
13970
def test_cohere_llm_failed(mock_cohere: Mock) -> None:
14071
mock_cohere.ClientV2.return_value.chat.side_effect = cohere.core.ApiError
14172
llm = CohereLLM(model_name="something")
142-
with pytest.raises(LLMGenerationError) as excinfo:
73+
with pytest.raises(LLMGenerationError, match="ApiError"):
14374
llm.invoke("my text")
144-
assert "ApiError" in str(excinfo)
14575

14676

14777
@pytest.mark.asyncio
14878
async def test_cohere_llm_failed_async(mock_cohere: Mock) -> None:
14979
mock_cohere.AsyncClientV2.return_value.chat.side_effect = cohere.core.ApiError
15080
llm = CohereLLM(model_name="something")
15181

152-
with pytest.raises(LLMGenerationError) as excinfo:
82+
with pytest.raises(LLMGenerationError, match="ApiError"):
15383
await llm.ainvoke("my text")
154-
assert "ApiError" in str(excinfo)

0 commit comments

Comments
 (0)