Skip to content

Commit 05afc1b

Browse files
committed
test(cli): add comprehensive CLI test suite and reorganize files (#1339)
1 parent dd1c64f commit 05afc1b

File tree

7 files changed

+1869
-316
lines changed

7 files changed

+1869
-316
lines changed

tests/cli/test_chat.py

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
import sys
18+
from unittest.mock import AsyncMock, MagicMock, patch
19+
20+
import pytest
21+
22+
from nemoguardrails.cli.chat import (
23+
ChatState,
24+
extract_scene_text_content,
25+
parse_events_inputs,
26+
run_chat,
27+
)
28+
29+
chat_module = sys.modules["nemoguardrails.cli.chat"]
30+
31+
32+
class TestParseEventsInputs:
33+
def test_parse_simple_event(self):
34+
result = parse_events_inputs("UserAction")
35+
assert result == {"type": "UserAction"}
36+
37+
def test_parse_event_with_params(self):
38+
result = parse_events_inputs('UserAction(name="test", value=123)')
39+
assert result == {"type": "UserAction", "name": "test", "value": 123}
40+
41+
def test_parse_event_with_string_params(self):
42+
result = parse_events_inputs('UserAction(message="hello world")')
43+
assert result == {"type": "UserAction", "message": "hello world"}
44+
45+
def test_parse_nested_event(self):
46+
result = parse_events_inputs("bot.UtteranceAction")
47+
assert result == {"type": "botUtteranceAction"}
48+
49+
def test_parse_event_with_nested_params(self):
50+
result = parse_events_inputs('UserAction(data={"key": "value"})')
51+
assert result == {"type": "UserAction", "data": {"key": "value"}}
52+
53+
def test_parse_event_with_list_params(self):
54+
result = parse_events_inputs("UserAction(items=[1, 2, 3])")
55+
assert result == {"type": "UserAction", "items": [1, 2, 3]}
56+
57+
def test_parse_invalid_event(self):
58+
result = parse_events_inputs("Invalid.Event.Format.TooMany")
59+
assert result is None
60+
61+
def test_parse_event_missing_equals(self):
62+
result = parse_events_inputs("UserAction(invalid_param)")
63+
assert result is None
64+
65+
66+
class TestExtractSceneTextContent:
67+
def test_extract_empty_list(self):
68+
result = extract_scene_text_content([])
69+
assert result == ""
70+
71+
def test_extract_single_text(self):
72+
content = [{"text": "Hello World"}]
73+
result = extract_scene_text_content(content)
74+
assert result == "\nHello World"
75+
76+
def test_extract_multiple_texts(self):
77+
content = [{"text": "Line 1"}, {"text": "Line 2"}, {"text": "Line 3"}]
78+
result = extract_scene_text_content(content)
79+
assert result == "\nLine 1\nLine 2\nLine 3"
80+
81+
def test_extract_mixed_content(self):
82+
content = [
83+
{"text": "Text 1"},
84+
{"image": "image.png"},
85+
{"text": "Text 2"},
86+
{"button": "Click Me"},
87+
]
88+
result = extract_scene_text_content(content)
89+
assert result == "\nText 1\nText 2"
90+
91+
def test_extract_no_text_content(self):
92+
content = [{"image": "image.png"}, {"button": "Click Me"}]
93+
result = extract_scene_text_content(content)
94+
assert result == ""
95+
96+
97+
class TestChatState:
98+
def test_initial_state(self):
99+
chat_state = ChatState()
100+
assert chat_state.state is None
101+
assert chat_state.waiting_user_input is False
102+
assert chat_state.paused is False
103+
assert chat_state.running_timer_tasks == {}
104+
assert chat_state.input_events == []
105+
assert chat_state.output_events == []
106+
assert chat_state.output_state is None
107+
assert chat_state.events_counter == 0
108+
assert chat_state.first_time is False
109+
110+
111+
class TestRunChat:
112+
def test_run_chat_v1_0(self):
113+
with patch.object(
114+
chat_module, "RailsConfig"
115+
) as mock_rails_config, patch.object(
116+
chat_module, "LLMRails"
117+
) as mock_llm_rails, patch(
118+
"asyncio.run"
119+
) as mock_asyncio_run:
120+
mock_config = MagicMock()
121+
mock_config.colang_version = "1.0"
122+
mock_rails_config.from_path.return_value = mock_config
123+
124+
run_chat(config_path="test_config")
125+
126+
mock_rails_config.from_path.assert_called_once_with("test_config")
127+
mock_asyncio_run.assert_called_once()
128+
129+
def test_run_chat_v2_x(self):
130+
with patch.object(
131+
chat_module, "RailsConfig"
132+
) as mock_rails_config, patch.object(
133+
chat_module, "LLMRails"
134+
) as mock_llm_rails, patch.object(
135+
chat_module, "get_or_create_event_loop"
136+
) as mock_get_loop:
137+
mock_config = MagicMock()
138+
mock_config.colang_version = "2.x"
139+
mock_rails_config.from_path.return_value = mock_config
140+
141+
mock_loop = MagicMock()
142+
mock_get_loop.return_value = mock_loop
143+
144+
run_chat(config_path="test_config")
145+
146+
mock_rails_config.from_path.assert_called_once_with("test_config")
147+
mock_llm_rails.assert_called_once_with(mock_config, verbose=False)
148+
mock_loop.run_until_complete.assert_called_once()
149+
150+
def test_run_chat_invalid_version(self):
151+
with patch.object(chat_module, "RailsConfig") as mock_rails_config:
152+
mock_config = MagicMock()
153+
mock_config.colang_version = "3.0"
154+
mock_rails_config.from_path.return_value = mock_config
155+
156+
with pytest.raises(Exception, match="Invalid colang version"):
157+
run_chat(config_path="test_config")
158+
159+
def test_run_chat_verbose_with_llm_calls(self):
160+
with patch.object(chat_module, "RailsConfig") as mock_rails_config, patch(
161+
"asyncio.run"
162+
) as mock_asyncio_run, patch.object(chat_module, "console") as mock_console:
163+
mock_config = MagicMock()
164+
mock_config.colang_version = "1.0"
165+
mock_rails_config.from_path.return_value = mock_config
166+
167+
run_chat(config_path="test_config", verbose=True, verbose_llm_calls=True)
168+
169+
mock_console.print.assert_any_call(
170+
"NOTE: use the `--verbose-no-llm` option to exclude the LLM prompts "
171+
"and completions from the log.\n"
172+
)
173+
174+
175+
class TestRunChatV1Async:
176+
@pytest.mark.asyncio
177+
async def test_run_chat_v1_no_config_no_server(self):
178+
from nemoguardrails.cli.chat import _run_chat_v1_0
179+
180+
with pytest.raises(RuntimeError, match="At least one of"):
181+
await _run_chat_v1_0(config_path=None, server_url=None)
182+
183+
@pytest.mark.asyncio
184+
@patch("builtins.input")
185+
@patch.object(chat_module, "LLMRails")
186+
@patch.object(chat_module, "RailsConfig")
187+
async def test_run_chat_v1_local_config(
188+
self, mock_rails_config, mock_llm_rails, mock_input
189+
):
190+
from nemoguardrails.cli.chat import _run_chat_v1_0
191+
192+
mock_config = MagicMock()
193+
mock_config.streaming_supported = False
194+
mock_rails_config.from_path.return_value = mock_config
195+
196+
mock_rails = AsyncMock()
197+
mock_rails.generate_async = AsyncMock(
198+
return_value={"role": "assistant", "content": "Hello!"}
199+
)
200+
mock_rails.main_llm_supports_streaming = False
201+
mock_llm_rails.return_value = mock_rails
202+
203+
mock_input.side_effect = ["test message", KeyboardInterrupt()]
204+
205+
try:
206+
await _run_chat_v1_0(config_path="test_config")
207+
except KeyboardInterrupt:
208+
pass
209+
210+
mock_rails.generate_async.assert_called_once()
211+
212+
@pytest.mark.asyncio
213+
@patch("builtins.input")
214+
@patch.object(chat_module, "console")
215+
@patch.object(chat_module, "LLMRails")
216+
@patch.object(chat_module, "RailsConfig")
217+
async def test_run_chat_v1_streaming_not_supported(
218+
self, mock_rails_config, mock_llm_rails, mock_console, mock_input
219+
):
220+
from nemoguardrails.cli.chat import _run_chat_v1_0
221+
222+
mock_config = MagicMock()
223+
mock_config.streaming_supported = False
224+
mock_rails_config.from_path.return_value = mock_config
225+
226+
mock_rails = AsyncMock()
227+
mock_llm_rails.return_value = mock_rails
228+
229+
mock_input.side_effect = [KeyboardInterrupt()]
230+
231+
try:
232+
await _run_chat_v1_0(config_path="test_config", streaming=True)
233+
except KeyboardInterrupt:
234+
pass
235+
236+
mock_console.print.assert_any_call(
237+
"WARNING: The config `test_config` does not support streaming. "
238+
"Falling back to normal mode."
239+
)
240+
241+
@pytest.mark.asyncio
242+
@patch("aiohttp.ClientSession")
243+
@patch("builtins.input")
244+
async def test_run_chat_v1_server_mode(self, mock_input, mock_client_session):
245+
from nemoguardrails.cli.chat import _run_chat_v1_0
246+
247+
mock_session = AsyncMock()
248+
mock_response = AsyncMock()
249+
mock_response.headers = {}
250+
mock_response.json = AsyncMock(
251+
return_value={
252+
"messages": [{"role": "assistant", "content": "Server response"}]
253+
}
254+
)
255+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
256+
mock_response.__aexit__ = AsyncMock()
257+
258+
mock_post_context = AsyncMock()
259+
mock_post_context.__aenter__ = AsyncMock(return_value=mock_response)
260+
mock_post_context.__aexit__ = AsyncMock()
261+
mock_session.post = MagicMock(return_value=mock_post_context)
262+
263+
mock_client_session.return_value.__aenter__ = AsyncMock(
264+
return_value=mock_session
265+
)
266+
mock_client_session.return_value.__aexit__ = AsyncMock()
267+
268+
mock_input.side_effect = ["test message", KeyboardInterrupt()]
269+
270+
try:
271+
await _run_chat_v1_0(
272+
server_url="http://localhost:8000", config_id="test_id"
273+
)
274+
except KeyboardInterrupt:
275+
pass
276+
277+
assert mock_session.post.called
278+
call_args = mock_session.post.call_args
279+
assert call_args[0][0] == "http://localhost:8000/v1/chat/completions"
280+
assert "config_id" in call_args[1]["json"]
281+
assert call_args[1]["json"]["config_id"] == "test_id"
282+
assert call_args[1]["json"]["stream"] is False
283+
284+
@pytest.mark.asyncio
285+
@patch("aiohttp.ClientSession")
286+
@patch("builtins.input")
287+
async def test_run_chat_v1_server_streaming(self, mock_input, mock_client_session):
288+
from nemoguardrails.cli.chat import _run_chat_v1_0
289+
290+
mock_session = AsyncMock()
291+
mock_response = AsyncMock()
292+
mock_response.headers = {"Transfer-Encoding": "chunked"}
293+
294+
async def mock_iter_any():
295+
yield b"Stream "
296+
yield b"response"
297+
298+
mock_response.content.iter_any = mock_iter_any
299+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
300+
mock_response.__aexit__ = AsyncMock()
301+
302+
mock_post_context = AsyncMock()
303+
mock_post_context.__aenter__ = AsyncMock(return_value=mock_response)
304+
mock_post_context.__aexit__ = AsyncMock()
305+
mock_session.post = MagicMock(return_value=mock_post_context)
306+
307+
mock_client_session.return_value.__aenter__ = AsyncMock(
308+
return_value=mock_session
309+
)
310+
mock_client_session.return_value.__aexit__ = AsyncMock()
311+
312+
mock_input.side_effect = ["test message", KeyboardInterrupt()]
313+
314+
try:
315+
await _run_chat_v1_0(
316+
server_url="http://localhost:8000", config_id="test_id", streaming=True
317+
)
318+
except KeyboardInterrupt:
319+
pass
320+
321+
assert mock_session.post.called

0 commit comments

Comments
 (0)