Skip to content

Commit

Permalink
Return ModelError when Run fails (#76)
Browse files Browse the repository at this point in the history
* Rename apierror.go to error.go

* Fix capitalization of error string

* Define ModelError type

* Return ModelError when Run fails to produce output

* Add test coverage for Run method
  • Loading branch information
mattt authored Aug 22, 2024
1 parent 972c92e commit 3c5fd6b
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 2 deletions.
90 changes: 90 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,96 @@ func TestWaitAsync(t *testing.T) {
assert.Equal(t, replicate.Succeeded, lastStatus)
}

func TestRun(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/predictions":
assert.Equal(t, http.MethodPost, r.Method)
prediction := replicate.Prediction{
ID: "gtsllfynndufawqhdngldkdrkq",
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
Status: replicate.Starting,
}
json.NewEncoder(w).Encode(prediction)
case "/predictions/gtsllfynndufawqhdngldkdrkq":
assert.Equal(t, http.MethodGet, r.Method)
prediction := replicate.Prediction{
ID: "gtsllfynndufawqhdngldkdrkq",
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
Status: replicate.Succeeded,
Output: "Hello, world!",
}
json.NewEncoder(w).Encode(prediction)
default:
t.Fatalf("Unexpected request to %s", r.URL.Path)
}
}))
defer mockServer.Close()

client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx := context.Background()
input := replicate.PredictionInput{"prompt": "Hello"}
output, err := client.Run(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil)

require.NoError(t, err)
assert.NotNil(t, output)
assert.Equal(t, "Hello, world!", output)
}

func TestRunReturningModelError(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/predictions":
assert.Equal(t, http.MethodPost, r.Method)
prediction := replicate.Prediction{
ID: "fynndufawqhdngldkgtslldrkq",
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
Status: replicate.Starting,
}
json.NewEncoder(w).Encode(prediction)
case "/predictions/fynndufawqhdngldkgtslldrkq":
assert.Equal(t, http.MethodGet, r.Method)

logs := "Could not say hello"
prediction := replicate.Prediction{
ID: "fynndufawqhdngldkgtslldrkq",
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
Status: replicate.Failed,
Logs: &logs,
Error: "Model execution failed",
}
json.NewEncoder(w).Encode(prediction)
default:
t.Fatalf("Unexpected request to %s", r.URL.Path)
}
}))
defer mockServer.Close()

client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx := context.Background()
input := replicate.PredictionInput{"prompt": "Hello"}
_, err = client.Run(ctx, "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil)

require.Error(t, err)
modelErr, ok := err.(*replicate.ModelError)
require.True(t, ok, "Expected error to be of type *replicate.ModelError")
assert.Equal(t, "model error: Model execution failed", modelErr.Error())
assert.Equal(t, "fynndufawqhdngldkgtslldrkq", modelErr.Prediction.ID)
assert.Equal(t, replicate.Failed, modelErr.Prediction.Status)
assert.Equal(t, "Model execution failed", modelErr.Prediction.Error)
assert.Equal(t, "Could not say hello", *modelErr.Prediction.Logs)
}

func TestCreateTraining(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
Expand Down
15 changes: 14 additions & 1 deletion apierror.go → error.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (e APIError) Error() string {

output := strings.Join(components, ": ")
if output == "" {
output = "Unknown error"
output = "unknown error"
}

if e.Instance != "" {
Expand All @@ -78,3 +78,16 @@ func (e *APIError) WriteHTTPResponse(w http.ResponseWriter) {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

// ModelError represents an error returned by a model for a failed prediction.
type ModelError struct {
Prediction *Prediction `json:"prediction"`
}

func (e *ModelError) Error() string {
if e.Prediction == nil {
return "unknown model error"
}

return fmt.Sprintf("model error: %s", e.Prediction.Error)
}
9 changes: 8 additions & 1 deletion run.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ func (r *Client) Run(ctx context.Context, identifier string, input PredictionInp
}

err = r.Wait(ctx, prediction)
if err != nil {
return nil, err
}

if prediction.Error != nil {
return nil, &ModelError{Prediction: prediction}
}

return prediction.Output, err
return prediction.Output, nil
}

0 comments on commit 3c5fd6b

Please sign in to comment.