Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion polling.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (r *VectorStoreFileService) PollStatus(ctx context.Context, vectorStoreID s
opts = append(opts, mkPollingOptions(pollIntervalMs)...)
opts = append(opts, option.WithResponseInto(&raw))
for {
file, err := r.Get(ctx, fileID, vectorStoreID, opts...)
file, err := r.Get(ctx, vectorStoreID, fileID, opts...)
if err != nil {
return nil, fmt.Errorf("vector store file poll: received %w", err)
}
Expand Down
6 changes: 3 additions & 3 deletions vectorstorefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ func (r *VectorStoreFileService) New(ctx context.Context, vectorStoreID string,
//
// Polls the API and blocks until the task is complete.
// Default polling interval is 1 second.
func (r *VectorStoreFileService) NewAndPoll(ctx context.Context, vectorStoreId string, body VectorStoreFileNewParams, pollIntervalMs int, opts ...option.RequestOption) (res *VectorStoreFile, err error) {
file, err := r.New(ctx, vectorStoreId, body, opts...)
func (r *VectorStoreFileService) NewAndPoll(ctx context.Context, vectorStoreID string, body VectorStoreFileNewParams, pollIntervalMs int, opts ...option.RequestOption) (res *VectorStoreFile, err error) {
file, err := r.New(ctx, vectorStoreID, body, opts...)
if err != nil {
return nil, err
}
return r.PollStatus(ctx, vectorStoreId, file.ID, pollIntervalMs, opts...)
return r.PollStatus(ctx, vectorStoreID, file.ID, pollIntervalMs, opts...)
}

// Upload a file to the `files` API and then attach it to the given vector store.
Expand Down
80 changes: 80 additions & 0 deletions vectorstorefile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ package openai_test
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
"testing"

"github.com/openai/openai-go/v3"
Expand Down Expand Up @@ -191,3 +195,79 @@ func TestVectorStoreFileContent(t *testing.T) {
t.Fatalf("err should be nil: %s", err.Error())
}
}

func TestVectorStoreFilePollStatus(t *testing.T) {
var capturedURLs []string
callCount := 0

client := openai.NewClient(
option.WithAPIKey("My API Key"),
option.WithHTTPClient(&http.Client{
Transport: &vectorStoreFileClosureTransport{
fn: func(req *http.Request) (*http.Response, error) {
capturedURLs = append(capturedURLs, req.URL.String())
callCount++

var status openai.VectorStoreFileStatus
if callCount < 3 {
status = openai.VectorStoreFileStatusInProgress
} else {
status = openai.VectorStoreFileStatusCompleted
}

responseBody := fmt.Sprintf(`{
"id": "file-abc123",
"object": "vector_store.file",
"status": "%s",
"vector_store_id": "vs_abc123",
"created_at": 1234567890,
"usage_bytes": 1024
}`, status)

return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(responseBody)),
Header: http.Header{
"Content-Type": []string{"application/json"},
"openai-poll-after-ms": []string{"10"},
},
}, nil
},
},
}),
)

file, err := client.VectorStores.Files.PollStatus(
context.Background(),
"vs_abc123",
"file-abc123",
10,
)

if err != nil {
t.Fatalf("expected no error, got: %v", err)
}

if file.Status != openai.VectorStoreFileStatusCompleted {
t.Errorf("expected status to be completed, got: %s", file.Status)
}

expectedURL := "https://api.openai.com/v1/vector_stores/vs_abc123/files/file-abc123"
for _, url := range capturedURLs {
if url != expectedURL {
t.Errorf("expected URL %s, got: %s", expectedURL, url)
}
}

if callCount != 3 {
t.Errorf("expected 3 calls, got: %d", callCount)
}
}

type vectorStoreFileClosureTransport struct {
fn func(req *http.Request) (*http.Response, error)
}

func (t *vectorStoreFileClosureTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.fn(req)
}