Skip to content

Commit 605fde0

Browse files
committed
Introduce .export_eval() for more delightful eval collection
1 parent 7da02c0 commit 605fde0

File tree

5 files changed

+391
-148
lines changed

5 files changed

+391
-148
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Untitled*.ipynb
77
uv.lock
88

99
sandbox/
10+
logs/
1011

1112
/.luarc.json
1213

chatlas/_chat.py

Lines changed: 229 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Optional,
2424
Sequence,
2525
TypeVar,
26+
cast,
2627
overload,
2728
)
2829

@@ -54,6 +55,8 @@
5455
from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async
5556

5657
if TYPE_CHECKING:
58+
from inspect_ai.model import ChatMessage as InspectChatMessage
59+
from inspect_ai.model import ChatMessageAssistant as InspectChatMessageAssistant
5760
from inspect_ai.solver import TaskState as InspectTaskState
5861

5962
from ._content import ToolAnnotations
@@ -810,7 +813,12 @@ def console(
810813
self.chat(user_input, echo=echo, stream=stream, kwargs=kwargs)
811814
print("")
812815

813-
def to_solver(self):
816+
def to_solver(
817+
self,
818+
*,
819+
include_system_prompt: bool = False,
820+
include_turns: bool = False,
821+
):
814822
"""
815823
Create an InspectAI solver from this chat.
816824
@@ -819,6 +827,25 @@ def to_solver(self):
819827
(and translate) important state from the chat, including the model,
820828
system prompt, previous turns, registered tools, model parameters, etc.
821829
830+
Parameters
831+
----------
832+
include_system_prompt
833+
Whether to include the system prompt in the solver's starting
834+
messages.
835+
include_turns
836+
Whether to include the chat's existing turns in the solver's
837+
starting messages.
838+
839+
Note
840+
----
841+
Both `include_system_prompt` and `include_turns` default to `False` since
842+
`.export_eval()` captures this information already. Therefore,
843+
including them here would lead to duplication of context in the
844+
evaluation. However, in some cases you may want to include them, for
845+
example if you are manually constructing an evaluation dataset that
846+
does not include this information. Or, if you want to always have the
847+
same starting context regardless of the evaluation dataset.
848+
822849
Returns
823850
-------
824851
An [InspectAI solver](https://inspect.ai-safety-institute.org.uk/solvers.html)
@@ -857,38 +884,91 @@ def my_eval(grader_model: str = "openai/gpt-4o"):
857884
in the [Chatlas documentation](https://posit-dev.github.io/chatlas/misc/evals.html).
858885
"""
859886

860-
from ._inspect import content_to_chatlas, try_import_inspect, turn_as_messages
887+
from ._inspect import (
888+
inspect_content_as_chatlas,
889+
inspect_messages_as_turns,
890+
try_import_inspect,
891+
)
861892

862893
(imodel, isolver, _) = try_import_inspect()
863894

864895
# Create a copy of the chat to avoid modifying its state
865-
# when inspect uses the solver
896+
# when inspect runs the solver
866897
chat_instance = copy.deepcopy(self)
867898
model = self.provider.model
868899

869-
@isolver.solver("chatlas_solver")
900+
# Remove existing turns if requested
901+
if not include_turns:
902+
chat_instance.set_turns([])
903+
904+
# Prepare the starting messages from the chat instance
905+
starting_turns = chat_instance.get_turns(
906+
include_system_prompt=include_system_prompt
907+
)
908+
909+
# Translate starting turns to Inspect messages
910+
starting_messages: list["InspectChatMessage"] = []
911+
for turn in starting_turns:
912+
starting_messages.extend(turn.to_inspect_messages(model))
913+
914+
# Since Inspect preserves state, across solves, prepend starting messages only once
915+
has_starting_messages = False
916+
917+
@isolver.solver(f"chatlas_{self.provider.name}_{model}")
870918
def _solver():
871919
async def solve(state: InspectTaskState, generate):
920+
nonlocal has_starting_messages
872921
start_time = time.perf_counter()
873-
if not state.messages:
874-
for turn in chat_instance.get_turns():
875-
state.messages.extend(turn_as_messages(turn, turn.role, model))
876922

923+
if not has_starting_messages:
924+
state.messages = starting_messages + state.messages
925+
has_starting_messages = True
926+
927+
# Now that we've translated the starting messages to Inspect,
928+
# we translate the message state back to the chat instance.
929+
# N.B., state.message can include non-trivial dataset of sample input
930+
# (e.g., `Sample(input=[ChatMessage, ...])`)
931+
system_prompts: list["InspectChatMessage"] = []
932+
other_prompts: list["InspectChatMessage"] = []
933+
for x in state.messages:
934+
if x.role == "system":
935+
system_prompts.append(x)
936+
else:
937+
other_prompts.append(x)
938+
939+
# Set the system prompt on the chat instance
940+
if len(system_prompts) == 1:
941+
chat_instance.system_prompt = str(system_prompts[0])
942+
elif len(system_prompts) > 1:
943+
raise ValueError(
944+
"Multiple system prompts detected in `.to_solver()`, but chatlas only "
945+
"supports a single system prompt. This usually indicates that the system "
946+
"prompt is mistakenly included in both the eval dataset (via `.export_eval()`) "
947+
"and on the chat instance. Consider dropping the system prompt from "
948+
"the chat instance by setting `.to_solver(include_system_prompt=False)`."
949+
)
950+
951+
# Now, set the other messages as turns on the chat instance
952+
chat_instance.set_turns(inspect_messages_as_turns(other_prompts))
953+
954+
# TODO: inspect docs mention this is always the _first_? user message??
877955
user_content = state.user_prompt.content
878956
if isinstance(user_content, str):
879957
input_content = [user_content]
880958
else:
881-
input_content = [content_to_chatlas(x) for x in user_content]
959+
input_content = [
960+
inspect_content_as_chatlas(x) for x in user_content
961+
]
882962

883-
await chat_instance.chat_async(
884-
*input_content,
885-
echo="none",
886-
)
963+
await chat_instance.chat_async(*input_content, echo="none")
887964
last_turn = chat_instance.get_last_turn(role="assistant")
888965
if last_turn is None:
889966
raise ValueError("No assistant turn found after chat completion")
890967

891-
last_turn_message = turn_as_messages(last_turn, "assistant", model)[0]
968+
last_turn_message = cast(
969+
"InspectChatMessageAssistant",
970+
last_turn.to_inspect_messages(model)[0],
971+
)
892972
state.messages.append(last_turn_message)
893973

894974
tokens = last_turn.tokens
@@ -2142,6 +2222,142 @@ def _html_template(contents: str) -> str:
21422222
</html>
21432223
"""
21442224

2225+
def export_eval(
2226+
self,
2227+
filename: str | Path,
2228+
*,
2229+
target: Optional[str] = None,
2230+
include_system_prompt: bool = True,
2231+
turns: Optional[list[Turn]] = None,
2232+
overwrite: Literal["append", True, False] = "append",
2233+
**kwargs: Any,
2234+
):
2235+
"""
2236+
Create an Inspect AI eval dataset sample from the current chat.
2237+
2238+
Creates an Inspect AI eval
2239+
[Sample](https://inspect.aisi.org.uk/reference/inspect_ai.dataset.html#sample)
2240+
from the current chat and appends it to a JSONL file. In Inspect, a eval
2241+
dataset is a collection of Samples, where each Sample represents a
2242+
single `input` (i.e., user prompt) and the expected `target` (i.e., the
2243+
target answer and/or grading guidance for it). Note that each `input` of
2244+
a particular sample can contain a series of messages (from both the user
2245+
and assistant).
2246+
2247+
Note
2248+
----
2249+
Each call to this method appends a single Sample as a new line in the
2250+
specified JSONL file. If the file does not exist, it will be created.
2251+
2252+
Parameters
2253+
----------
2254+
filename
2255+
The filename to export the chat to. Currently this must
2256+
be a `.jsonl` file.
2257+
target
2258+
The target output for the eval sample. By default, this is
2259+
taken to be the content of the last assistant turn.
2260+
include_system_prompt
2261+
Whether to include the system prompt (if any) as the
2262+
first turn in the eval sample.
2263+
turns
2264+
The input turns for the eval sample. By default, this is
2265+
taken to be all turns except the last (assistant) turn.
2266+
Note that system prompts are not allowed here, but controlled
2267+
separately via the `include_system_prompt` parameter.
2268+
overwrite
2269+
Behavior when the file already exists:
2270+
- `"append"` (default): Append to the existing file.
2271+
- `True`: Overwrite the existing file.
2272+
- `False`: Raise an error if the file already exists.
2273+
kwargs
2274+
Additional keyword arguments to pass to the `Sample()` constructor.
2275+
This is primarily useful for setting an ID or metadata on the sample.
2276+
2277+
Examples
2278+
--------
2279+
2280+
Step 1: export the chat to an eval JSONL file
2281+
2282+
```python
2283+
from chatlas import ChatOpenAI
2284+
2285+
chat = ChatOpenAI(system_prompt="You are a helpful assistant.")
2286+
chat.chat("Hello, how are you?")
2287+
2288+
chat.export_eval("my_eval_1.jsonl")
2289+
```
2290+
2291+
Step 2: load the eval JSONL file into an Inspect AI eval task
2292+
2293+
```python
2294+
from chatlas import ChatOpenAI
2295+
from inspect_ai import Task, task
2296+
from inspect_ai.dataset import json_dataset
2297+
from inspect_ai.scorer import model_graded_qa
2298+
2299+
# No need to load in system prompt -- it's included in the eval JSONL file by default
2300+
chat = ChatOpenAI()
2301+
2302+
2303+
@task
2304+
def my_eval():
2305+
return Task(
2306+
dataset=json_dataset("my_eval.jsonl"),
2307+
solver=chat.to_solver(),
2308+
scorer=model_graded_qa(model="openai/gpt-4o-mini"),
2309+
)
2310+
```
2311+
"""
2312+
2313+
if isinstance(filename, str):
2314+
filename = Path(filename)
2315+
2316+
filename = filename.resolve()
2317+
if filename.exists() and overwrite is False:
2318+
raise ValueError(
2319+
f"File {filename} already exists. Set `overwrite=True` to overwrite or `overwrite='append'` to append."
2320+
)
2321+
2322+
if filename.suffix not in {".jsonl"}:
2323+
raise ValueError("The filename must have a `.jsonl` extension.")
2324+
2325+
if turns is None:
2326+
turns = self.get_turns(include_system_prompt=False)
2327+
2328+
if any(x.role == "system" for x in turns):
2329+
raise ValueError("System prompts are not allowed in eval input turns.")
2330+
2331+
if not any(x.role == "user" for x in turns):
2332+
raise ValueError("At least one user turn is required in eval input turns.")
2333+
2334+
if include_system_prompt:
2335+
system_turn = self.get_last_turn(role="system")
2336+
if system_turn:
2337+
turns = [system_turn] + turns
2338+
2339+
input_turns, target_turn = turns[:-1], turns[-1]
2340+
if target_turn.role != "assistant":
2341+
raise ValueError("The last turn must be an assistant turn.")
2342+
2343+
if target is None:
2344+
target = str(target_turn)
2345+
2346+
input_messages = []
2347+
for x in input_turns:
2348+
input_messages.extend(x.to_inspect_messages())
2349+
2350+
from inspect_ai.dataset import Sample
2351+
2352+
sample = Sample(input=input_messages, target=target, **kwargs)
2353+
sample_json = sample.model_dump_json(exclude_none=True)
2354+
2355+
mode = "a" if overwrite == "append" and filename.exists() else "w"
2356+
with open(filename, mode) as f:
2357+
f.write(sample_json + "\n")
2358+
2359+
return filename
2360+
21452361
@overload
21462362
def _chat_impl(
21472363
self,

0 commit comments

Comments
 (0)