Skip to content

Commit f03ff04

Browse files
authored
Merge pull request #109 from baalimago/feat/openai-reponses-api
Seems like OpenAI doesn't release new gpt models anymore, just codex models. This PR adds support for their responses API, required for their codex models. Recommendation is to _not_ use `gpt-5.3-codex` as this one for some reason is abysmal at calling tools (in my limited testing). `gpt-5.2-codex` works pretty well though.
2 parents d85b7f3 + fbe8f63 commit f03ff04

22 files changed

Lines changed: 2180 additions & 33 deletions
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package generic
2+
3+
import (
4+
"testing"
5+
6+
"github.com/baalimago/clai/internal/models"
7+
"github.com/baalimago/clai/internal/tools"
8+
pub_models "github.com/baalimago/clai/pkg/text/models"
9+
)
10+
11+
type fakeTool struct{ spec pub_models.Specification }
12+
13+
func (f fakeTool) Call(input pub_models.Input) (string, error) { return "", nil }
14+
func (f fakeTool) Specification() pub_models.Specification { return f.spec }
15+
16+
// Validates that streamed OpenAI-style tool_calls are assembled, parsed as JSON,
17+
// and emitted as a pub_models.Call.
18+
func TestStreamCompleter_EmitsToolCall(t *testing.T) {
19+
s := &StreamCompleter{}
20+
21+
orig := tools.Registry
22+
tools.Registry = tools.NewRegistry()
23+
defer func() { tools.Registry = orig }()
24+
25+
tools.Registry.Set("mcp_everything_get-annotated-message", fakeTool{spec: pub_models.Specification{Name: "mcp_everything_get-annotated-message"}})
26+
27+
// Stream args in multiple chunks, like the real API does.
28+
first := Choice{Delta: Delta{ToolCalls: []ToolsCall{{
29+
ID: "id1",
30+
Index: 0,
31+
Type: "function",
32+
Function: Func{
33+
Name: "mcp_everything_get-annotated-message",
34+
Arguments: `{"messageType":"`,
35+
},
36+
}}}}
37+
ev := s.handleChoice(first)
38+
if _, ok := ev.(models.NoopEvent); !ok {
39+
t.Fatalf("expected NoopEvent, got %T: %#v", ev, ev)
40+
}
41+
42+
second := Choice{Delta: Delta{ToolCalls: []ToolsCall{{Function: Func{Arguments: `debug"`}}}}}
43+
ev = s.handleChoice(second)
44+
if _, ok := ev.(models.NoopEvent); !ok {
45+
t.Fatalf("expected NoopEvent, got %T: %#v", ev, ev)
46+
}
47+
48+
third := Choice{Delta: Delta{ToolCalls: []ToolsCall{{Function: Func{Arguments: `}`}}}}}
49+
ev = s.handleChoice(third)
50+
call, ok := ev.(pub_models.Call)
51+
if !ok {
52+
t.Fatalf("expected Call, got %T: %#v", ev, ev)
53+
}
54+
if call.Name != "mcp_everything_get-annotated-message" {
55+
t.Fatalf("unexpected tool name: %q", call.Name)
56+
}
57+
if call.Inputs == nil || (*call.Inputs)["messageType"] != "debug" {
58+
t.Fatalf("unexpected inputs: %#v", call.Inputs)
59+
}
60+
}

internal/tools/handler.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ func Init() {
3030
Registry.Set(tools.RipGrep.Specification().Name, tools.RipGrep)
3131
Registry.Set(tools.Go.Specification().Name, tools.Go)
3232
Registry.Set(tools.WriteFile.Specification().Name, tools.WriteFile)
33+
Registry.Set(tools.ApplyPatch.Specification().Name, tools.ApplyPatch)
3334
Registry.Set(tools.FreetextCmd.Specification().Name, tools.FreetextCmd)
3435
Registry.Set(tools.Sed.Specification().Name, tools.Sed)
3536
Registry.Set(tools.RowsBetween.Specification().Name, tools.RowsBetween)

internal/tools/registry_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,20 @@ func TestRegistry_Get(t *testing.T) {
8181
}
8282
}
8383

84+
func TestInitRegistersApplyPatch(t *testing.T) {
85+
origRegistry := Registry
86+
Registry = NewRegistry()
87+
Registry.hasBeenInit = false
88+
t.Cleanup(func() {
89+
Registry = origRegistry
90+
})
91+
92+
Init()
93+
if _, ok := Registry.Get("apply_patch"); !ok {
94+
t.Fatalf("expected apply_patch to be registered")
95+
}
96+
}
97+
8498
func TestRegistry_All(t *testing.T) {
8599
r := NewRegistry()
86100
tool1 := newMockTool("tool1")
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package tools
2+
3+
import (
4+
"sync"
5+
"testing"
6+
)
7+
8+
var testRegistryMu sync.Mutex
9+
10+
// WithTestRegistry replaces the global Registry for the duration of the test callback.
11+
func WithTestRegistry(t *testing.T, fn func()) {
12+
t.Helper()
13+
testRegistryMu.Lock()
14+
t.Cleanup(testRegistryMu.Unlock)
15+
16+
orig := Registry
17+
Registry = NewRegistry()
18+
t.Cleanup(func() { Registry = orig })
19+
fn()
20+
}
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package openai
22

33
const (
4-
ChatURL = "https://api.openai.com/v1/chat/completions"
5-
PhotoURL = "https://api.openai.com/v1/images/generations"
6-
VideoURL = "https://api.openai.com/v1/videos"
7-
FilesURL = "https://api.openai.com/v1/files"
4+
ChatURL = "https://api.openai.com/v1/chat/completions"
5+
ResponsesURL = "https://api.openai.com/v1/responses"
6+
PhotoURL = "https://api.openai.com/v1/images/generations"
7+
VideoURL = "https://api.openai.com/v1/videos"
8+
FilesURL = "https://api.openai.com/v1/files"
89
)

internal/vendors/openai/doc.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Package openai implements OpenAI vendor integrations (chat, photo, video).
2+
//
3+
// Streaming paths:
4+
// - Non-Codex models use Chat Completions (/v1/chat/completions).
5+
// - Codex-named models (model contains "codex", case-insensitive) use Responses (/v1/responses).
6+
//
7+
// Tool calling:
8+
// - Chat Completions are normalized via the generic stream completer.
9+
// - Responses streaming is parsed directly and emits the same normalized events.
10+
//
11+
// Usage accounting:
12+
// - Chat Completions uses the generic stream completer token usage.
13+
// - Responses sets usage from the responses stream metadata.
14+
package openai

internal/vendors/openai/gpt.go

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
package openai
22

33
import (
4+
"context"
45
"fmt"
6+
"net/http"
7+
"os"
58

9+
"github.com/baalimago/clai/internal/models"
610
"github.com/baalimago/clai/internal/text/generic"
711
pub_models "github.com/baalimago/clai/pkg/text/models"
12+
"github.com/baalimago/go_away_boilerplate/pkg/misc"
813
)
914

1015
var GptDefault = ChatGPT{
@@ -15,31 +20,114 @@ var GptDefault = ChatGPT{
1520
}
1621

1722
type ChatGPT struct {
18-
generic.StreamCompleter
1923
Model string `json:"model"`
2024
FrequencyPenalty float64 `json:"frequency_penalty"`
2125
MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value
2226
PresencePenalty float64 `json:"presence_penalty"`
2327
Temperature float64 `json:"temperature"`
2428
TopP float64 `json:"top_p"`
2529
URL string `json:"url"`
30+
31+
apiKey string
32+
debug bool
33+
34+
tools []pub_models.LLMTool
35+
usage *pub_models.Usage
36+
37+
streamCompleter *generic.StreamCompleter
38+
useResponses bool
2639
}
2740

2841
func (g *ChatGPT) Setup() error {
29-
err := g.StreamCompleter.Setup("OPENAI_API_KEY", ChatURL, "DEBUG_OPENAI")
30-
if err != nil {
31-
return fmt.Errorf("failed to setup stream completer: %w", err)
42+
apiKey := os.Getenv("OPENAI_API_KEY")
43+
if apiKey == "" {
44+
return fmt.Errorf("openai: missing OPENAI_API_KEY")
3245
}
33-
g.StreamCompleter.Model = g.Model
34-
g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty
35-
g.StreamCompleter.MaxTokens = g.MaxTokens
36-
g.StreamCompleter.Temperature = &g.Temperature
37-
g.StreamCompleter.TopP = &g.TopP
38-
toolChoice := "auto"
39-
g.ToolChoice = &toolChoice
46+
g.apiKey = apiKey
47+
g.debug = misc.Truthy(os.Getenv("DEBUG_OPENAI"))
48+
if g.Model == "" {
49+
g.Model = GptDefault.Model
50+
}
51+
url, useResponses := selectOpenAIURL(g.Model, g.URL)
52+
g.URL = url
53+
g.useResponses = useResponses
4054
return nil
4155
}
4256

4357
func (g *ChatGPT) RegisterTool(tool pub_models.LLMTool) {
44-
g.InternalRegisterTool(tool)
58+
g.tools = append(g.tools, tool)
59+
}
60+
61+
func (g *ChatGPT) TokenUsage() *pub_models.Usage {
62+
if g.useResponses {
63+
return g.usage
64+
}
65+
if g.streamCompleter == nil {
66+
return nil
67+
}
68+
return g.streamCompleter.TokenUsage()
69+
}
70+
71+
func (g *ChatGPT) setUsage(usage *pub_models.Usage) error {
72+
g.usage = usage
73+
return nil
74+
}
75+
76+
func (g *ChatGPT) StreamCompletions(ctx context.Context, chat pub_models.Chat) (chan models.CompletionEvent, error) {
77+
g.usage = nil
78+
url, useResponses := selectOpenAIURL(g.Model, g.URL)
79+
g.URL = url
80+
g.useResponses = useResponses
81+
82+
if g.useResponses {
83+
toolsMapped := make([]responsesTool, 0, len(g.tools))
84+
for _, t := range g.tools {
85+
spec := t.Specification()
86+
toolsMapped = append(toolsMapped, responsesTool{
87+
Type: "function",
88+
Name: spec.Name,
89+
Description: spec.Description,
90+
Parameters: spec.Inputs,
91+
})
92+
}
93+
94+
s := &responsesStreamer{
95+
apiKey: g.apiKey,
96+
url: g.URL,
97+
model: g.Model,
98+
debug: g.debug,
99+
client: http.DefaultClient,
100+
tools: toolsMapped,
101+
usageSetter: g.setUsage,
102+
}
103+
104+
out, err := s.stream(ctx, chat)
105+
if err != nil {
106+
return nil, fmt.Errorf("openai responses: stream: %w", err)
107+
}
108+
return out, nil
109+
}
110+
111+
sc := &generic.StreamCompleter{}
112+
if err := sc.Setup("OPENAI_API_KEY", g.URL, "DEBUG_OPENAI"); err != nil {
113+
return nil, fmt.Errorf("openai chat: setup stream completer: %w", err)
114+
}
115+
g.streamCompleter = sc
116+
g.streamCompleter.Model = g.Model
117+
g.streamCompleter.MaxTokens = g.MaxTokens
118+
g.streamCompleter.FrequencyPenalty = &g.FrequencyPenalty
119+
g.streamCompleter.PresencePenalty = &g.PresencePenalty
120+
g.streamCompleter.Temperature = &g.Temperature
121+
g.streamCompleter.TopP = &g.TopP
122+
toolChoice := "auto"
123+
g.streamCompleter.ToolChoice = &toolChoice
124+
for _, tool := range g.tools {
125+
g.streamCompleter.InternalRegisterTool(tool)
126+
}
127+
128+
out, err := g.streamCompleter.StreamCompletions(ctx, chat)
129+
if err != nil {
130+
return nil, fmt.Errorf("openai chat: stream: %w", err)
131+
}
132+
return out, nil
45133
}

0 commit comments

Comments
 (0)