|
4 | 4 | import re
|
5 | 5 | import tempfile
|
6 | 6 | from pathlib import Path
|
| 7 | +from typing import Any |
7 | 8 |
|
8 | 9 | import pytest
|
| 10 | +from httpx import Response as httpxResponse |
9 | 11 | from pydantic import BaseModel, ValidationError
|
10 | 12 | from pytest_httpserver import HTTPServer, URIPattern
|
11 | 13 | from werkzeug.wrappers import Request, Response
|
12 | 14 |
|
13 | 15 | from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools
|
14 |
| -from ollama._types import Image |
| 16 | +from ollama._types import Image, Message |
15 | 17 |
|
16 | 18 | PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'
|
17 | 19 | PNG_BYTES = base64.b64decode(PNG_BASE64)
|
@@ -1181,3 +1183,32 @@ async def test_async_client_connection_error():
|
1181 | 1183 | with pytest.raises(ConnectionError) as exc_info:
|
1182 | 1184 | await client.show('model')
|
1183 | 1185 | assert str(exc_info.value) == 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
|
| 1186 | + |
| 1187 | + |
| 1188 | +def test_arbitrary_roles_accepted_in_message(): |
| 1189 | + _ = Message(role='somerandomrole', content="I'm ok with you adding any role message now!") |
| 1190 | + |
| 1191 | + |
| 1192 | +def _mock_request(*args: Any, **kwargs: Any) -> Response: |
| 1193 | + return httpxResponse(status_code=200, content="{'response': 'Hello world!'}") |
| 1194 | + |
| 1195 | + |
| 1196 | +def test_arbitrary_roles_accepted_in_message_request(monkeypatch: pytest.MonkeyPatch): |
| 1197 | + monkeypatch.setattr(Client, '_request', _mock_request) |
| 1198 | + |
| 1199 | + client = Client() |
| 1200 | + |
| 1201 | + client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}]) |
| 1202 | + |
| 1203 | + |
| 1204 | +async def _mock_request_async(*args: Any, **kwargs: Any) -> Response: |
| 1205 | + return httpxResponse(status_code=200, content="{'response': 'Hello world!'}") |
| 1206 | + |
| 1207 | + |
| 1208 | +@pytest.mark.asyncio |
| 1209 | +async def test_arbitrary_roles_accepted_in_message_request_async(monkeypatch: pytest.MonkeyPatch): |
| 1210 | + monkeypatch.setattr(AsyncClient, '_request', _mock_request_async) |
| 1211 | + |
| 1212 | + client = AsyncClient() |
| 1213 | + |
| 1214 | + await client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}]) |
0 commit comments