Skip to content

Commit 07eec6d

Browse files
rylativityRyan Stewartgabe-l-hart
authored
types: enable passing messages with arbitrary role (#462)
--------- Co-authored-by: Ryan Stewart <[email protected]> Co-authored-by: Gabe Goodhart <[email protected]>
1 parent 6b235b2 commit 07eec6d

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

ollama/_types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ class Message(SubscriptableBaseModel):
256256
Chat message.
257257
"""
258258

259-
role: Literal['user', 'assistant', 'system', 'tool']
259+
role: str
260260
"Assumed role of the message. Response messages has role 'assistant' or 'tool'."
261261

262262
content: Optional[str] = None

tests/test_client.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import re
55
import tempfile
66
from pathlib import Path
7+
from typing import Any
78

89
import pytest
10+
from httpx import Response as httpxResponse
911
from pydantic import BaseModel, ValidationError
1012
from pytest_httpserver import HTTPServer, URIPattern
1113
from werkzeug.wrappers import Request, Response
1214

1315
from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools
14-
from ollama._types import Image
16+
from ollama._types import Image, Message
1517

1618
PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'
1719
PNG_BYTES = base64.b64decode(PNG_BASE64)
@@ -1181,3 +1183,32 @@ async def test_async_client_connection_error():
11811183
with pytest.raises(ConnectionError) as exc_info:
11821184
await client.show('model')
11831185
assert str(exc_info.value) == 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
1186+
1187+
1188+
def test_arbitrary_roles_accepted_in_message():
1189+
_ = Message(role='somerandomrole', content="I'm ok with you adding any role message now!")
1190+
1191+
1192+
def _mock_request(*args: Any, **kwargs: Any) -> Response:
1193+
return httpxResponse(status_code=200, content="{'response': 'Hello world!'}")
1194+
1195+
1196+
def test_arbitrary_roles_accepted_in_message_request(monkeypatch: pytest.MonkeyPatch):
1197+
monkeypatch.setattr(Client, '_request', _mock_request)
1198+
1199+
client = Client()
1200+
1201+
client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}])
1202+
1203+
1204+
async def _mock_request_async(*args: Any, **kwargs: Any) -> Response:
1205+
return httpxResponse(status_code=200, content="{'response': 'Hello world!'}")
1206+
1207+
1208+
@pytest.mark.asyncio
1209+
async def test_arbitrary_roles_accepted_in_message_request_async(monkeypatch: pytest.MonkeyPatch):
1210+
monkeypatch.setattr(AsyncClient, '_request', _mock_request_async)
1211+
1212+
client = AsyncClient()
1213+
1214+
await client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}])

0 commit comments

Comments
 (0)