diff --git a/go/internal/doc-snippets/rag/main.go b/go/internal/doc-snippets/rag/main.go index 1d3fcd5083..d82f58dff6 100644 --- a/go/internal/doc-snippets/rag/main.go +++ b/go/internal/doc-snippets/rag/main.go @@ -50,7 +50,7 @@ func main() { log.Fatal(err) } - menuPDFIndexer, _, err := localvec.DefineIndexerAndRetriever( + docStore, _, err := localvec.DefineRetriever( g, "menuQA", localvec.Config{ @@ -100,7 +100,7 @@ func main() { } // Add chunks to the index. - err = ai.Index(ctx, menuPDFIndexer, ai.WithDocs(docs...)) + err = localvec.Index(ctx, docs, docStore) return nil, err }, ) @@ -155,7 +155,7 @@ func menuQA() { model := googlegenai.VertexAIModel(g, "gemini-1.5-flash") - _, menuPdfRetriever, err := localvec.DefineIndexerAndRetriever( + _, menuPdfRetriever, err := localvec.DefineRetriever( g, "menuQA", localvec.Config{ @@ -207,7 +207,7 @@ func customret() { log.Fatal(err) } - _, menuPDFRetriever, _ := localvec.DefineIndexerAndRetriever( + _, menuPDFRetriever, _ := localvec.DefineRetriever( g, "menuQA", localvec.Config{ diff --git a/go/plugins/localvec/localvec.go b/go/plugins/localvec/localvec.go index 4a64079511..fe2bee3672 100644 --- a/go/plugins/localvec/localvec.go +++ b/go/plugins/localvec/localvec.go @@ -49,14 +49,14 @@ type Config struct { // Init initializes the plugin. func Init() error { return nil } -// DefineIndexerAndRetriever defines an Indexer and Retriever that share the same underlying storage. -// The name uniquely identifies the the Indexer and Retriever in the registry. -func DefineIndexerAndRetriever(g *genkit.Genkit, name string, cfg Config) (ai.Indexer, ai.Retriever, error) { +// DefineRetriever defines a Retriever and docStore which is also used by the retriever. +// The name uniquely identifies the Retriever in the registry. +func DefineRetriever(g *genkit.Genkit, name string, cfg Config) (*DocStore, ai.Retriever, error) { ds, err := newDocStore(cfg.Dir, name, cfg.Embedder, cfg.EmbedderOptions) if err != nil { return nil, nil, err } - return genkit.DefineIndexer(g, provider, name, ds.index), + return ds, genkit.DefineRetriever(g, provider, name, ds.retrieve), nil } @@ -82,23 +82,23 @@ func Retriever(g *genkit.Genkit, name string) ai.Retriever { return genkit.LookupRetriever(g, provider, name) } -// docStore implements a local vector database. +// DocStore implements a local vector database. // This is based on js/plugins/dev-local-vectorstore/src/index.ts. -type docStore struct { - filename string - embedder ai.Embedder - embedderOptions any - data map[string]dbValue +type DocStore struct { + Filename string + Embedder ai.Embedder + EmbedderOptions any + Data map[string]DbValue } -// dbValue is the type of a document stored in the database. -type dbValue struct { +// DbValue is the type of a document stored in the database. +type DbValue struct { Doc *ai.Document `json:"doc"` Embedding []float32 `json:"embedding"` } // newDocStore returns a new ai.DocumentStore to register. -func newDocStore(dir, name string, embedder ai.Embedder, embedderOptions any) (*docStore, error) { +func newDocStore(dir, name string, embedder ai.Embedder, embedderOptions any) (*DocStore, error) { if dir == "" { dir = os.TempDir() } @@ -108,7 +108,7 @@ func newDocStore(dir, name string, embedder ai.Embedder, embedderOptions any) (* dbname := "__db_" + name + ".json" filename := filepath.Join(dir, dbname) f, err := os.Open(filename) - var data map[string]dbValue + var data map[string]DbValue if err != nil { if !errors.Is(err, fs.ErrNotExist) { return nil, err @@ -121,67 +121,15 @@ func newDocStore(dir, name string, embedder ai.Embedder, embedderOptions any) (* } } - ds := &docStore{ - filename: filename, - embedder: embedder, - embedderOptions: embedderOptions, - data: data, + ds := &DocStore{ + Filename: filename, + Embedder: embedder, + EmbedderOptions: embedderOptions, + Data: data, } return ds, nil } -// index indexes a document. -func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error { - ereq := &ai.EmbedRequest{ - Input: req.Documents, - Options: ds.embedderOptions, - } - eres, err := ds.embedder.Embed(ctx, ereq) - if err != nil { - return fmt.Errorf("localvec index embedding failed: %v", err) - } - for i, de := range eres.Embeddings { - id, err := docID(req.Documents[i]) - if err != nil { - return err - } - if _, ok := ds.data[id]; ok { - logger.FromContext(ctx).Debug("localvec skipping document because already present", "id", id) - continue - } - - if ds.data == nil { - ds.data = make(map[string]dbValue) - } - - ds.data[id] = dbValue{ - Doc: req.Documents[i], - Embedding: de.Embedding, - } - } - - // Update the file every time we add documents. - // We use a temporary file to avoid losing the original - // file, in case of a crash. - tmpname := ds.filename + ".tmp" - f, err := os.Create(tmpname) - if err != nil { - return err - } - encoder := json.NewEncoder(f) - if err := encoder.Encode(ds.data); err != nil { - return err - } - if err := f.Close(); err != nil { - return err - } - if err := os.Rename(tmpname, ds.filename); err != nil { - return err - } - - return nil -} - // RetrieverOptions may be passed in the Options field // of [ai.RetrieverRequest] to pass options to the retriever. // The Options field should be either nil or a value of type *RetrieverOptions. @@ -190,14 +138,14 @@ type RetrieverOptions struct { } // retrieve retrieves documents close to the argument. -func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { +func (ds *DocStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { // Use the embedder to convert the document we want to // retrieve into a vector. ereq := &ai.EmbedRequest{ Input: []*ai.Document{req.Query}, - Options: ds.embedderOptions, + Options: ds.EmbedderOptions, } - eres, err := ds.embedder.Embed(ctx, ereq) + eres, err := ds.Embedder.Embed(ctx, ereq) if err != nil { return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err) } @@ -207,8 +155,8 @@ func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai score float64 doc *ai.Document } - scoredDocs := make([]scoredDoc, 0, len(ds.data)) - for _, dbv := range ds.data { + scoredDocs := make([]scoredDoc, 0, len(ds.Data)) + for _, dbv := range ds.Data { score := similarity(vals, dbv.Embedding) scoredDocs = append(scoredDocs, scoredDoc{ score: score, @@ -279,6 +227,58 @@ func similarity(vals1, vals2 []float32) float64 { return dot / (l1 * l2) } +// Helper function to get started with indexing +func Index(ctx context.Context, docs []*ai.Document, ds *DocStore) error { + ereq := &ai.EmbedRequest{ + Input: docs, + Options: ds.EmbedderOptions, + } + eres, err := ds.Embedder.Embed(ctx, ereq) + if err != nil { + return fmt.Errorf("localvec index embedding failed: %v", err) + } + for i, de := range eres.Embeddings { + id, err := docID(docs[i]) + if err != nil { + return err + } + if _, ok := ds.Data[id]; ok { + logger.FromContext(ctx).Debug("localvec skipping document because already present", "id", id) + continue + } + + if ds.Data == nil { + ds.Data = make(map[string]DbValue) + } + + ds.Data[id] = DbValue{ + Doc: docs[i], + Embedding: de.Embedding, + } + } + + // Update the file every time we add documents. + // We use a temporary file to avoid losing the original + // file, in case of a crash. + tmpname := ds.Filename + ".tmp" + f, err := os.Create(tmpname) + if err != nil { + return err + } + encoder := json.NewEncoder(f) + if err := encoder.Encode(ds.Data); err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + if err := os.Rename(tmpname, ds.Filename); err != nil { + return err + } + + return nil +} + // docID returns the ID to use for a Document. // This is intended to be the same as the genkit Typescript computation. func docID(doc *ai.Document) (string, error) { diff --git a/go/plugins/localvec/localvec_test.go b/go/plugins/localvec/localvec_test.go index 2f38a2c00d..60bf759a88 100644 --- a/go/plugins/localvec/localvec_test.go +++ b/go/plugins/localvec/localvec_test.go @@ -64,10 +64,7 @@ func TestLocalVec(t *testing.T) { t.Fatal(err) } - indexerReq := &ai.IndexerRequest{ - Documents: []*ai.Document{d1, d2, d3}, - } - err = ds.index(ctx, indexerReq) + err = Index(ctx, []*ai.Document{d1, d2, d3}, ds) if err != nil { t.Fatalf("Index operation failed: %v", err) } @@ -132,10 +129,7 @@ func TestPersistentIndexing(t *testing.T) { t.Fatal(err) } - indexerReq := &ai.IndexerRequest{ - Documents: []*ai.Document{d1, d2}, - } - err = ds.index(ctx, indexerReq) + err = Index(ctx, []*ai.Document{d1, d2}, ds) if err != nil { t.Fatalf("Index operation failed: %v", err) } @@ -163,10 +157,7 @@ func TestPersistentIndexing(t *testing.T) { t.Fatal(err) } - indexerReq = &ai.IndexerRequest{ - Documents: []*ai.Document{d3}, - } - err = dsAnother.index(ctx, indexerReq) + err = Index(ctx, []*ai.Document{d3}, dsAnother) if err != nil { t.Fatalf("Index operation failed: %v", err) } @@ -210,14 +201,11 @@ func TestInit(t *testing.T) { t.Fatal(err) } const name = "mystore" - ind, ret, err := DefineIndexerAndRetriever(g, name, Config{Embedder: embedder}) + _, ret, err := DefineRetriever(g, name, Config{Embedder: embedder}) if err != nil { t.Fatal(err) } want := "devLocalVectorStore/" + name - if g := ind.Name(); g != want { - t.Errorf("got %q, want %q", g, want) - } if g := ret.Name(); g != want { t.Errorf("got %q, want %q", g, want) } diff --git a/go/samples/menu/main.go b/go/samples/menu/main.go index b887788826..f8edba1838 100644 --- a/go/samples/menu/main.go +++ b/go/samples/menu/main.go @@ -19,7 +19,6 @@ package main import ( "context" "log" - "os" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/googlegenai" @@ -60,7 +59,7 @@ type textMenuQuestionInput struct { func main() { ctx := context.Background() g, err := genkit.Init(ctx, - genkit.WithPlugins(&googlegenai.VertexAI{Location: os.Getenv("GCLOUD_LOCATION")}), + genkit.WithPlugins(&googlegenai.VertexAI{}), ) if err != nil { log.Fatalf("failed to create Genkit: %v", err) @@ -83,13 +82,13 @@ func main() { if err != nil { log.Fatal(err) } - indexer, retriever, err := localvec.DefineIndexerAndRetriever(g, "go-menu_items", localvec.Config{ + docStore, retriever, err := localvec.DefineRetriever(g, "go-menu_items", localvec.Config{ Embedder: embedder, }) if err != nil { log.Fatal(err) } - if err := setup04(g, indexer, retriever, model); err != nil { + if err := setup04(ctx, g, docStore, retriever, model); err != nil { log.Fatal(err) } diff --git a/go/samples/menu/s04.go b/go/samples/menu/s04.go index 25b7ce21bc..bbc2284f00 100644 --- a/go/samples/menu/s04.go +++ b/go/samples/menu/s04.go @@ -26,7 +26,7 @@ import ( "github.com/firebase/genkit/go/plugins/localvec" ) -func setup04(g *genkit.Genkit, indexer ai.Indexer, retriever ai.Retriever, model ai.Model) error { +func setup04(ctx context.Context, g *genkit.Genkit, docStore *localvec.DocStore, retriever ai.Retriever, model ai.Model) error { ragDataMenuPrompt, err := genkit.DefinePrompt(g, "s04_ragDataMenu", ai.WithPrompt(` You are acting as Walt, a helpful AI assistant here at the restaurant. @@ -68,7 +68,9 @@ Answer this customer's question: } docs = append(docs, ai.DocumentFromText(s, metadata)) } - if err := ai.Index(ctx, indexer, ai.WithDocs(docs...)); err != nil { + + // Index the menu items. + if err := localvec.Index(ctx, docs, docStore); err != nil { return nil, err } diff --git a/go/samples/rag/main.go b/go/samples/rag/main.go index bc0333d48d..d293ed031a 100644 --- a/go/samples/rag/main.go +++ b/go/samples/rag/main.go @@ -96,7 +96,7 @@ func main() { if err := localvec.Init(); err != nil { log.Fatal(err) } - indexer, retriever, err := localvec.DefineIndexerAndRetriever(g, "simpleQa", localvec.Config{Embedder: embedder}) + docStore, retriever, err := localvec.DefineRetriever(g, "simpleQa", localvec.Config{Embedder: embedder}) if err != nil { log.Fatal(err) } @@ -158,7 +158,7 @@ func main() { d2 := ai.DocumentFromText("USA is the largest importer of coffee", nil) d3 := ai.DocumentFromText("Water exists in 3 states - solid, liquid and gas", nil) - err := ai.Index(ctx, indexer, ai.WithDocs(d1, d2, d3)) + err := localvec.Index(ctx, []*ai.Document{d1, d2, d3}, docStore) if err != nil { return "", err }