Skip to content

Commit 74c6323

Browse files
committed
Improve/simplify tests; avoid user prompt duplication
1 parent 605fde0 commit 74c6323

File tree

5 files changed

+343
-546
lines changed

5 files changed

+343
-546
lines changed

chatlas/_chat.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
Optional,
2424
Sequence,
2525
TypeVar,
26-
cast,
2726
overload,
2827
)
2928

@@ -56,7 +55,6 @@
5655

5756
if TYPE_CHECKING:
5857
from inspect_ai.model import ChatMessage as InspectChatMessage
59-
from inspect_ai.model import ChatMessageAssistant as InspectChatMessageAssistant
6058
from inspect_ai.solver import TaskState as InspectTaskState
6159

6260
from ._content import ToolAnnotations
@@ -888,6 +886,7 @@ def my_eval(grader_model: str = "openai/gpt-4o"):
888886
inspect_content_as_chatlas,
889887
inspect_messages_as_turns,
890888
try_import_inspect,
889+
turn_as_inspect_messages,
891890
)
892891

893892
(imodel, isolver, _) = try_import_inspect()
@@ -928,18 +927,23 @@ async def solve(state: InspectTaskState, generate):
928927
# we translate the message state back to the chat instance.
929928
# N.B., state.message can include non-trivial dataset of sample input
930929
# (e.g., `Sample(input=[ChatMessage, ...])`)
931-
system_prompts: list["InspectChatMessage"] = []
932-
other_prompts: list["InspectChatMessage"] = []
933-
for x in state.messages:
934-
if x.role == "system":
935-
system_prompts.append(x)
930+
system_messages: list["InspectChatMessage"] = []
931+
other_messages: list["InspectChatMessage"] = []
932+
user_prompt: "InspectChatMessage | None" = None
933+
for x in reversed(state.messages):
934+
if x.role == "user" and user_prompt is None:
935+
user_prompt = x
936+
elif x.role == "system":
937+
system_messages.append(x)
936938
else:
937-
other_prompts.append(x)
939+
other_messages.append(x)
940+
941+
other_messages.reverse()
938942

939943
# Set the system prompt on the chat instance
940-
if len(system_prompts) == 1:
941-
chat_instance.system_prompt = str(system_prompts[0])
942-
elif len(system_prompts) > 1:
944+
if len(system_messages) == 1:
945+
chat_instance.system_prompt = str(system_messages[0])
946+
elif len(system_messages) > 1:
943947
raise ValueError(
944948
"Multiple system prompts detected in `.to_solver()`, but chatlas only "
945949
"supports a single system prompt. This usually indicates that the system "
@@ -949,26 +953,21 @@ async def solve(state: InspectTaskState, generate):
949953
)
950954

951955
# Now, set the other messages as turns on the chat instance
952-
chat_instance.set_turns(inspect_messages_as_turns(other_prompts))
956+
chat_instance.set_turns(inspect_messages_as_turns(other_messages))
953957

954-
# TODO: inspect docs mention this is always the _first_? user message??
955-
user_content = state.user_prompt.content
956-
if isinstance(user_content, str):
957-
input_content = [user_content]
958-
else:
959-
input_content = [
960-
inspect_content_as_chatlas(x) for x in user_content
961-
]
958+
if user_prompt is None:
959+
raise ValueError("No user prompt found in InspectAI state messages")
960+
961+
input_content = [inspect_content_as_chatlas(x) for x in user_prompt.content]
962962

963963
await chat_instance.chat_async(*input_content, echo="none")
964964
last_turn = chat_instance.get_last_turn(role="assistant")
965965
if last_turn is None:
966966
raise ValueError("No assistant turn found after chat completion")
967967

968-
last_turn_message = cast(
969-
"InspectChatMessageAssistant",
970-
last_turn.to_inspect_messages(model)[0],
971-
)
968+
last_turn_message = turn_as_inspect_messages(
969+
last_turn, "assistant", model
970+
)[0]
972971
state.messages.append(last_turn_message)
973972

974973
tokens = last_turn.tokens

tests/integration/__init__.py

Whitespace-only changes.

tests/integration/test_inspect_integration.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

tests/test_chat.py

Lines changed: 2 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import tempfile
33

44
import pytest
5+
from pydantic import BaseModel
6+
57
from chatlas import (
68
ChatOpenAI,
79
ContentToolRequest,
@@ -10,7 +12,6 @@
1012
Turn,
1113
)
1214
from chatlas._chat import ToolFailureWarning
13-
from pydantic import BaseModel
1415

1516

1617
def test_simple_batch_chat():
@@ -19,157 +20,6 @@ def test_simple_batch_chat():
1920
assert str(response) == "2"
2021

2122

22-
def test_chat_to_solver_creates_solver():
23-
pytest.importorskip("inspect_ai")
24-
25-
chat = ChatOpenAI()
26-
solver = chat.to_solver()
27-
28-
assert callable(solver)
29-
30-
31-
def test_chat_to_solver_with_history():
32-
pytest.importorskip("inspect_ai")
33-
34-
chat = ChatOpenAI(system_prompt="You are a helpful assistant.")
35-
chat.set_turns(
36-
[
37-
Turn("user", "What is 2 + 2?"),
38-
Turn("assistant", "4"),
39-
]
40-
)
41-
42-
solver = chat.to_solver()
43-
44-
assert callable(solver)
45-
assert len(chat.get_turns(include_system_prompt=False)) == 2
46-
assert chat.system_prompt == "You are a helpful assistant."
47-
48-
49-
def test_chat_to_solver_with_rich_content():
50-
pytest.importorskip("inspect_ai")
51-
from chatlas import content_image_url
52-
53-
chat = ChatOpenAI(system_prompt="You analyze images.")
54-
55-
# Create a turn with mixed content (text + image)
56-
image_content = content_image_url("https://example.com/image.jpg")
57-
chat.set_turns(
58-
[
59-
Turn("user", ["Describe this image", image_content]),
60-
Turn("assistant", "This is a test image."),
61-
]
62-
)
63-
64-
solver = chat.to_solver()
65-
66-
assert callable(solver)
67-
turns = chat.get_turns(include_system_prompt=False)
68-
assert len(turns) == 2
69-
assert len(turns[0].contents) == 2 # text + image
70-
assert chat.system_prompt == "You analyze images."
71-
72-
73-
def test_chat_to_solver_with_tool_calls():
74-
pytest.importorskip("inspect_ai")
75-
from chatlas import ContentToolRequest, ContentToolResult
76-
77-
chat = ChatOpenAI()
78-
79-
# Create a turn with tool calls
80-
tool_request = ContentToolRequest(
81-
id="call_123", name="get_weather", arguments={"city": "NYC"}
82-
)
83-
tool_result = ContentToolResult(value="Sunny, 75°F", request=tool_request)
84-
85-
chat.set_turns(
86-
[
87-
Turn("user", "What's the weather in NYC?"),
88-
Turn("assistant", [tool_request]),
89-
Turn("user", [tool_result]),
90-
Turn("assistant", "The weather in NYC is sunny and 75°F."),
91-
]
92-
)
93-
94-
solver = chat.to_solver()
95-
96-
assert callable(solver)
97-
turns = chat.get_turns(include_system_prompt=False)
98-
assert len(turns) == 4
99-
assert isinstance(turns[1].contents[0], ContentToolRequest)
100-
assert isinstance(turns[2].contents[0], ContentToolResult)
101-
102-
103-
def test_chat_to_solver_without_inspect_ai():
104-
import sys
105-
from unittest.mock import patch
106-
107-
chat = ChatOpenAI()
108-
109-
# Mock inspect_ai as not installed in sys.modules
110-
with patch.dict(sys.modules, {"inspect_ai": None, "inspect_ai.model": None}):
111-
with pytest.raises(ImportError, match="pip install inspect-a"):
112-
chat.to_solver()
113-
114-
115-
@pytest.mark.asyncio
116-
async def test_chat_to_solver_with_pdf_content():
117-
pytest.importorskip("inspect_ai")
118-
from chatlas._content import ContentPDF
119-
120-
chat = ChatOpenAI(system_prompt="You analyze documents.")
121-
122-
pdf_content = ContentPDF(data=b"Mock PDF data")
123-
chat.set_turns(
124-
[
125-
Turn("user", ["Analyze this document", pdf_content]),
126-
Turn("assistant", "This appears to be a PDF document."),
127-
]
128-
)
129-
130-
solver = chat.to_solver()
131-
assert callable(solver)
132-
133-
turns = chat.get_turns(include_system_prompt=False)
134-
assert len(turns) == 2
135-
assert any(isinstance(c, ContentPDF) for c in turns[0].contents)
136-
137-
138-
@pytest.mark.asyncio
139-
async def test_chat_to_solver_with_json_content():
140-
pytest.importorskip("inspect_ai")
141-
from chatlas._content import ContentJson
142-
143-
chat = ChatOpenAI()
144-
145-
json_content = ContentJson(value={"key": "value", "number": 42})
146-
chat.set_turns(
147-
[
148-
Turn("user", "Process this data"),
149-
Turn("assistant", [json_content]),
150-
]
151-
)
152-
153-
solver = chat.to_solver()
154-
assert callable(solver)
155-
156-
turns = chat.get_turns(include_system_prompt=False)
157-
assert len(turns) == 2
158-
assert any(isinstance(c, ContentJson) for c in turns[1].contents)
159-
160-
161-
def test_chat_to_solver_deepcopy_isolation():
162-
pytest.importorskip("inspect_ai")
163-
164-
chat = ChatOpenAI()
165-
initial_turns_count = len(chat.get_turns())
166-
167-
solver = chat.to_solver()
168-
169-
assert len(chat.get_turns()) == initial_turns_count
170-
assert callable(solver)
171-
172-
17323
@pytest.mark.asyncio
17424
async def test_simple_async_batch_chat():
17525
chat = ChatOpenAI()

0 commit comments

Comments
 (0)