Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions examples/tools/retrievaltool/vertex_ai_rag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"context"
"log"
"os"
"strconv"

"google.golang.org/adk/agent/llmagent"
"google.golang.org/adk/cmd/launcher/adk"
"google.golang.org/adk/cmd/launcher/full"
"google.golang.org/adk/model/gemini"
"google.golang.org/adk/server/restapi/services"
"google.golang.org/adk/tool"
"google.golang.org/adk/tool/retrievaltool"
"google.golang.org/genai"
)

func main() {
ctx := context.Background()

modelName := "gemini-2.0-flash-001"
if v := os.Getenv("MODEL_NAME"); v != "" {
modelName = v
}
location := "us-central1"
if v := os.Getenv("LOCATION"); v != "" {
location = v
}
model, err := gemini.NewModel(ctx, modelName,
&genai.ClientConfig{
Backend: genai.BackendVertexAI,
Project: os.Getenv("PROJECT_ID"),
Location: location,
})
if err != nil {
log.Fatalf("Failed to create model: %v", err)
}

// Need to create RAG corpus by https://docs.cloud.google.com/vertex-ai/generative-ai/docs/rag-engine/rag-quickstart#run-rag
ragCorpus := os.Getenv("RAG_CORPUS")
if ragCorpus == "" {
log.Fatalf("RAG_CORPUS environment variable is required")
}

similarityTopKStr := "10"
if v := os.Getenv("SIMILARITY_TOP_K"); v != "" {
similarityTopKStr = v
}
similarityTopKVal, err := strconv.ParseInt(similarityTopKStr, 10, 32)
if err != nil {
log.Fatalf("failed to parse SIMILARITY_TOP_K: %v", err)
}
similarityTopK := int32(similarityTopKVal)

vectorDistanceThresholdStr := "0.6"
if v := os.Getenv("VECTOR_DISTANCE_THRESHOLD"); v != "" {
vectorDistanceThresholdStr = v
}
vectorDistanceThreshold, err := strconv.ParseFloat(vectorDistanceThresholdStr, 64)
if err != nil {
log.Fatalf("failed to parse VECTOR_DISTANCE_THRESHOLD: %v", err)
}

ragStore := &genai.VertexRAGStore{
RAGCorpora: []string{ragCorpus},
SimilarityTopK: &similarityTopK,
VectorDistanceThreshold: &vectorDistanceThreshold,
}

askVertexRetrieval, err := retrievaltool.NewVertexAIRAG(
"retrieve_rag_documentation",
"Use this tool to retrieve documentation and reference materials for the question from the RAG corpus",
ragStore,
)
if err != nil {
log.Fatalf("Failed to create retrievaltool: %v", err)
}

rootAgent, err := llmagent.New(llmagent.Config{
Name: "ask_rag_agent",
Model: model,
Description: "Agent that answers questions using RAG-based document retrieval",
Instruction: `You are a helpful assistant that can retrieve documentation and reference materials to answer questions.

When answering questions:
1. Use the retrieve_rag_documentation tool to search for relevant information from the knowledge base
2. Base your answers on the retrieved documentation
3. If the retrieved information is insufficient, clearly state what additional information might be needed
4. Always cite or reference the source of your information when possible`,
Tools: []tool.Tool{
askVertexRetrieval,
},
})
if err != nil {
log.Fatalf("Failed to create agent: %v", err)
}

config := &adk.Config{
AgentLoader: services.NewSingleAgentLoader(rootAgent),
}

l := full.NewLauncher()
err = l.Execute(ctx, config, os.Args[1:])
if err != nil {
log.Fatalf("run failed: %v\n\n%s", err, l.CommandLineSyntax())
}
}
74 changes: 74 additions & 0 deletions tool/retrievaltool/vertex_ai_rag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package retrievaltool

import (
"google.golang.org/adk/model"
"google.golang.org/adk/tool"
"google.golang.org/genai"
)

// VertexAIRAG is a retrieval tool that uses Vertex AI RAG (Retrieval-Augmented Generation) to retrieve data.
type VertexAIRAG struct {
name string
description string
vertexRAGStore *genai.VertexRAGStore
}

// NewVertexAIRAG creates a new Vertex AI RAG retrieval tool with the given parameters.
func NewVertexAIRAG(name, description string, vertexRAGStore *genai.VertexRAGStore) (tool.Tool, error) {
return &VertexAIRAG{
name: name,
description: description,
vertexRAGStore: vertexRAGStore,
}, nil
}

// Name implements tool.Tool.
func (v *VertexAIRAG) Name() string {
return v.name
}

// Description implements tool.Tool.
func (v *VertexAIRAG) Description() string {
return v.description
}

// IsLongRunning implements tool.Tool.
func (v *VertexAIRAG) IsLongRunning() bool {
return false
}

// ProcessRequest adds the Vertex AI RAG tool to the LLM request.
// Uses the built-in Vertex AI RAG tool for Gemini models.
func (v *VertexAIRAG) ProcessRequest(ctx tool.Context, req *model.LLMRequest) error {
return v.addBuiltInRAGTool(req)
}

// addBuiltInRAGTool adds the built-in Vertex AI RAG tool to the request config.
func (v *VertexAIRAG) addBuiltInRAGTool(req *model.LLMRequest) error {
if req.Config == nil {
req.Config = &genai.GenerateContentConfig{}
}

ragTool := &genai.Tool{
Retrieval: &genai.Retrieval{
VertexRAGStore: v.vertexRAGStore,
},
}

req.Config.Tools = append(req.Config.Tools, ragTool)
return nil
}
126 changes: 126 additions & 0 deletions tool/retrievaltool/vertex_ai_rag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package retrievaltool_test

import (
"context"
"testing"

icontext "google.golang.org/adk/internal/context"
"google.golang.org/adk/internal/sessioninternal"
"google.golang.org/adk/internal/toolinternal"
"google.golang.org/adk/model"
"google.golang.org/adk/session"
toolpkg "google.golang.org/adk/tool"
"google.golang.org/adk/tool/retrievaltool"
"google.golang.org/genai"
)

func TestVertexAIRAG_ProcessRequest(t *testing.T) {
similarityTopK := int32(5)
vectorDistanceThreshold := 0.8
ragStore := &genai.VertexRAGStore{
RAGCorpora: []string{"projects/123456789/locations/us-central1/ragCorpora/1234567890"},
SimilarityTopK: &similarityTopK,
VectorDistanceThreshold: &vectorDistanceThreshold,
}

tool, err := retrievaltool.NewVertexAIRAG("test_rag", "Test RAG tool", ragStore)
if err != nil {
t.Fatalf("NewVertexAIRAG() failed: %v", err)
}

req := &model.LLMRequest{
Model: "gemini-2.0-flash",
}

requestProcessor, ok := tool.(toolinternal.RequestProcessor)
if !ok {
t.Fatal("tool does not implement RequestProcessor")
}

toolCtx := createToolContext(t)
err = requestProcessor.ProcessRequest(toolCtx, req)
if err != nil {
t.Fatalf("ProcessRequest failed: %v", err)
}

if req.Config == nil {
t.Fatal("req.Config is nil")
}

if len(req.Config.Tools) == 0 {
t.Fatal("req.Config.Tools is empty")
}

// Find the retrieval tool
var foundRetrievalTool bool
for _, genaiTool := range req.Config.Tools {
if genaiTool.Retrieval != nil {
foundRetrievalTool = true
store := genaiTool.Retrieval.VertexRAGStore
if store == nil {
t.Fatal("VertexRAGStore is nil")
}

if len(store.RAGCorpora) == 0 {
t.Error("RAGCorpora is empty")
} else {
expectedCorpus := "projects/123456789/locations/us-central1/ragCorpora/1234567890"
if store.RAGCorpora[0] != expectedCorpus {
t.Errorf("Expected corpus %s, got %s", expectedCorpus, store.RAGCorpora[0])
}
}

if store.SimilarityTopK == nil {
t.Error("SimilarityTopK is nil")
} else if *store.SimilarityTopK != similarityTopK {
t.Errorf("Expected SimilarityTopK %d, got %d", similarityTopK, *store.SimilarityTopK)
}

if store.VectorDistanceThreshold == nil {
t.Error("VectorDistanceThreshold is nil")
} else if *store.VectorDistanceThreshold != vectorDistanceThreshold {
t.Errorf("Expected VectorDistanceThreshold %f, got %f", vectorDistanceThreshold, *store.VectorDistanceThreshold)
}
}
}
if !foundRetrievalTool {
t.Error("Retrieval tool not found in request config")
}
}

// createToolContext creates a tool context for testing
func createToolContext(t *testing.T) toolpkg.Context {
t.Helper()

sessionService := session.InMemoryService()
createResponse, err := sessionService.Create(context.Background(), &session.CreateRequest{
AppName: "testApp",
UserID: "testUser",
SessionID: "testSession",
})
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
s := createResponse.Session
sessionImpl := sessioninternal.NewMutableSession(sessionService, s)

ctx := icontext.NewInvocationContext(context.Background(), icontext.InvocationContextParams{
Session: sessionImpl,
})

return toolinternal.NewToolContext(ctx, "", &session.EventActions{})
}