Skip to content

Commit fc271fc

Browse files
committed
Add unit tests.
1 parent 48647ad commit fc271fc

5 files changed

Lines changed: 142 additions & 2 deletions

File tree

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
*.sh
1+
*.sh
2+
.windsurfrules
3+
scratchpad.md

__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,7 @@ python-multipart
55
numpy
66
websockets
77
scipy
8-
google-generativeai
8+
google-generativeai
9+
pytest
10+
pytest-asyncio
11+
pytest-mock

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import pytest
2+
import asyncio
3+
import json
4+
import websockets
5+
from unittest.mock import AsyncMock, patch, MagicMock
6+
from openai_realtime_client import OpenAIRealtimeAudioTextClient
7+
8+
@pytest.fixture
9+
def api_key():
10+
return "test_api_key"
11+
12+
@pytest.fixture
13+
def client(api_key):
14+
return OpenAIRealtimeAudioTextClient(api_key)
15+
16+
@pytest.mark.asyncio
17+
async def test_connect_success(client):
18+
mock_ws = AsyncMock()
19+
mock_ws.recv.return_value = json.dumps({
20+
"type": "session.created",
21+
"session": {"id": "test_session_id"}
22+
})
23+
24+
with patch('websockets.connect', AsyncMock(return_value=mock_ws)):
25+
await client.connect()
26+
27+
assert client.session_id == "test_session_id"
28+
assert client.ws == mock_ws
29+
assert len(client.handlers) > 0
30+
assert "default" in client.handlers
31+
32+
# Verify the session update message was sent
33+
expected_update = {
34+
"type": "session.update",
35+
"session": {
36+
"modalities": ["text"],
37+
"input_audio_format": "pcm16",
38+
"input_audio_transcription": None,
39+
"turn_detection": None,
40+
}
41+
}
42+
mock_ws.send.assert_awaited_with(json.dumps(expected_update))
43+
44+
@pytest.mark.asyncio
45+
async def test_send_audio(client):
46+
mock_ws = AsyncMock()
47+
mock_ws.open = True
48+
client.ws = mock_ws
49+
50+
test_audio = b"test_audio_data"
51+
await client.send_audio(test_audio)
52+
53+
expected_message = {
54+
"type": "input_audio_buffer.append",
55+
"audio": "dGVzdF9hdWRpb19kYXRh" # base64 encoded test_audio_data
56+
}
57+
mock_ws.send.assert_awaited_with(json.dumps(expected_message))
58+
59+
@pytest.mark.asyncio
60+
async def test_commit_audio(client):
61+
mock_ws = AsyncMock()
62+
mock_ws.open = True
63+
client.ws = mock_ws
64+
65+
await client.commit_audio()
66+
67+
expected_message = {"type": "input_audio_buffer.commit"}
68+
mock_ws.send.assert_awaited_with(json.dumps(expected_message))
69+
70+
@pytest.mark.asyncio
71+
async def test_clear_audio_buffer(client):
72+
mock_ws = AsyncMock()
73+
mock_ws.open = True
74+
client.ws = mock_ws
75+
76+
await client.clear_audio_buffer()
77+
78+
expected_message = {"type": "input_audio_buffer.clear"}
79+
mock_ws.send.assert_awaited_with(json.dumps(expected_message))
80+
81+
@pytest.mark.asyncio
82+
async def test_start_response(client):
83+
mock_ws = AsyncMock()
84+
mock_ws.open = True
85+
client.ws = mock_ws
86+
87+
test_instructions = "test instructions"
88+
await client.start_response(test_instructions)
89+
90+
expected_message = {
91+
"type": "response.create",
92+
"response": {
93+
"modalities": ["text"],
94+
"instructions": test_instructions
95+
}
96+
}
97+
mock_ws.send.assert_awaited_with(json.dumps(expected_message))
98+
99+
@pytest.mark.asyncio
100+
async def test_close(client):
101+
mock_ws = AsyncMock()
102+
client.ws = mock_ws
103+
client.receive_task = asyncio.create_task(asyncio.sleep(0))
104+
105+
await client.close()
106+
107+
mock_ws.close.assert_awaited_once()
108+
assert client.receive_task.cancelled()
109+
110+
@pytest.mark.asyncio
111+
async def test_receive_messages(client):
112+
mock_ws = AsyncMock()
113+
test_message = {"type": "test_type", "data": "test_data"}
114+
mock_ws.__aiter__.return_value = [json.dumps(test_message)]
115+
client.ws = mock_ws
116+
117+
# Create a mock handler and register it
118+
mock_handler = AsyncMock()
119+
client.register_handler("test_type", mock_handler)
120+
121+
# Start receive_messages
122+
receive_task = asyncio.create_task(client.receive_messages())
123+
await asyncio.sleep(0.1) # Give some time for the message to be processed
124+
125+
# Verify the handler was called with the correct message
126+
mock_handler.assert_awaited_once_with(test_message)
127+
128+
# Clean up
129+
receive_task.cancel()
130+
try:
131+
await receive_task
132+
except asyncio.CancelledError:
133+
pass

0 commit comments

Comments
 (0)