diff --git a/examples/tools/retrievaltool/vertex_ai_rag.go b/examples/tools/retrievaltool/vertex_ai_rag.go new file mode 100644 index 00000000..626e3db0 --- /dev/null +++ b/examples/tools/retrievaltool/vertex_ai_rag.go @@ -0,0 +1,122 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "log" + "os" + "strconv" + + "google.golang.org/adk/agent/llmagent" + "google.golang.org/adk/cmd/launcher/adk" + "google.golang.org/adk/cmd/launcher/full" + "google.golang.org/adk/model/gemini" + "google.golang.org/adk/server/restapi/services" + "google.golang.org/adk/tool" + "google.golang.org/adk/tool/retrievaltool" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + + modelName := "gemini-2.0-flash-001" + if v := os.Getenv("MODEL_NAME"); v != "" { + modelName = v + } + location := "us-central1" + if v := os.Getenv("LOCATION"); v != "" { + location = v + } + model, err := gemini.NewModel(ctx, modelName, + &genai.ClientConfig{ + Backend: genai.BackendVertexAI, + Project: os.Getenv("PROJECT_ID"), + Location: location, + }) + if err != nil { + log.Fatalf("Failed to create model: %v", err) + } + + // Need to create RAG corpus by https://docs.cloud.google.com/vertex-ai/generative-ai/docs/rag-engine/rag-quickstart#run-rag + ragCorpus := os.Getenv("RAG_CORPUS") + if ragCorpus == "" { + log.Fatalf("RAG_CORPUS environment variable is required") + } + + similarityTopKStr := "10" + if v := os.Getenv("SIMILARITY_TOP_K"); v != "" { + similarityTopKStr = v + } + similarityTopKVal, err := strconv.ParseInt(similarityTopKStr, 10, 32) + if err != nil { + log.Fatalf("failed to parse SIMILARITY_TOP_K: %v", err) + } + similarityTopK := int32(similarityTopKVal) + + vectorDistanceThresholdStr := "0.6" + if v := os.Getenv("VECTOR_DISTANCE_THRESHOLD"); v != "" { + vectorDistanceThresholdStr = v + } + vectorDistanceThreshold, err := strconv.ParseFloat(vectorDistanceThresholdStr, 64) + if err != nil { + log.Fatalf("failed to parse VECTOR_DISTANCE_THRESHOLD: %v", err) + } + + ragStore := &genai.VertexRAGStore{ + RAGCorpora: []string{ragCorpus}, + SimilarityTopK: &similarityTopK, + VectorDistanceThreshold: &vectorDistanceThreshold, + } + + askVertexRetrieval, err := retrievaltool.NewVertexAIRAG( + "retrieve_rag_documentation", + "Use this tool to retrieve documentation and reference materials for the question from the RAG corpus", + ragStore, + ) + if err != nil { + log.Fatalf("Failed to create retrievaltool: %v", err) + } + + rootAgent, err := llmagent.New(llmagent.Config{ + Name: "ask_rag_agent", + Model: model, + Description: "Agent that answers questions using RAG-based document retrieval", + Instruction: `You are a helpful assistant that can retrieve documentation and reference materials to answer questions. + +When answering questions: +1. Use the retrieve_rag_documentation tool to search for relevant information from the knowledge base +2. Base your answers on the retrieved documentation +3. If the retrieved information is insufficient, clearly state what additional information might be needed +4. Always cite or reference the source of your information when possible`, + Tools: []tool.Tool{ + askVertexRetrieval, + }, + }) + if err != nil { + log.Fatalf("Failed to create agent: %v", err) + } + + config := &adk.Config{ + AgentLoader: services.NewSingleAgentLoader(rootAgent), + } + + l := full.NewLauncher() + err = l.Execute(ctx, config, os.Args[1:]) + if err != nil { + log.Fatalf("run failed: %v\n\n%s", err, l.CommandLineSyntax()) + } +} diff --git a/tool/retrievaltool/vertex_ai_rag.go b/tool/retrievaltool/vertex_ai_rag.go new file mode 100644 index 00000000..06f81d42 --- /dev/null +++ b/tool/retrievaltool/vertex_ai_rag.go @@ -0,0 +1,74 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package retrievaltool + +import ( + "google.golang.org/adk/model" + "google.golang.org/adk/tool" + "google.golang.org/genai" +) + +// VertexAIRAG is a retrieval tool that uses Vertex AI RAG (Retrieval-Augmented Generation) to retrieve data. +type VertexAIRAG struct { + name string + description string + vertexRAGStore *genai.VertexRAGStore +} + +// NewVertexAIRAG creates a new Vertex AI RAG retrieval tool with the given parameters. +func NewVertexAIRAG(name, description string, vertexRAGStore *genai.VertexRAGStore) (tool.Tool, error) { + return &VertexAIRAG{ + name: name, + description: description, + vertexRAGStore: vertexRAGStore, + }, nil +} + +// Name implements tool.Tool. +func (v *VertexAIRAG) Name() string { + return v.name +} + +// Description implements tool.Tool. +func (v *VertexAIRAG) Description() string { + return v.description +} + +// IsLongRunning implements tool.Tool. +func (v *VertexAIRAG) IsLongRunning() bool { + return false +} + +// ProcessRequest adds the Vertex AI RAG tool to the LLM request. +// Uses the built-in Vertex AI RAG tool for Gemini models. +func (v *VertexAIRAG) ProcessRequest(ctx tool.Context, req *model.LLMRequest) error { + return v.addBuiltInRAGTool(req) +} + +// addBuiltInRAGTool adds the built-in Vertex AI RAG tool to the request config. +func (v *VertexAIRAG) addBuiltInRAGTool(req *model.LLMRequest) error { + if req.Config == nil { + req.Config = &genai.GenerateContentConfig{} + } + + ragTool := &genai.Tool{ + Retrieval: &genai.Retrieval{ + VertexRAGStore: v.vertexRAGStore, + }, + } + + req.Config.Tools = append(req.Config.Tools, ragTool) + return nil +} diff --git a/tool/retrievaltool/vertex_ai_rag_test.go b/tool/retrievaltool/vertex_ai_rag_test.go new file mode 100644 index 00000000..634f40a7 --- /dev/null +++ b/tool/retrievaltool/vertex_ai_rag_test.go @@ -0,0 +1,126 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package retrievaltool_test + +import ( + "context" + "testing" + + icontext "google.golang.org/adk/internal/context" + "google.golang.org/adk/internal/sessioninternal" + "google.golang.org/adk/internal/toolinternal" + "google.golang.org/adk/model" + "google.golang.org/adk/session" + toolpkg "google.golang.org/adk/tool" + "google.golang.org/adk/tool/retrievaltool" + "google.golang.org/genai" +) + +func TestVertexAIRAG_ProcessRequest(t *testing.T) { + similarityTopK := int32(5) + vectorDistanceThreshold := 0.8 + ragStore := &genai.VertexRAGStore{ + RAGCorpora: []string{"projects/123456789/locations/us-central1/ragCorpora/1234567890"}, + SimilarityTopK: &similarityTopK, + VectorDistanceThreshold: &vectorDistanceThreshold, + } + + tool, err := retrievaltool.NewVertexAIRAG("test_rag", "Test RAG tool", ragStore) + if err != nil { + t.Fatalf("NewVertexAIRAG() failed: %v", err) + } + + req := &model.LLMRequest{ + Model: "gemini-2.0-flash", + } + + requestProcessor, ok := tool.(toolinternal.RequestProcessor) + if !ok { + t.Fatal("tool does not implement RequestProcessor") + } + + toolCtx := createToolContext(t) + err = requestProcessor.ProcessRequest(toolCtx, req) + if err != nil { + t.Fatalf("ProcessRequest failed: %v", err) + } + + if req.Config == nil { + t.Fatal("req.Config is nil") + } + + if len(req.Config.Tools) == 0 { + t.Fatal("req.Config.Tools is empty") + } + + // Find the retrieval tool + var foundRetrievalTool bool + for _, genaiTool := range req.Config.Tools { + if genaiTool.Retrieval != nil { + foundRetrievalTool = true + store := genaiTool.Retrieval.VertexRAGStore + if store == nil { + t.Fatal("VertexRAGStore is nil") + } + + if len(store.RAGCorpora) == 0 { + t.Error("RAGCorpora is empty") + } else { + expectedCorpus := "projects/123456789/locations/us-central1/ragCorpora/1234567890" + if store.RAGCorpora[0] != expectedCorpus { + t.Errorf("Expected corpus %s, got %s", expectedCorpus, store.RAGCorpora[0]) + } + } + + if store.SimilarityTopK == nil { + t.Error("SimilarityTopK is nil") + } else if *store.SimilarityTopK != similarityTopK { + t.Errorf("Expected SimilarityTopK %d, got %d", similarityTopK, *store.SimilarityTopK) + } + + if store.VectorDistanceThreshold == nil { + t.Error("VectorDistanceThreshold is nil") + } else if *store.VectorDistanceThreshold != vectorDistanceThreshold { + t.Errorf("Expected VectorDistanceThreshold %f, got %f", vectorDistanceThreshold, *store.VectorDistanceThreshold) + } + } + } + if !foundRetrievalTool { + t.Error("Retrieval tool not found in request config") + } +} + +// createToolContext creates a tool context for testing +func createToolContext(t *testing.T) toolpkg.Context { + t.Helper() + + sessionService := session.InMemoryService() + createResponse, err := sessionService.Create(context.Background(), &session.CreateRequest{ + AppName: "testApp", + UserID: "testUser", + SessionID: "testSession", + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + s := createResponse.Session + sessionImpl := sessioninternal.NewMutableSession(sessionService, s) + + ctx := icontext.NewInvocationContext(context.Background(), icontext.InvocationContextParams{ + Session: sessionImpl, + }) + + return toolinternal.NewToolContext(ctx, "", &session.EventActions{}) +}