Skip to content

Commit e3ad4b4

Browse files
authored
feat(go): use GenkitError instead of regular error (#2643)
1 parent 0fc448e commit e3ad4b4

File tree

13 files changed

+329
-123
lines changed

13 files changed

+329
-123
lines changed

go/ai/gen.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -181,23 +181,6 @@ type GenerationUsage struct {
181181
TotalTokens int `json:"totalTokens,omitempty"`
182182
}
183183

184-
type GenkitError struct {
185-
Data *GenkitErrorData `json:"data,omitempty"`
186-
Details any `json:"details,omitempty"`
187-
Message string `json:"message,omitempty"`
188-
Stack string `json:"stack,omitempty"`
189-
}
190-
191-
type GenkitErrorData struct {
192-
GenkitErrorDetails *GenkitErrorDetails `json:"genkitErrorDetails,omitempty"`
193-
GenkitErrorMessage string `json:"genkitErrorMessage,omitempty"`
194-
}
195-
196-
type GenkitErrorDetails struct {
197-
Stack string `json:"stack,omitempty"`
198-
TraceID string `json:"traceId,omitempty"`
199-
}
200-
201184
type Media struct {
202185
ContentType string `json:"contentType,omitempty"`
203186
Url string `json:"url,omitempty"`

go/ai/generate.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func LookupModel(r *registry.Registry, provider, name string) Model {
153153
// It returns an error if the model was not defined.
154154
func LookupModelByName(r *registry.Registry, modelName string) (Model, error) {
155155
if modelName == "" {
156-
return nil, errors.New("ai.LookupModelByName: model not specified")
156+
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.LookupModelByName: model not specified")
157157
}
158158

159159
provider, name, found := strings.Cut(modelName, "/")
@@ -165,9 +165,9 @@ func LookupModelByName(r *registry.Registry, modelName string) (Model, error) {
165165
model := LookupModel(r, provider, name)
166166
if model == nil {
167167
if provider == "" {
168-
return nil, fmt.Errorf("ai.LookupModelByName: no model named %q", name)
168+
return nil, core.NewError(core.NOT_FOUND, "ai.LookupModelByName: model %q not found", name)
169169
}
170-
return nil, fmt.Errorf("ai.LookupModelByName: no model named %q for provider %q", name, provider)
170+
return nil, core.NewError(core.NOT_FOUND, "ai.LookupModelByName: model %q by provider %q not found", name, provider)
171171
}
172172

173173
return model, nil
@@ -180,7 +180,7 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
180180
opts.Model = defaultModel
181181
}
182182
if opts.Model == "" {
183-
return nil, errors.New("ai.GenerateWithRequest: model is required")
183+
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: model is required")
184184
}
185185
}
186186

@@ -193,12 +193,12 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
193193
toolDefMap := make(map[string]*ToolDefinition)
194194
for _, t := range opts.Tools {
195195
if _, ok := toolDefMap[t]; ok {
196-
return nil, fmt.Errorf("ai.GenerateWithRequest: duplicate tool found: %q", t)
196+
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: duplicate tool %q", t)
197197
}
198198

199199
tool := LookupTool(r, t)
200200
if tool == nil {
201-
return nil, fmt.Errorf("ai.GenerateWithRequest: tool not found: %q", t)
201+
return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: tool %q not found", t)
202202
}
203203

204204
toolDefMap[t] = tool.Definition()
@@ -210,7 +210,7 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
210210

211211
maxTurns := opts.MaxTurns
212212
if maxTurns < 0 {
213-
return nil, fmt.Errorf("ai.GenerateWithRequest: max turns must be greater than 0, got %d", maxTurns)
213+
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: max turns must be greater than 0, got %d", maxTurns)
214214
}
215215
if maxTurns == 0 {
216216
maxTurns = 5 // Default max turns.
@@ -276,7 +276,8 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
276276
resp.Message, err = formatHandler.ParseMessage(resp.Message)
277277
if err != nil {
278278
logger.FromContext(ctx).Debug("model failed to generate output matching expected schema", "error", err.Error())
279-
return nil, fmt.Errorf("model failed to generate output matching expected schema: %w", err)
279+
return nil, core.NewError(core.INTERNAL, "model failed to generate output matching expected schema: %v", err)
280+
280281
}
281282
}
282283

@@ -291,7 +292,7 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
291292
}
292293

293294
if currentTurn+1 > maxTurns {
294-
return nil, fmt.Errorf("exceeded maximum tool call iterations (%d)", maxTurns)
295+
return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns)
295296
}
296297

297298
newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, cb)
@@ -318,7 +319,7 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
318319
genOpts := &generateOptions{}
319320
for _, opt := range opts {
320321
if err := opt.applyGenerate(genOpts); err != nil {
321-
return nil, fmt.Errorf("ai.Generate: error applying options: %w", err)
322+
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.Generate: error applying options: %v", err)
322323
}
323324
}
324325

@@ -421,7 +422,7 @@ func (m *model) Name() string {
421422
// Generate applies the [Action] to provided request.
422423
func (m *model) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
423424
if m == nil {
424-
return nil, errors.New("Model.Generate: generate called on a nil model; check that all models are defined")
425+
return nil, core.NewError(core.INVALID_ARGUMENT, "Model.Generate: generate called on a nil model; check that all models are defined")
425426
}
426427

427428
return (*core.ActionDef[*ModelRequest, *ModelResponse, *ModelResponseChunk])(m).Run(ctx, req, cb)
@@ -478,12 +479,12 @@ func cloneMessage(m *Message) *Message {
478479
panic(fmt.Sprintf("failed to marshal message: %v", err))
479480
}
480481

481-
var copy Message
482-
if err := json.Unmarshal(bytes, &copy); err != nil {
482+
var msgCopy Message
483+
if err := json.Unmarshal(bytes, &msgCopy); err != nil {
483484
panic(fmt.Sprintf("failed to unmarshal message: %v", err))
484485
}
485486

486-
return &copy
487+
return &msgCopy
487488
}
488489

489490
// handleToolRequests processes any tool requests in the response, returning
@@ -520,7 +521,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq
520521
toolReq := p.ToolRequest
521522
tool := LookupTool(r, toolReq.Name)
522523
if tool == nil {
523-
resultChan <- toolResult{idx, nil, fmt.Errorf("tool %q not found", toolReq.Name)}
524+
resultChan <- toolResult{idx, nil, core.NewError(core.NOT_FOUND, "tool %q not found", toolReq.Name)}
524525
return
525526
}
526527

@@ -538,7 +539,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq
538539
resultChan <- toolResult{idx, nil, interruptErr}
539540
return
540541
}
541-
resultChan <- toolResult{idx, nil, fmt.Errorf("tool %q failed: %w", toolReq.Name, err)}
542+
resultChan <- toolResult{idx, nil, core.NewError(core.INTERNAL, "tool %q failed: %v", toolReq.Name, err)}
542543
return
543544
}
544545

go/ai/model_middleware.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"strconv"
2828
"strings"
2929

30+
"github.com/firebase/genkit/go/core"
3031
"github.com/firebase/genkit/go/core/logger"
3132
)
3233

@@ -101,28 +102,28 @@ func validateSupport(model string, info *ModelInfo) ModelMiddleware {
101102
for _, msg := range input.Messages {
102103
for _, part := range msg.Content {
103104
if part.IsMedia() {
104-
return nil, fmt.Errorf("model %q does not support media, but media was provided. Request: %+v", model, input)
105+
return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support media, but media was provided. Request: %+v", model, input)
105106
}
106107
}
107108
}
108109
}
109110

110111
if !info.Supports.Tools && len(input.Tools) > 0 {
111-
return nil, fmt.Errorf("model %q does not support tool use, but tools were provided. Request: %+v", model, input)
112+
return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support tool use, but tools were provided. Request: %+v", model, input)
112113
}
113114

114115
if !info.Supports.Multiturn && len(input.Messages) > 1 {
115-
return nil, fmt.Errorf("model %q does not support multiple messages, but %d were provided. Request: %+v", model, len(input.Messages), input)
116+
return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support multiple messages, but %d were provided. Request: %+v", model, len(input.Messages), input)
116117
}
117118

118119
if !info.Supports.ToolChoice && input.ToolChoice != "" && input.ToolChoice != ToolChoiceAuto {
119-
return nil, fmt.Errorf("model %q does not support tool choice, but tool choice was provided. Request: %+v", model, input)
120+
return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support tool choice, but tool choice was provided. Request: %+v", model, input)
120121
}
121122

122123
if !info.Supports.SystemRole {
123124
for _, msg := range input.Messages {
124125
if msg.Role == RoleSystem {
125-
return nil, fmt.Errorf("model %q does not support system role, but system role was provided. Request: %+v", model, input)
126+
return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support system role, but system role was provided. Request: %+v", model, input)
126127
}
127128
}
128129
}
@@ -140,7 +141,7 @@ func validateSupport(model string, info *ModelInfo) ModelMiddleware {
140141
info.Supports.Constrained == ConstrainedSupportNone ||
141142
(info.Supports.Constrained == ConstrainedSupportNoTools && len(input.Tools) > 0)) &&
142143
input.Output != nil && input.Output.Constrained {
143-
return nil, fmt.Errorf("model %q does not support native constrained output, but constrained output was requested. Request: %+v", model, input)
144+
return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support native constrained output, but constrained output was requested. Request: %+v", model, input)
144145
}
145146

146147
if err := validateVersion(model, info.Versions, input.Config); err != nil {
@@ -176,14 +177,14 @@ func validateVersion(model string, versions []string, config any) error {
176177

177178
version, ok := versionVal.(string)
178179
if !ok {
179-
return fmt.Errorf("version must be a string, got %T", versionVal)
180+
return core.NewError(core.INVALID_ARGUMENT, "version must be a string, got %T", versionVal)
180181
}
181182

182183
if slices.Contains(versions, version) {
183184
return nil
184185
}
185186

186-
return fmt.Errorf("model %q does not support version %q, supported versions: %v", model, version, versions)
187+
return core.NewError(core.INVALID_ARGUMENT, "model %q does not support version %q, supported versions: %v", model, version, versions)
187188
}
188189

189190
// ContextItemTemplate is the default item template for context augmentation.
@@ -302,13 +303,13 @@ func DownloadRequestMedia(options *DownloadMediaOptions) ModelMiddleware {
302303

303304
resp, err := client.Get(mediaUrl)
304305
if err != nil {
305-
return nil, fmt.Errorf("HTTP error downloading media %q: %w", mediaUrl, err)
306+
return nil, core.NewError(core.INVALID_ARGUMENT, "HTTP error downloading media %q: %v", mediaUrl, err)
306307
}
307308
defer resp.Body.Close()
308309

309310
if resp.StatusCode != http.StatusOK {
310311
body, _ := io.ReadAll(resp.Body)
311-
return nil, fmt.Errorf("HTTP error downloading media %q: %s", mediaUrl, string(body))
312+
return nil, core.NewError(core.UNKNOWN, "HTTP error downloading media %q: %s", mediaUrl, string(body))
312313
}
313314

314315
contentType := part.ContentType
@@ -324,7 +325,7 @@ func DownloadRequestMedia(options *DownloadMediaOptions) ModelMiddleware {
324325
data, err = io.ReadAll(resp.Body)
325326
}
326327
if err != nil {
327-
return nil, fmt.Errorf("error reading media %q: %v", mediaUrl, err)
328+
return nil, core.NewError(core.UNKNOWN, "error reading media %q: %v", mediaUrl, err)
328329
}
329330

330331
message.Content[j] = NewMediaPart(contentType, fmt.Sprintf("data:%s;base64,%s", contentType, base64.StdEncoding.EncodeToString(data)))

go/core/action.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23-
"net/http"
2423
"reflect"
2524
"time"
2625

@@ -213,7 +212,7 @@ func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb Strea
213212
func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) {
214213
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
215214
if err := base.ValidateJSON(input, a.inputSchema); err != nil {
216-
return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err}
215+
return nil, NewError(INVALID_ARGUMENT, err.Error())
217216
}
218217
var in In
219218
if input != nil {

go/core/error.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// SPDX-License-Identifier: Apache-2.0
16+
17+
// Package core provides base error types and utilities for Genkit.
18+
package core
19+
20+
import (
21+
"fmt"
22+
"runtime/debug"
23+
)
24+
25+
type ReflectionErrorDetails struct {
26+
Stack *string `json:"stack,omitempty"` // Use pointer for optional
27+
TraceID *string `json:"traceId,omitempty"`
28+
}
29+
30+
// ReflectionError is the wire format for HTTP errors for Reflection API responses.
31+
type ReflectionError struct {
32+
Details *ReflectionErrorDetails `json:"details,omitempty"`
33+
Message string `json:"message"`
34+
Code int `json:"code"`
35+
}
36+
37+
// GenkitError is the base error type for Genkit errors.
38+
type GenkitError struct {
39+
Message string `json:"message"` // Exclude from default JSON if embedded elsewhere
40+
Status StatusName `json:"status"`
41+
HTTPCode int `json:"-"` // Exclude from default JSON
42+
Details map[string]any `json:"details"` // Use map for arbitrary details
43+
Source *string `json:"source,omitempty"` // Pointer for optional
44+
}
45+
46+
// UserFacingError is the base error type for user facing errors.
47+
type UserFacingError struct {
48+
Message string `json:"message"` // Exclude from default JSON if embedded elsewhere
49+
Status StatusName `json:"status"`
50+
Details map[string]any `json:"details"` // Use map for arbitrary details
51+
}
52+
53+
// NewPublicError allows a web framework handler to know it
54+
// is safe to return the message in a request. Other kinds of errors will
55+
// result in a generic 500 message to avoid the possibility of internal
56+
// exceptions being leaked to attackers.
57+
func NewPublicError(status StatusName, message string, details map[string]any) *UserFacingError {
58+
return &UserFacingError{
59+
Status: status,
60+
Details: details,
61+
Message: message,
62+
}
63+
}
64+
65+
// Error implements the standard error interface for UserFacingError.
66+
func (e *UserFacingError) Error() string {
67+
return fmt.Sprintf("%s: %s", e.Status, e.Message)
68+
}
69+
70+
// NewError creates a new GenkitError with a stack trace.
71+
func NewError(status StatusName, message string, args ...any) *GenkitError {
72+
// Prevents a compile-time warning about non-constant message.
73+
msg := message
74+
75+
ge := &GenkitError{
76+
Status: status,
77+
Message: fmt.Sprintf(msg, args...),
78+
}
79+
80+
errStack := string(debug.Stack())
81+
if errStack != "" {
82+
ge.Details = make(map[string]any)
83+
ge.Details["stack"] = errStack
84+
}
85+
return ge
86+
}
87+
88+
// Error implements the standard error interface.
89+
func (e *GenkitError) Error() string {
90+
return e.Message
91+
}
92+
93+
// ToReflectionError returns a JSON-serializable representation for reflection API responses.
94+
func (e *GenkitError) ToReflectionError() ReflectionError {
95+
errDetails := &ReflectionErrorDetails{}
96+
if stackVal, ok := e.Details["stack"].(string); ok {
97+
errDetails.Stack = &stackVal
98+
}
99+
if traceVal, ok := e.Details["traceId"].(string); ok {
100+
errDetails.TraceID = &traceVal
101+
}
102+
return ReflectionError{
103+
Details: errDetails,
104+
Code: HTTPStatusCode(e.Status),
105+
Message: e.Message,
106+
}
107+
}
108+
109+
// ToReflectionError gets the JSON representation for reflection API Error responses.
110+
func ToReflectionError(err error) ReflectionError {
111+
if ge, ok := err.(*GenkitError); ok {
112+
return ge.ToReflectionError()
113+
}
114+
115+
return ReflectionError{
116+
Message: err.Error(),
117+
Code: HTTPStatusCode(INTERNAL),
118+
Details: &ReflectionErrorDetails{},
119+
}
120+
}

go/core/schemas.config

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,4 +272,6 @@ Score omit
272272

273273
Embedding.embedding type []float32
274274

275-
GenkitErrorDataGenkitErrorDetails name GenkitErrorDetails
275+
GenkitError omit
276+
GenkitErrorData omit
277+
GenkitErrorDataGenkitErrorDetails omit

0 commit comments

Comments
 (0)