11package openai
22
33import (
4+ "context"
45 "fmt"
6+ "net/http"
7+ "os"
58
9+ "github.com/baalimago/clai/internal/models"
610 "github.com/baalimago/clai/internal/text/generic"
711 pub_models "github.com/baalimago/clai/pkg/text/models"
12+ "github.com/baalimago/go_away_boilerplate/pkg/misc"
813)
914
1015var GptDefault = ChatGPT {
@@ -15,31 +20,114 @@ var GptDefault = ChatGPT{
1520}
1621
1722type ChatGPT struct {
18- generic.StreamCompleter
1923 Model string `json:"model"`
2024 FrequencyPenalty float64 `json:"frequency_penalty"`
2125 MaxTokens * int `json:"max_tokens"` // Use a pointer to allow null value
2226 PresencePenalty float64 `json:"presence_penalty"`
2327 Temperature float64 `json:"temperature"`
2428 TopP float64 `json:"top_p"`
2529 URL string `json:"url"`
30+
31+ apiKey string
32+ debug bool
33+
34+ tools []pub_models.LLMTool
35+ usage * pub_models.Usage
36+
37+ streamCompleter * generic.StreamCompleter
38+ useResponses bool
2639}
2740
2841func (g * ChatGPT ) Setup () error {
29- err := g . StreamCompleter . Setup ("OPENAI_API_KEY" , ChatURL , "DEBUG_OPENAI " )
30- if err != nil {
31- return fmt .Errorf ("failed to setup stream completer: %w" , err )
42+ apiKey := os . Getenv ("OPENAI_API_KEY" )
43+ if apiKey == "" {
44+ return fmt .Errorf ("openai: missing OPENAI_API_KEY" )
3245 }
33- g .StreamCompleter .Model = g .Model
34- g .StreamCompleter .FrequencyPenalty = & g .FrequencyPenalty
35- g .StreamCompleter .MaxTokens = g .MaxTokens
36- g .StreamCompleter .Temperature = & g .Temperature
37- g .StreamCompleter .TopP = & g .TopP
38- toolChoice := "auto"
39- g .ToolChoice = & toolChoice
46+ g .apiKey = apiKey
47+ g .debug = misc .Truthy (os .Getenv ("DEBUG_OPENAI" ))
48+ if g .Model == "" {
49+ g .Model = GptDefault .Model
50+ }
51+ url , useResponses := selectOpenAIURL (g .Model , g .URL )
52+ g .URL = url
53+ g .useResponses = useResponses
4054 return nil
4155}
4256
4357func (g * ChatGPT ) RegisterTool (tool pub_models.LLMTool ) {
44- g .InternalRegisterTool (tool )
58+ g .tools = append (g .tools , tool )
59+ }
60+
61+ func (g * ChatGPT ) TokenUsage () * pub_models.Usage {
62+ if g .useResponses {
63+ return g .usage
64+ }
65+ if g .streamCompleter == nil {
66+ return nil
67+ }
68+ return g .streamCompleter .TokenUsage ()
69+ }
70+
71+ func (g * ChatGPT ) setUsage (usage * pub_models.Usage ) error {
72+ g .usage = usage
73+ return nil
74+ }
75+
76+ func (g * ChatGPT ) StreamCompletions (ctx context.Context , chat pub_models.Chat ) (chan models.CompletionEvent , error ) {
77+ g .usage = nil
78+ url , useResponses := selectOpenAIURL (g .Model , g .URL )
79+ g .URL = url
80+ g .useResponses = useResponses
81+
82+ if g .useResponses {
83+ toolsMapped := make ([]responsesTool , 0 , len (g .tools ))
84+ for _ , t := range g .tools {
85+ spec := t .Specification ()
86+ toolsMapped = append (toolsMapped , responsesTool {
87+ Type : "function" ,
88+ Name : spec .Name ,
89+ Description : spec .Description ,
90+ Parameters : spec .Inputs ,
91+ })
92+ }
93+
94+ s := & responsesStreamer {
95+ apiKey : g .apiKey ,
96+ url : g .URL ,
97+ model : g .Model ,
98+ debug : g .debug ,
99+ client : http .DefaultClient ,
100+ tools : toolsMapped ,
101+ usageSetter : g .setUsage ,
102+ }
103+
104+ out , err := s .stream (ctx , chat )
105+ if err != nil {
106+ return nil , fmt .Errorf ("openai responses: stream: %w" , err )
107+ }
108+ return out , nil
109+ }
110+
111+ sc := & generic.StreamCompleter {}
112+ if err := sc .Setup ("OPENAI_API_KEY" , g .URL , "DEBUG_OPENAI" ); err != nil {
113+ return nil , fmt .Errorf ("openai chat: setup stream completer: %w" , err )
114+ }
115+ g .streamCompleter = sc
116+ g .streamCompleter .Model = g .Model
117+ g .streamCompleter .MaxTokens = g .MaxTokens
118+ g .streamCompleter .FrequencyPenalty = & g .FrequencyPenalty
119+ g .streamCompleter .PresencePenalty = & g .PresencePenalty
120+ g .streamCompleter .Temperature = & g .Temperature
121+ g .streamCompleter .TopP = & g .TopP
122+ toolChoice := "auto"
123+ g .streamCompleter .ToolChoice = & toolChoice
124+ for _ , tool := range g .tools {
125+ g .streamCompleter .InternalRegisterTool (tool )
126+ }
127+
128+ out , err := g .streamCompleter .StreamCompletions (ctx , chat )
129+ if err != nil {
130+ return nil , fmt .Errorf ("openai chat: stream: %w" , err )
131+ }
132+ return out , nil
45133}
0 commit comments