Skip to content

Commit 4b68d7a

Browse files
committed
wip first pass at new OpenAI Responses API support
1 parent df9873c commit 4b68d7a

20 files changed

+1400
-577
lines changed

chatlas/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
from ._provider_huggingface import ChatHuggingFace
2828
from ._provider_mistral import ChatMistral
2929
from ._provider_ollama import ChatOllama
30-
from ._provider_openai import ChatAzureOpenAI, ChatOpenAI
30+
from ._provider_openai import ChatOpenAI
31+
from ._provider_openai_azure import ChatAzureOpenAI
32+
from ._provider_openai_responses import ChatOpenAIResponses
3133
from ._provider_openrouter import ChatOpenRouter
3234
from ._provider_perplexity import ChatPerplexity
3335
from ._provider_portkey import ChatPortkey
@@ -59,6 +61,7 @@
5961
"ChatMistral",
6062
"ChatOllama",
6163
"ChatOpenAI",
64+
"ChatOpenAIResponses",
6265
"ChatOpenRouter",
6366
"ChatAzureOpenAI",
6467
"ChatPerplexity",

chatlas/_auto.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from ._provider_huggingface import ChatHuggingFace
1818
from ._provider_mistral import ChatMistral
1919
from ._provider_ollama import ChatOllama
20-
from ._provider_openai import ChatAzureOpenAI, ChatOpenAI
20+
from ._provider_openai import ChatOpenAI
21+
from ._provider_openai_azure import ChatAzureOpenAI
2122
from ._provider_openrouter import ChatOpenRouter
2223
from ._provider_perplexity import ChatPerplexity
2324
from ._provider_portkey import ChatPortkey

chatlas/_chat.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from ._logging import log_tool_error
4747
from ._mcp_manager import MCPSessionManager
4848
from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT
49-
from ._tokens import compute_cost, get_token_pricing
49+
from ._tokens import compute_cost, get_token_pricing, tokens_log
5050
from ._tools import Tool, ToolRejectError
5151
from ._turn import Turn, user_turn
5252
from ._typing_extensions import TypedDict, TypeGuard
@@ -2210,12 +2210,11 @@ def emit(text: str | Content):
22102210
result,
22112211
has_data_model=data_model is not None,
22122212
)
2213-
22142213
if echo == "all":
22152214
emit_other_contents(turn, emit)
22162215

22172216
else:
2218-
response = self.provider.chat_perform(
2217+
result = self.provider.chat_perform(
22192218
stream=False,
22202219
turns=[*self._turns, user_turn],
22212220
tools=self._tools,
@@ -2224,7 +2223,7 @@ def emit(text: str | Content):
22242223
)
22252224

22262225
turn = self.provider.value_turn(
2227-
response, has_data_model=data_model is not None
2226+
result, has_data_model=data_model is not None
22282227
)
22292228
if turn.text:
22302229
emit(turn.text)
@@ -2233,6 +2232,9 @@ def emit(text: str | Content):
22332232
if echo == "all":
22342233
emit_other_contents(turn, emit)
22352234

2235+
turn.tokens = self.provider.value_tokens(result)
2236+
if turn.tokens is not None:
2237+
tokens_log(self.provider, turn.tokens)
22362238
self._turns.extend([user_turn, turn])
22372239

22382240
async def _submit_turns_async(
@@ -2277,7 +2279,7 @@ def emit(text: str | Content):
22772279
emit_other_contents(turn, emit)
22782280

22792281
else:
2280-
response = await self.provider.chat_perform_async(
2282+
result = await self.provider.chat_perform_async(
22812283
stream=False,
22822284
turns=[*self._turns, user_turn],
22832285
tools=self._tools,
@@ -2286,7 +2288,7 @@ def emit(text: str | Content):
22862288
)
22872289

22882290
turn = self.provider.value_turn(
2289-
response, has_data_model=data_model is not None
2291+
result, has_data_model=data_model is not None
22902292
)
22912293
if turn.text:
22922294
emit(turn.text)
@@ -2295,6 +2297,9 @@ def emit(text: str | Content):
22952297
if echo == "all":
22962298
emit_other_contents(turn, emit)
22972299

2300+
turn.tokens = self.provider.value_tokens(result)
2301+
if turn.tokens is not None:
2302+
tokens_log(self.provider, turn.tokens)
22982303
self._turns.extend([user_turn, turn])
22992304

23002305
def _invoke_tool(self, request: ContentToolRequest):

chatlas/_content.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def from_tool(cls, tool: "Tool") -> "ToolInfo":
125125
"tool_result_resource",
126126
"json",
127127
"pdf",
128+
"thinking",
128129
]
129130
"""
130131
A discriminated union of all content types.
@@ -682,6 +683,40 @@ def __repr__(self, indent: int = 0):
682683
return " " * indent + f"<ContentPDF size={len(self.data)}>"
683684

684685

686+
class ContentThinking(Content):
687+
"""
688+
Thinking/reasoning content
689+
690+
This content type represents reasoning traces from models that support
691+
extended thinking (like OpenAI's o-series models). The thinking content
692+
is not meant to be sent back to the model but is useful for debugging
693+
and understanding the model's reasoning process.
694+
695+
Parameters
696+
----------
697+
thinking
698+
The thinking/reasoning text from the model.
699+
extra
700+
Additional metadata associated with the thinking content (e.g.,
701+
encrypted content, status information).
702+
"""
703+
704+
thinking: str
705+
extra: Optional[dict[str, Any]] = None
706+
707+
content_type: ContentTypeEnum = "thinking"
708+
709+
def __str__(self):
710+
return f"<thinking>\n{self.thinking}\n</thinking>\n"
711+
712+
def _repr_markdown_(self):
713+
return self.__str__()
714+
715+
def __repr__(self, indent: int = 0):
716+
preview = self.thinking[:50] + "..." if len(self.thinking) > 50 else self.thinking
717+
return " " * indent + f"<ContentThinking thinking='{preview}'>"
718+
719+
685720
ContentUnion = Union[
686721
ContentText,
687722
ContentImageRemote,
@@ -692,6 +727,7 @@ def __repr__(self, indent: int = 0):
692727
ContentToolResultResource,
693728
ContentJson,
694729
ContentPDF,
730+
ContentThinking,
695731
]
696732

697733

@@ -724,6 +760,8 @@ def create_content(data: dict[str, Any]) -> ContentUnion:
724760
return ContentJson.model_validate(data)
725761
elif ct == "pdf":
726762
return ContentPDF.model_validate(data)
763+
elif ct == "thinking":
764+
return ContentThinking.model_validate(data)
727765
else:
728766
raise ValueError(f"Unknown content type: {ct}")
729767

chatlas/_provider.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,12 @@ def value_turn(
249249
has_data_model: bool,
250250
) -> Turn: ...
251251

252+
@abstractmethod
253+
def value_tokens(
254+
self,
255+
completion: ChatCompletionDictT,
256+
) -> tuple[int, int, int] | None: ...
257+
252258
@abstractmethod
253259
def token_count(
254260
self,

chatlas/_provider_anthropic.py

Lines changed: 12 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
StandardModelParamNames,
3131
StandardModelParams,
3232
)
33-
from ._tokens import get_token_pricing, tokens_log
33+
from ._tokens import get_token_pricing
3434
from ._tools import Tool, basemodel_to_param_schema
3535
from ._turn import Turn, user_turn
3636
from ._utils import split_http_client_kwargs
@@ -242,28 +242,6 @@ def list_models(self):
242242

243243
return res
244244

245-
@overload
246-
def chat_perform(
247-
self,
248-
*,
249-
stream: Literal[False],
250-
turns: list[Turn],
251-
tools: dict[str, Tool],
252-
data_model: Optional[type[BaseModel]] = None,
253-
kwargs: Optional["SubmitInputArgs"] = None,
254-
): ...
255-
256-
@overload
257-
def chat_perform(
258-
self,
259-
*,
260-
stream: Literal[True],
261-
turns: list[Turn],
262-
tools: dict[str, Tool],
263-
data_model: Optional[type[BaseModel]] = None,
264-
kwargs: Optional["SubmitInputArgs"] = None,
265-
): ...
266-
267245
def chat_perform(
268246
self,
269247
*,
@@ -276,28 +254,6 @@ def chat_perform(
276254
kwargs = self._chat_perform_args(stream, turns, tools, data_model, kwargs)
277255
return self._client.messages.create(**kwargs) # type: ignore
278256

279-
@overload
280-
async def chat_perform_async(
281-
self,
282-
*,
283-
stream: Literal[False],
284-
turns: list[Turn],
285-
tools: dict[str, Tool],
286-
data_model: Optional[type[BaseModel]] = None,
287-
kwargs: Optional["SubmitInputArgs"] = None,
288-
): ...
289-
290-
@overload
291-
async def chat_perform_async(
292-
self,
293-
*,
294-
stream: Literal[True],
295-
turns: list[Turn],
296-
tools: dict[str, Tool],
297-
data_model: Optional[type[BaseModel]] = None,
298-
kwargs: Optional["SubmitInputArgs"] = None,
299-
): ...
300-
301257
async def chat_perform_async(
302258
self,
303259
*,
@@ -411,6 +367,17 @@ def stream_turn(self, completion, has_data_model) -> Turn:
411367
def value_turn(self, completion, has_data_model) -> Turn:
412368
return self._as_turn(completion, has_data_model)
413369

370+
def value_tokens(self, completion):
371+
usage = completion.usage
372+
# N.B. Currently, Anthropic doesn't cache by default and we currently do not support
373+
# manual caching in chatlas. Note also that this only tracks reads, NOT writes, which
374+
# have their own cost. To track that properly, we would need another caching category and per-token cost.
375+
return (
376+
completion.usage.input_tokens,
377+
completion.usage.output_tokens,
378+
usage.cache_read_input_tokens if usage.cache_read_input_tokens else 0,
379+
)
380+
414381
def token_count(
415382
self,
416383
*args: Content | str,
@@ -619,23 +586,9 @@ def _as_turn(self, completion: Message, has_data_model=False) -> Turn:
619586
)
620587
)
621588

622-
usage = completion.usage
623-
# N.B. Currently, Anthropic doesn't cache by default and we currently do not support
624-
# manual caching in chatlas. Note also that this only tracks reads, NOT writes, which
625-
# have their own cost. To track that properly, we would need another caching category and per-token cost.
626-
627-
tokens = (
628-
completion.usage.input_tokens,
629-
completion.usage.output_tokens,
630-
usage.cache_read_input_tokens if usage.cache_read_input_tokens else 0,
631-
)
632-
633-
tokens_log(self, tokens)
634-
635589
return Turn(
636590
"assistant",
637591
contents,
638-
tokens=tokens,
639592
finish_reason=completion.stop_reason,
640593
completion=completion,
641594
)

chatlas/_provider_github.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from ._chat import Chat
99
from ._logging import log_model_default
10-
from ._provider_openai import ModelInfo, OpenAIProvider
10+
from ._provider import ModelInfo
11+
from ._provider_openai import OpenAIProvider
1112
from ._utils import MISSING, MISSING_TYPE, is_testing
1213

1314
if TYPE_CHECKING:

chatlas/_provider_google.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ._logging import log_model_default
2323
from ._merge import merge_dicts
2424
from ._provider import ModelInfo, Provider, StandardModelParamNames, StandardModelParams
25-
from ._tokens import get_token_pricing, tokens_log
25+
from ._tokens import get_token_pricing
2626
from ._tools import Tool
2727
from ._turn import Turn, user_turn
2828

@@ -228,6 +228,7 @@ def chat_perform(
228228

229229
def chat_perform(
230230
self,
231+
*,
231232
stream: bool,
232233
turns: list[Turn],
233234
tools: dict[str, Tool],
@@ -264,6 +265,7 @@ async def chat_perform_async(
264265

265266
async def chat_perform_async(
266267
self,
268+
*,
267269
stream: bool,
268270
turns: list[Turn],
269271
tools: dict[str, Tool],
@@ -349,6 +351,17 @@ def value_turn(self, completion, has_data_model) -> Turn:
349351
completion = cast("GenerateContentResponseDict", completion.model_dump())
350352
return self._as_turn(completion, has_data_model)
351353

354+
def value_tokens(self, completion):
355+
usage = completion.get("usage_metadata")
356+
if usage is None:
357+
return None
358+
cached = usage.get("cached_content_token_count") or 0
359+
return (
360+
(usage.get("prompt_token_count") or 0) - cached,
361+
usage.get("candidates_token_count") or 0,
362+
usage.get("cached_content_token_count") or 0,
363+
)
364+
352365
def token_count(
353366
self,
354367
*args: Content | str,
@@ -528,25 +541,12 @@ def _as_turn(
528541
)
529542
)
530543

531-
usage = message.get("usage_metadata")
532-
tokens = (0, 0, 0)
533-
if usage:
534-
cached = usage.get("cached_content_token_count") or 0
535-
tokens = (
536-
(usage.get("prompt_token_count") or 0) - cached,
537-
usage.get("candidates_token_count") or 0,
538-
usage.get("cached_content_token_count") or 0,
539-
)
540-
541-
tokens_log(self, tokens)
542-
543544
if isinstance(finish_reason, FinishReason):
544545
finish_reason = finish_reason.name
545546

546547
return Turn(
547548
"assistant",
548549
contents,
549-
tokens=tokens,
550550
finish_reason=finish_reason,
551551
completion=message,
552552
)

chatlas/_provider_ollama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import orjson
88

99
from ._chat import Chat
10-
from ._provider_openai import ModelInfo, OpenAIProvider
10+
from ._provider import ModelInfo
11+
from ._provider_openai import OpenAIProvider
1112
from ._utils import MISSING_TYPE, is_testing
1213

1314
if TYPE_CHECKING:

0 commit comments

Comments
 (0)