diff --git a/polling.go b/polling.go index f6b4a2e8..7ca22103 100644 --- a/polling.go +++ b/polling.go @@ -32,7 +32,7 @@ func (r *VectorStoreFileService) PollStatus(ctx context.Context, vectorStoreID s opts = append(opts, mkPollingOptions(pollIntervalMs)...) opts = append(opts, option.WithResponseInto(&raw)) for { - file, err := r.Get(ctx, fileID, vectorStoreID, opts...) + file, err := r.Get(ctx, vectorStoreID, fileID, opts...) if err != nil { return nil, fmt.Errorf("vector store file poll: received %w", err) } diff --git a/vectorstorefile.go b/vectorstorefile.go index 8a389a99..77f089b6 100644 --- a/vectorstorefile.go +++ b/vectorstorefile.go @@ -61,12 +61,12 @@ func (r *VectorStoreFileService) New(ctx context.Context, vectorStoreID string, // // Polls the API and blocks until the task is complete. // Default polling interval is 1 second. -func (r *VectorStoreFileService) NewAndPoll(ctx context.Context, vectorStoreId string, body VectorStoreFileNewParams, pollIntervalMs int, opts ...option.RequestOption) (res *VectorStoreFile, err error) { - file, err := r.New(ctx, vectorStoreId, body, opts...) +func (r *VectorStoreFileService) NewAndPoll(ctx context.Context, vectorStoreID string, body VectorStoreFileNewParams, pollIntervalMs int, opts ...option.RequestOption) (res *VectorStoreFile, err error) { + file, err := r.New(ctx, vectorStoreID, body, opts...) if err != nil { return nil, err } - return r.PollStatus(ctx, vectorStoreId, file.ID, pollIntervalMs, opts...) + return r.PollStatus(ctx, vectorStoreID, file.ID, pollIntervalMs, opts...) } // Upload a file to the `files` API and then attach it to the given vector store. diff --git a/vectorstorefile_test.go b/vectorstorefile_test.go index 17d7626b..e5558736 100644 --- a/vectorstorefile_test.go +++ b/vectorstorefile_test.go @@ -5,7 +5,11 @@ package openai_test import ( "context" "errors" + "fmt" + "io" + "net/http" "os" + "strings" "testing" "github.com/openai/openai-go/v3" @@ -191,3 +195,79 @@ func TestVectorStoreFileContent(t *testing.T) { t.Fatalf("err should be nil: %s", err.Error()) } } + +func TestVectorStoreFilePollStatus(t *testing.T) { + var capturedURLs []string + callCount := 0 + + client := openai.NewClient( + option.WithAPIKey("My API Key"), + option.WithHTTPClient(&http.Client{ + Transport: &vectorStoreFileClosureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + capturedURLs = append(capturedURLs, req.URL.String()) + callCount++ + + var status openai.VectorStoreFileStatus + if callCount < 3 { + status = openai.VectorStoreFileStatusInProgress + } else { + status = openai.VectorStoreFileStatusCompleted + } + + responseBody := fmt.Sprintf(`{ + "id": "file-abc123", + "object": "vector_store.file", + "status": "%s", + "vector_store_id": "vs_abc123", + "created_at": 1234567890, + "usage_bytes": 1024 + }`, status) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(responseBody)), + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "openai-poll-after-ms": []string{"10"}, + }, + }, nil + }, + }, + }), + ) + + file, err := client.VectorStores.Files.PollStatus( + context.Background(), + "vs_abc123", + "file-abc123", + 10, + ) + + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + if file.Status != openai.VectorStoreFileStatusCompleted { + t.Errorf("expected status to be completed, got: %s", file.Status) + } + + expectedURL := "https://api.openai.com/v1/vector_stores/vs_abc123/files/file-abc123" + for _, url := range capturedURLs { + if url != expectedURL { + t.Errorf("expected URL %s, got: %s", expectedURL, url) + } + } + + if callCount != 3 { + t.Errorf("expected 3 calls, got: %d", callCount) + } +} + +type vectorStoreFileClosureTransport struct { + fn func(req *http.Request) (*http.Response, error) +} + +func (t *vectorStoreFileClosureTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.fn(req) +}