Skip to content

Commit 7da02c0

Browse files
committed
Add choices to output state
1 parent 8afe02a commit 7da02c0

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

chatlas/_chat.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import os
66
import sys
7+
import time
78
import traceback
89
import warnings
910
from pathlib import Path
@@ -868,9 +869,10 @@ def my_eval(grader_model: str = "openai/gpt-4o"):
868869
@isolver.solver("chatlas_solver")
869870
def _solver():
870871
async def solve(state: InspectTaskState, generate):
872+
start_time = time.perf_counter()
871873
if not state.messages:
872874
for turn in chat_instance.get_turns():
873-
state.messages.append(*turn_as_messages(turn, model=model))
875+
state.messages.extend(turn_as_messages(turn, turn.role, model))
874876

875877
user_content = state.user_prompt.content
876878
if isinstance(user_content, str):
@@ -885,19 +887,27 @@ async def solve(state: InspectTaskState, generate):
885887
last_turn = chat_instance.get_last_turn(role="assistant")
886888
if last_turn is None:
887889
raise ValueError("No assistant turn found after chat completion")
888-
state.messages.append(*turn_as_messages(last_turn, model=model))
889-
tokens = last_turn.tokens or (0, 0, 0)
890-
state.output = imodel.ModelOutput(
891-
model=model,
892-
# TODO: add choices?
893-
# choices=<choices>,
894-
completion=last_turn.text,
895-
usage=imodel.ModelUsage(
890+
891+
last_turn_message = turn_as_messages(last_turn, "assistant", model)[0]
892+
state.messages.append(last_turn_message)
893+
894+
tokens = last_turn.tokens
895+
if tokens is None:
896+
usage = None
897+
else:
898+
usage = imodel.ModelUsage(
896899
input_tokens=tokens[0],
897900
output_tokens=tokens[1],
898901
total_tokens=tokens[0] + tokens[1],
899902
input_tokens_cache_read=tokens[2],
900-
),
903+
)
904+
905+
state.output = imodel.ModelOutput(
906+
model=model,
907+
choices=[imodel.ChatCompletionChoice(message=last_turn_message)],
908+
completion=last_turn.text,
909+
usage=usage,
910+
time=time.perf_counter() - start_time,
901911
)
902912
return state
903913

chatlas/_inspect.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Literal, overload
44

55
from ._content import (
66
Content,
@@ -19,8 +19,34 @@
1919
import inspect_ai.solver as isolver
2020
import inspect_ai.tool as itool
2121

22+
Role = Literal["system", "user", "assistant"]
2223

23-
def turn_as_messages(turn: Turn, model: str | None = None) -> list:
24+
25+
@overload
26+
def turn_as_messages(
27+
turn: Turn, role: Literal["system"], model: str | None = None
28+
) -> list[imodel.ChatMessageSystem]: ...
29+
30+
31+
@overload
32+
def turn_as_messages(
33+
turn: Turn, role: Literal["user"], model: str | None = None
34+
) -> list[imodel.ChatMessage]: ...
35+
36+
37+
@overload
38+
def turn_as_messages(
39+
turn: Turn, role: Literal["assistant"], model: str | None = None
40+
) -> list[imodel.ChatMessageAssistant]: ...
41+
42+
43+
def turn_as_messages(
44+
turn: Turn, role: Role, model: str | None = None
45+
) -> (
46+
list[imodel.ChatMessageSystem]
47+
| list[imodel.ChatMessage]
48+
| list[imodel.ChatMessageAssistant]
49+
):
2450
"""
2551
Translate a chatlas Turn into InspectAI ChatMessages.
2652

0 commit comments

Comments
 (0)