diff --git a/.github/labeler.yml b/.github/labeler.yml index a6134f4d46..9477d66ee4 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -54,6 +54,13 @@ js: - "**/package.json" - "js/**" +java: + - changed-files: + - any-glob-to-any-file: + - "**/*.java" + - "**/pom.xml" + - "java/**" + tooling: - changed-files: - any-glob-to-any-file: diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml new file mode 100644 index 0000000000..9d396419f4 --- /dev/null +++ b/.github/workflows/java.yml @@ -0,0 +1,70 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +name: Java tests and checks + +on: + pull_request: + paths: + - "java/**" + - "genkit-tools/**" + - ".github/workflows/java.yml" + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + java-version: ['17', '21'] + steps: + - name: Checkout Repo + uses: actions/checkout@v4 + + - name: Set up JDK ${{ matrix.java-version }} + uses: actions/setup-java@v4 + with: + java-version: ${{ matrix.java-version }} + distribution: 'temurin' + cache: 'maven' + + - name: Build with Maven + run: mvn -B clean compile -q + working-directory: ./java + + - name: Run tests + run: mvn -B test + working-directory: ./java + + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout Repo + uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + + - name: Check code formatting with Spotless + run: mvn -B spotless:check + working-directory: ./java + + - name: Run Checkstyle + run: mvn -B checkstyle:check + working-directory: ./java diff --git a/.gitignore b/.gitignore index 0bd0d33355..eaec6f2cb0 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,17 @@ next-env.d.ts # Code Coverage js/plugins/compat-oai/coverage/ +# Java +java/**/target/ +java/**/*.class +java/**/*.jar +java/**/*.war +java/**/*.ear +java/**/.classpath +java/**/.project +java/**/.settings/ +java/**/*.iml +java/**/.genkit +*.log +hs_err_pid* +java/samples/google-genai/generated_media/ diff --git a/genkit-tools/common/src/utils/utils.ts b/genkit-tools/common/src/utils/utils.ts index d2a404a834..f11703410d 100644 --- a/genkit-tools/common/src/utils/utils.ts +++ b/genkit-tools/common/src/utils/utils.ts @@ -37,6 +37,8 @@ export async function findProjectRoot(): Promise { const goModPath = path.join(currentDir, 'go.mod'); const pyprojectPath = path.join(currentDir, 'pyproject.toml'); const pyproject2Path = path.join(currentDir, 'requirements.txt'); + const javaPomPath = path.join(currentDir, 'pom.xml'); + const javaBuildGradlePath = path.join(currentDir, 'build.gradle'); try { const [ @@ -44,6 +46,8 @@ export async function findProjectRoot(): Promise { goModExists, pyprojectExists, pyproject2Exists, + javaPomExists, + javaBuildGradleExists, ] = await Promise.all([ fs .access(packageJsonPath) @@ -61,12 +65,22 @@ export async function findProjectRoot(): Promise { .access(pyproject2Path) .then(() => true) .catch(() => false), + fs + .access(javaPomPath) + .then(() => true) + .catch(() => false), + fs + .access(javaBuildGradlePath) + .then(() => true) + .catch(() => false), ]); if ( packageJsonExists || goModExists || pyprojectExists || - pyproject2Exists + pyproject2Exists || + javaPomExists || + javaBuildGradleExists ) { return currentDir; } diff --git a/java/README.md b/java/README.md new file mode 100644 index 0000000000..b54d0350dc --- /dev/null +++ b/java/README.md @@ -0,0 +1,497 @@ +# Genkit for Java + +Genkit for Java is the Java implementation of the Genkit framework for building AI-powered applications. + +See: https://firebase.google.com/docs/genkit + +> **Status**: Currently in active development (1.0.0-SNAPSHOT). Requires Java 17+. +> +> **Note**: The Java SDK supports OpenAI and Google GenAI (Gemini) models. Additional plugins (Vertex AI, Anthropic, Ollama, Firebase, etc.) are planned for future releases. See [Plugin Availability](#plugin-availability) for details. + +## Installation + +Add the following dependencies to your Maven `pom.xml`: + +```xml + + + com.google.genkit + genkit + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit-plugin-openai + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit-plugin-google-genai + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit-plugin-jetty + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit-plugin-localvec + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit-plugin-mcp + 1.0.0-SNAPSHOT + +``` + +## Quick Start + +```java +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.GenerateOptions; +import com.google.genkit.ai.GenerationConfig; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.plugins.openai.OpenAIPlugin; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; + +public class Main { + public static void main(String[] args) { + // Create Genkit with plugins + Genkit genkit = Genkit.builder() + .options(GenkitOptions.builder() + .devMode(true) + .reflectionPort(3100) + .build()) + .plugin(OpenAIPlugin.create()) + .plugin(new JettyPlugin(JettyPluginOptions.builder() + .port(8080) + .build())) + .build(); + + // Generate text + ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt("Tell me a fun fact!") + .config(GenerationConfig.builder() + .temperature(0.9) + .maxOutputTokens(200) + .build()) + .build()); + + System.out.println(response.getText()); + } +} +``` + +## Defining Flows + +Flows are observable, traceable AI workflows that can be exposed as HTTP endpoints: + +```java +// Simple flow with typed input/output +Flow greetFlow = genkit.defineFlow( + "greeting", + String.class, + String.class, + name -> "Hello, " + name + "!"); + +// AI-powered flow with context access +Flow jokeFlow = genkit.defineFlow( + "tellJoke", + String.class, + String.class, + (ctx, topic) -> { + ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt("Tell me a short, funny joke about: " + topic) + .build()); + return response.getText(); + }); + +// Run a flow programmatically +String result = genkit.runFlow("greeting", "World"); +``` + +## Using Tools + +Define tools that models can call during generation: + +```java +@SuppressWarnings("unchecked") +Tool, Map> weatherTool = genkit.defineTool( + "getWeather", + "Gets the current weather for a location", + Map.of( + "type", "object", + "properties", Map.of( + "location", Map.of("type", "string", "description", "The city name") + ), + "required", new String[]{"location"} + ), + (Class>) (Class) Map.class, + (ctx, input) -> { + String location = (String) input.get("location"); + return Map.of( + "location", location, + "temperature", "72°F", + "conditions", "sunny" + ); + }); + +// Use tool in generation - tool execution is handled automatically +ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o") + .prompt("What's the weather in Paris?") + .tools(List.of(weatherTool)) + .build()); +``` + +## DotPrompt Support + +Load and use `.prompt` files with Handlebars templating: + +```java +// Load a prompt from resources/prompts/recipe.prompt +ExecutablePrompt recipePrompt = genkit.prompt("recipe", RecipeInput.class); + +// Execute with typed input +ModelResponse response = recipePrompt.generate(new RecipeInput("pasta carbonara")); + +// Prompts support variants (e.g., recipe.robot.prompt) +ExecutablePrompt robotPrompt = genkit.prompt("recipe", RecipeInput.class, "robot"); +``` + +## RAG (Retrieval Augmented Generation) + +Build RAG applications with retrievers and indexers: + +```java +// Define a retriever +Retriever myRetriever = genkit.defineRetriever("myStore/docs", (ctx, request) -> { + List docs = findSimilarDocs(request.getQuery()); + return new RetrieverResponse(docs); +}); + +// Define an indexer +Indexer myIndexer = genkit.defineIndexer("myStore/docs", (ctx, request) -> { + indexDocuments(request.getDocuments()); + return new IndexerResponse(); +}); + +// Index documents +List docs = List.of( + Document.fromText("Paris is the capital of France."), + Document.fromText("Berlin is the capital of Germany.") +); +genkit.index("myStore/docs", docs); + +// Retrieve and generate +List relevantDocs = genkit.retrieve("myStore/docs", "What is the capital of France?"); +ModelResponse response = genkit.generate(GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt("Answer based on context: What is the capital of France?") + .docs(relevantDocs) + .build()); +``` + +## Evaluations + +Define custom evaluators to assess AI output quality: + +```java +genkit.defineEvaluator("accuracyCheck", "Accuracy Check", "Checks factual accuracy", + (dataPoint, options) -> { + double score = calculateAccuracyScore(dataPoint.getOutput()); + return EvalResponse.builder() + .testCaseId(dataPoint.getTestCaseId()) + .evaluation(Score.builder().score(score).build()) + .build(); + }); + +// Run evaluation +EvalRunKey result = genkit.evaluate(RunEvaluationRequest.builder() + .datasetId("my-dataset") + .evaluators(List.of("accuracyCheck")) + .actionRef("/flow/myFlow") + .build()); +``` + +## Streaming + +Generate responses with streaming for real-time output: + +```java +StringBuilder result = new StringBuilder(); +ModelResponse response = genkit.generateStream( + GenerateOptions.builder() + .model("openai/gpt-4o") + .prompt("Tell me a story") + .build(), + chunk -> { + System.out.print(chunk.getText()); + result.append(chunk.getText()); + }); +``` + +## Embeddings + +Generate vector embeddings for semantic search: + +```java +List documents = List.of( + Document.fromText("Hello world"), + Document.fromText("Goodbye world") +); +EmbedResponse response = genkit.embed("openai/text-embedding-3-small", documents); +``` + +## Modules + +| Module | Description | +|--------|-------------| +| **genkit-core** | Core framework: actions, flows, registry, tracing (OpenTelemetry) | +| **genkit-ai** | AI abstractions: models, embedders, tools, prompts, retrievers, indexers, evaluators | +| **genkit** | Main entry point combining core and AI with reflection server | +| **plugins/openai** | OpenAI models (GPT-4o, GPT-4o-mini, etc.) and embeddings | +| **plugins/google-genai** | Google Gemini models and Imagen image generation | +| **plugins/jetty** | HTTP server plugin using Jetty 12 | +| **plugins/localvec** | Local file-based vector store for development | +| **plugins/mcp** | Model Context Protocol (MCP) client integration | + + +## Observability + +Genkit Java SDK provides comprehensive observability features through OpenTelemetry integration: + +### Tracing + +All actions (models, tools, flows) are automatically traced with rich metadata: + +- **Span types**: `action`, `flow`, `flowStep`, `util` +- **Subtypes**: `model`, `tool`, `flow`, `embedder`, etc. +- **Session tracking**: `sessionId` and `threadName` for multi-turn conversations +- **Input/Output capture**: Full request/response data in span attributes + +Example span attributes: +``` +genkit:name = "openai/gpt-4o-mini" +genkit:type = "action" +genkit:metadata:subtype = "model" +genkit:path = "/{myFlow,t:flow}/{openai/gpt-4o-mini,t:action,s:model}" +genkit:input = {...} +genkit:output = {...} +genkit:sessionId = "user-123" +``` + +### Metrics + +The SDK exposes OpenTelemetry metrics for monitoring: + +| Metric | Description | +|--------|-------------| +| `genkit/ai/generate/requests` | Model generation request count | +| `genkit/ai/generate/latency` | Model generation latency (ms) | +| `genkit/ai/generate/input/tokens` | Input token count | +| `genkit/ai/generate/output/tokens` | Output token count | +| `genkit/ai/generate/input/characters` | Input character count | +| `genkit/ai/generate/output/characters` | Output character count | +| `genkit/ai/generate/input/images` | Input image count | +| `genkit/ai/generate/output/images` | Output image count | +| `genkit/ai/generate/thinking/tokens` | Thinking/reasoning token count | +| `genkit/tool/requests` | Tool execution request count | +| `genkit/tool/latency` | Tool execution latency (ms) | +| `genkit/feature/requests` | Feature (flow) request count | +| `genkit/feature/latency` | Feature (flow) latency (ms) | +| `genkit/action/requests` | General action request count | +| `genkit/action/latency` | General action latency (ms) | + +### Usage Tracking + +Model responses include detailed usage statistics: + +```java +ModelResponse response = genkit.generate(options); +Usage usage = response.getUsage(); + +System.out.println("Input tokens: " + usage.getInputTokens()); +System.out.println("Output tokens: " + usage.getOutputTokens()); +System.out.println("Latency: " + response.getLatencyMs() + "ms"); +``` + +### Session Context + +Track multi-turn conversations with session and thread context: + +```java +ActionContext ctx = ActionContext.builder() + .registry(genkit.getRegistry()) + .sessionId("user-123") + .threadName("support-chat") + .build(); +``` + +## Samples + +The following samples are available in `java/samples/`. See the [samples README](./samples/README.md) for detailed instructions on running each sample. + +| Sample | Description | +|--------|-------------| +| **openai** | Basic OpenAI integration with flows and tools | +| **google-genai** | Google Gemini integration with image generation | +| **dotprompt** | DotPrompt files with complex inputs/outputs, variants, and partials | +| **rag** | RAG application with local vector store | +| **chat-session** | Multi-turn chat with session persistence | +| **evaluations** | Custom evaluators and evaluation workflows | +| **complex-io** | Complex nested types, arrays, maps in flow inputs/outputs | +| **middleware** | Middleware patterns for logging, caching, rate limiting | +| **multi-agent** | Multi-agent orchestration patterns | +| **interrupts** | Flow interrupts and human-in-the-loop patterns | +| **mcp** | Model Context Protocol (MCP) integration | + +### Running Samples + +```bash +# Set your API key +export OPENAI_API_KEY=your-api-key +# Or: export GOOGLE_GENAI_API_KEY=your-api-key + +# Navigate to a sample and run +cd java/samples/openai +./run.sh + +# Or with Genkit Dev UI +genkit start -- ./run.sh +``` + +## Development + +### Prerequisites + +- Java 17+ +- Maven 3.6+ +- OpenAI API key or Google GenAI API key (for samples) +- Genkit CLI (optional, for Dev UI) + +### Installing Genkit CLI + +```bash +npm install -g genkit +``` + +### Building + +```bash +cd java +mvn clean install +``` + +### Running Tests + +```bash +mvn test +``` + +### Running Samples + +See the [samples README](./samples/README.md) for detailed instructions. + +```bash +# Set your API key +export OPENAI_API_KEY=your-api-key +# Or: export GOOGLE_GENAI_API_KEY=your-api-key + +# Run a sample +cd java/samples/openai +./run.sh +# Or: mvn compile exec:java + +# Run with Genkit Dev UI (recommended) +genkit start -- ./run.sh +``` + +## CLI Integration + +The Java implementation works with the Genkit CLI. Start your application with: + +```bash +genkit start -- ./run.sh +# Or: genkit start -- mvn exec:java +``` + +The reflection server starts automatically in dev mode (`devMode(true)`). + +## Dev UI + +When running in dev mode, Genkit starts a reflection server on port 3100 (configurable via `reflectionPort()`). +The Dev UI connects to this server to: + +- List all registered actions (flows, models, tools, prompts, retrievers, evaluators) +- Run actions with test inputs +- View traces and execution logs +- Manage datasets and run evaluations + +## Architecture + +``` +com.google.genkit +├── core/ # Core framework +│ ├── Action # Base action interface +│ ├── ActionDef # Action implementation +│ ├── ActionContext # Execution context with registry access +│ ├── Flow # Flow definition +│ ├── Registry # Action registry +│ ├── Plugin # Plugin interface +│ └── tracing/ # OpenTelemetry integration +│ ├── Tracer # Span management +│ └── TelemetryClient # Telemetry export +├── ai/ # AI features +│ ├── Model # Model interface +│ ├── ModelRequest/Response# Model I/O types +│ ├── Tool # Tool definition +│ ├── Embedder # Embedder interface +│ ├── Retriever # Retriever interface +│ ├── Indexer # Indexer interface +│ ├── Prompt # Prompt templates +│ ├── telemetry/ # AI-specific metrics +│ │ ├── GenerateTelemetry# Model generation metrics +│ │ ├── ToolTelemetry # Tool execution metrics +│ │ ├── ActionTelemetry # Action execution metrics +│ │ ├── FeatureTelemetry # Flow/feature metrics +│ │ └── ModelTelemetryHelper # Telemetry helper +│ └── evaluation/ # Evaluation framework +│ ├── Evaluator # Evaluator definition +│ ├── EvaluationManager# Run evaluations +│ └── DatasetStore # Dataset management +├── genkit/ # Main module +│ ├── Genkit # Main entry point & builder +│ ├── GenkitOptions # Configuration options +│ ├── ReflectionServer # Dev UI integration +│ └── prompt/ # DotPrompt support +│ ├── DotPrompt # Prompt file parser +│ └── ExecutablePrompt # Prompt execution +└── plugins/ # Plugin implementations + ├── openai/ # OpenAI models & embeddings + ├── jetty/ # Jetty HTTP server + └── localvec/ # Local vector store +``` + +## License + +Apache License 2.0 diff --git a/java/ai/pom.xml b/java/ai/pom.xml new file mode 100644 index 0000000000..7f61f7c95b --- /dev/null +++ b/java/ai/pom.xml @@ -0,0 +1,71 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + + + genkit-ai + jar + + Genkit AI + AI features for Genkit including models, embedders, tools, retrievers, and prompts + + + + + com.google.genkit + genkit-core + + + + + com.github.jknack + handlebars + + + + + org.junit.jupiter + junit-jupiter + test + + + org.mockito + mockito-core + test + + + org.mockito + mockito-junit-jupiter + test + + + ch.qos.logback + logback-classic + test + + + diff --git a/java/ai/src/main/java/com/google/genkit/ai/Agent.java b/java/ai/src/main/java/com/google/genkit/ai/Agent.java new file mode 100644 index 0000000000..0139562006 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Agent.java @@ -0,0 +1,235 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Represents an agent that can be used as a tool in multi-agent systems. + * + *

+ * An Agent wraps an AgentConfig and provides a Tool interface for delegation. + * When the model calls an agent as a tool, the agent's configuration (system + * prompt, model, tools) is applied to the conversation, effectively + * "transferring" control to the specialized agent. + * + *

+ * Example usage: + * + *

{@code
+ * // Define a specialized agent
+ * Agent reservationAgent = genkit
+ * 		.defineAgent(AgentConfig.builder().name("reservationAgent").description("Handles restaurant reservations")
+ * 				.system("You are a reservation specialist...").tools(List.of(reservationTool)).build());
+ *
+ * // Use in a parent agent
+ * Agent triageAgent = genkit.defineAgent(AgentConfig.builder().name("triageAgent").description("Routes requests")
+ * 		.system("Route customer requests to specialists").agents(List.of(reservationAgent.getConfig())).build());
+ *
+ * // Start chat with triage agent
+ * Chat chat = genkit.chat(triageAgent);
+ * }
+ */ +public class Agent { + + private final AgentConfig config; + private final Tool, AgentTransferResult> asTool; + + /** + * Creates a new Agent. + * + * @param config + * the agent configuration + */ + @SuppressWarnings("unchecked") + public Agent(AgentConfig config) { + this.config = config; + this.asTool = createAgentTool(); + } + + /** + * Gets the agent configuration. + * + * @return the config + */ + public AgentConfig getConfig() { + return config; + } + + /** + * Gets the agent name. + * + * @return the name + */ + public String getName() { + return config.getName(); + } + + /** + * Gets the agent description. + * + * @return the description + */ + public String getDescription() { + return config.getDescription(); + } + + /** + * Gets the system prompt. + * + * @return the system prompt + */ + public String getSystem() { + return config.getSystem(); + } + + /** + * Gets the model name. + * + * @return the model name + */ + public String getModel() { + return config.getModel(); + } + + /** + * Gets the tools available to this agent. + * + * @return the tools + */ + public List> getTools() { + return config.getTools(); + } + + /** + * Gets the sub-agents. + * + * @return the sub-agents + */ + public List getAgents() { + return config.getAgents(); + } + + /** + * Gets all tools including sub-agent tools for handoff pattern. + * + *

+ * This method collects all tools that should be available to the agent, + * including the agent's own tools and sub-agents as tools (for handoff). When a + * sub-agent tool is called, the Chat will handle the handoff by switching + * context to that agent. + * + * @param agentRegistry + * map of agent name to Agent instance + * @return combined list of all tools from this agent and sub-agents as tools + */ + public List> getAllTools(Map agentRegistry) { + List> allTools = new ArrayList<>(); + + // Add this agent's direct tools + if (config.getTools() != null) { + allTools.addAll(config.getTools()); + } + + // Add sub-agents as tools (for handoff pattern) + if (config.getAgents() != null) { + for (AgentConfig agentConfig : config.getAgents()) { + Agent agent = agentRegistry.get(agentConfig.getName()); + if (agent != null) { + // Add the sub-agent as a tool - when called, Chat will handle the handoff + allTools.add(agent.asTool()); + } + } + } + + return allTools; + } + + /** + * Returns this agent as a tool that can be used by other agents. + * + * @return the agent as a tool + */ + public Tool, AgentTransferResult> asTool() { + return asTool; + } + + /** + * Gets the tool definition for this agent. + * + * @return the tool definition + */ + public ToolDefinition getToolDefinition() { + return asTool.getDefinition(); + } + + /** Creates the agent-as-tool wrapper. */ + @SuppressWarnings("unchecked") + private Tool, AgentTransferResult> createAgentTool() { + // OpenAI requires "properties" field even if empty + Map inputSchema = new HashMap<>(); + inputSchema.put("type", "object"); + inputSchema.put("properties", new HashMap()); + inputSchema.put("additionalProperties", true); + + Map outputSchema = new HashMap<>(); + outputSchema.put("type", "object"); + outputSchema.put("properties", + Map.of("transferredTo", Map.of("type", "string"), "transferred", Map.of("type", "boolean"))); + + return new Tool<>(config.getName(), + config.getDescription() != null ? config.getDescription() : "Transfer to " + config.getName(), + inputSchema, outputSchema, (Class>) (Class) Map.class, (ctx, input) -> { + // Throw handoff exception to signal the chat to switch to this agent + throw new AgentHandoffException(config.getName(), config, input); + }); + } + + /** Result of an agent transfer. */ + public static class AgentTransferResult { + private final String transferredTo; + private final boolean transferred; + + public AgentTransferResult(String agentName) { + this.transferredTo = agentName; + this.transferred = true; + } + + public String getTransferredTo() { + return transferredTo; + } + + public boolean isTransferred() { + return transferred; + } + + @Override + public String toString() { + return "transferred to " + transferredTo; + } + } + + @Override + public String toString() { + return "Agent{" + "name='" + config.getName() + '\'' + '}'; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/AgentConfig.java b/java/ai/src/main/java/com/google/genkit/ai/AgentConfig.java new file mode 100644 index 0000000000..3a40addabc --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/AgentConfig.java @@ -0,0 +1,351 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.List; + +/** + * Configuration for defining an agent (prompt as tool). + * + *

+ * An agent is a specialized prompt that can be used as a tool, enabling + * multi-agent systems where one agent can delegate tasks to other specialized + * agents. + * + *

+ * Example usage: + * + *

{@code
+ * // Define a specialized agent
+ * AgentConfig reservationAgent = AgentConfig.builder().name("reservationAgent")
+ * 		.description("Handles restaurant reservations")
+ * 		.system("You are a reservation specialist. Help users make and manage reservations.").model("openai/gpt-4o")
+ * 		.tools(List.of(reservationTool, cancelTool)).build();
+ *
+ * // Use as a tool in a triage agent
+ * AgentConfig triageAgent = AgentConfig.builder().name("triageAgent")
+ * 		.description("Routes customer requests to appropriate specialists")
+ * 		.system("You are a customer service triage agent...").agents(List.of(reservationAgent, menuAgent)) // Sub-agents
+ * 																											// as
+ * 																											// tools
+ * 		.build();
+ * }
+ */ +public class AgentConfig { + + private String name; + private String description; + private String system; + private String model; + private List> tools; + private List agents; + private GenerationConfig config; + private OutputConfig output; + + /** Default constructor. */ + public AgentConfig() { + } + + /** + * Gets the agent name. + * + * @return the name + */ + public String getName() { + return name; + } + + /** + * Sets the agent name. + * + * @param name + * the name + */ + public void setName(String name) { + this.name = name; + } + + /** + * Gets the description. + * + * @return the description + */ + public String getDescription() { + return description; + } + + /** + * Sets the description (used when agent is called as a tool). + * + * @param description + * the description + */ + public void setDescription(String description) { + this.description = description; + } + + /** + * Gets the system prompt. + * + * @return the system prompt + */ + public String getSystem() { + return system; + } + + /** + * Sets the system prompt. + * + * @param system + * the system prompt + */ + public void setSystem(String system) { + this.system = system; + } + + /** + * Gets the model name. + * + * @return the model name + */ + public String getModel() { + return model; + } + + /** + * Sets the model name. + * + * @param model + * the model name + */ + public void setModel(String model) { + this.model = model; + } + + /** + * Gets the tools available to this agent. + * + * @return the tools + */ + public List> getTools() { + return tools; + } + + /** + * Sets the tools available to this agent. + * + * @param tools + * the tools + */ + public void setTools(List> tools) { + this.tools = tools; + } + + /** + * Gets the sub-agents (agents that can be delegated to). + * + * @return the sub-agents + */ + public List getAgents() { + return agents; + } + + /** + * Sets the sub-agents. + * + * @param agents + * the sub-agents + */ + public void setAgents(List agents) { + this.agents = agents; + } + + /** + * Gets the generation config. + * + * @return the generation config + */ + public GenerationConfig getConfig() { + return config; + } + + /** + * Sets the generation config. + * + * @param config + * the generation config + */ + public void setConfig(GenerationConfig config) { + this.config = config; + } + + /** + * Gets the output config. + * + * @return the output config + */ + public OutputConfig getOutput() { + return output; + } + + /** + * Sets the output config. + * + * @param output + * the output config + */ + public void setOutput(OutputConfig output) { + this.output = output; + } + + /** + * Creates a new builder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for AgentConfig. */ + public static class Builder { + private String name; + private String description; + private String system; + private String model; + private List> tools; + private List agents; + private GenerationConfig config; + private OutputConfig output; + + /** + * Sets the agent name. + * + * @param name + * the name + * @return this builder + */ + public Builder name(String name) { + this.name = name; + return this; + } + + /** + * Sets the description. + * + * @param description + * the description + * @return this builder + */ + public Builder description(String description) { + this.description = description; + return this; + } + + /** + * Sets the system prompt. + * + * @param system + * the system prompt + * @return this builder + */ + public Builder system(String system) { + this.system = system; + return this; + } + + /** + * Sets the model name. + * + * @param model + * the model name + * @return this builder + */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Sets the tools available to this agent. + * + * @param tools + * the tools + * @return this builder + */ + public Builder tools(List> tools) { + this.tools = tools; + return this; + } + + /** + * Sets the sub-agents. + * + * @param agents + * the sub-agents + * @return this builder + */ + public Builder agents(List agents) { + this.agents = agents; + return this; + } + + /** + * Sets the generation config. + * + * @param config + * the generation config + * @return this builder + */ + public Builder config(GenerationConfig config) { + this.config = config; + return this; + } + + /** + * Sets the output config. + * + * @param output + * the output config + * @return this builder + */ + public Builder output(OutputConfig output) { + this.output = output; + return this; + } + + /** + * Builds the AgentConfig. + * + * @return the built config + */ + public AgentConfig build() { + AgentConfig agentConfig = new AgentConfig(); + agentConfig.setName(name); + agentConfig.setDescription(description); + agentConfig.setSystem(system); + agentConfig.setModel(model); + agentConfig.setTools(tools); + agentConfig.setAgents(agents); + agentConfig.setConfig(config); + agentConfig.setOutput(output); + return agentConfig; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/AgentHandoffException.java b/java/ai/src/main/java/com/google/genkit/ai/AgentHandoffException.java new file mode 100644 index 0000000000..f9185a072b --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/AgentHandoffException.java @@ -0,0 +1,86 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Map; + +/** + * Exception thrown when an agent tool is called to signal a handoff. + * + *

+ * When the model calls an agent-as-tool, this exception is thrown to signal + * that the chat should switch context to the target agent. The Chat class + * catches this exception and updates its system prompt, tools, and model to + * those of the target agent. + * + *

+ * This enables the "handoff" pattern where conversations can be transferred + * between specialized agents. + */ +public class AgentHandoffException extends RuntimeException { + + private final String targetAgentName; + private final AgentConfig targetAgentConfig; + private final Map handoffInput; + + /** + * Creates a new AgentHandoffException. + * + * @param targetAgentName + * the name of the agent to hand off to + * @param targetAgentConfig + * the configuration of the target agent + * @param handoffInput + * the input passed to the agent tool (can be used for context) + */ + public AgentHandoffException(String targetAgentName, AgentConfig targetAgentConfig, + Map handoffInput) { + super("Handoff to agent: " + targetAgentName); + this.targetAgentName = targetAgentName; + this.targetAgentConfig = targetAgentConfig; + this.handoffInput = handoffInput; + } + + /** + * Gets the name of the target agent. + * + * @return the target agent name + */ + public String getTargetAgentName() { + return targetAgentName; + } + + /** + * Gets the configuration of the target agent. + * + * @return the target agent config + */ + public AgentConfig getTargetAgentConfig() { + return targetAgentConfig; + } + + /** + * Gets the input passed to the agent tool. + * + * @return the handoff input + */ + public Map getHandoffInput() { + return handoffInput; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Candidate.java b/java/ai/src/main/java/com/google/genkit/ai/Candidate.java new file mode 100644 index 0000000000..0f8dbb6c10 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Candidate.java @@ -0,0 +1,135 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Candidate represents a single model response candidate. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class Candidate { + + @JsonProperty("index") + private int index; + + @JsonProperty("message") + private Message message; + + @JsonProperty("finishReason") + private FinishReason finishReason; + + @JsonProperty("finishMessage") + private String finishMessage; + + @JsonProperty("custom") + private Map custom; + + /** + * Default constructor. + */ + public Candidate() { + } + + /** + * Creates a Candidate with a message. + * + * @param message + * the candidate message + */ + public Candidate(Message message) { + this.message = message; + } + + /** + * Creates a Candidate with message and finish reason. + * + * @param message + * the candidate message + * @param finishReason + * the finish reason + */ + public Candidate(Message message, FinishReason finishReason) { + this.message = message; + this.finishReason = finishReason; + } + + // Getters and setters + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + + public Message getMessage() { + return message; + } + + public void setMessage(Message message) { + this.message = message; + } + + public FinishReason getFinishReason() { + return finishReason; + } + + public void setFinishReason(FinishReason finishReason) { + this.finishReason = finishReason; + } + + public String getFinishMessage() { + return finishMessage; + } + + public void setFinishMessage(String finishMessage) { + this.finishMessage = finishMessage; + } + + public Map getCustom() { + return custom; + } + + public void setCustom(Map custom) { + this.custom = custom; + } + + /** + * Extracts the text content from this candidate. + * + * @return the concatenated text content + */ + public String text() { + if (message == null || message.getContent() == null) { + return ""; + } + StringBuilder sb = new StringBuilder(); + for (Part part : message.getContent()) { + if (part.getText() != null) { + sb.append(part.getText()); + } + } + return sb.toString(); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Document.java b/java/ai/src/main/java/com/google/genkit/ai/Document.java new file mode 100644 index 0000000000..4b27306749 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Document.java @@ -0,0 +1,149 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Document represents a document for use with embedders and retrievers. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class Document { + + @JsonProperty("content") + private List content; + + @JsonProperty("metadata") + private Map metadata; + + /** + * Default constructor. + */ + public Document() { + this.content = new ArrayList<>(); + this.metadata = new HashMap<>(); + } + + /** + * Creates a Document with text content. + * + * @param text + * the text content + */ + public Document(String text) { + this(); + this.content.add(Part.text(text)); + } + + /** + * Creates a Document with parts. + * + * @param content + * the content parts + */ + public Document(List content) { + this.content = content != null ? content : new ArrayList<>(); + this.metadata = new HashMap<>(); + } + + /** + * Creates a text Document. + * + * @param text + * the text content + * @return a Document with text content + */ + public static Document fromText(String text) { + return new Document(text); + } + + /** + * Creates a Document with text and metadata. + * + * @param text + * the text content + * @param metadata + * the metadata + * @return a Document with text content and metadata + */ + public static Document fromText(String text, Map metadata) { + Document doc = new Document(text); + doc.metadata = metadata != null ? metadata : new HashMap<>(); + return doc; + } + + /** + * Gets the text content of this Document. + * + * @return the concatenated text content + */ + public String text() { + if (content == null) { + return ""; + } + StringBuilder sb = new StringBuilder(); + for (Part part : content) { + if (part.getText() != null) { + sb.append(part.getText()); + } + } + return sb.toString(); + } + + // Getters and setters + + public List getContent() { + return content; + } + + public void setContent(List content) { + this.content = content; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } + + /** + * Adds metadata to this Document. + * + * @param key + * the metadata key + * @param value + * the metadata value + * @return this Document for chaining + */ + public Document withMetadata(String key, Object value) { + if (this.metadata == null) { + this.metadata = new HashMap<>(); + } + this.metadata.put(key, value); + return this; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/EmbedRequest.java b/java/ai/src/main/java/com/google/genkit/ai/EmbedRequest.java new file mode 100644 index 0000000000..47d8efe663 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/EmbedRequest.java @@ -0,0 +1,72 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * EmbedRequest contains documents to embed. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EmbedRequest { + + @JsonProperty("input") + private List documents; + + @JsonProperty("options") + private Map options; + + /** + * Default constructor. + */ + public EmbedRequest() { + } + + /** + * Creates an EmbedRequest with documents. + * + * @param documents + * the documents to embed + */ + public EmbedRequest(List documents) { + this.documents = documents; + } + + // Getters and setters + + public List getDocuments() { + return documents; + } + + public void setDocuments(List documents) { + this.documents = documents; + } + + public Map getOptions() { + return options; + } + + public void setOptions(Map options) { + this.options = options; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/EmbedResponse.java b/java/ai/src/main/java/com/google/genkit/ai/EmbedResponse.java new file mode 100644 index 0000000000..9ec68857b2 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/EmbedResponse.java @@ -0,0 +1,94 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * EmbedResponse contains the embeddings generated from documents. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EmbedResponse { + + @JsonProperty("embeddings") + private List embeddings; + + /** + * Default constructor. + */ + public EmbedResponse() { + } + + /** + * Creates an EmbedResponse with embeddings. + * + * @param embeddings + * the embeddings + */ + public EmbedResponse(List embeddings) { + this.embeddings = embeddings; + } + + // Getters and setters + + public List getEmbeddings() { + return embeddings; + } + + public void setEmbeddings(List embeddings) { + this.embeddings = embeddings; + } + + /** + * Embedding represents a single embedding vector. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class Embedding { + + @JsonProperty("values") + private float[] values; + + /** + * Default constructor. + */ + public Embedding() { + } + + /** + * Creates an Embedding with the given values. + * + * @param values + * the embedding values + */ + public Embedding(float[] values) { + this.values = values; + } + + public float[] getValues() { + return values; + } + + public void setValues(float[] values) { + this.values = values; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Embedder.java b/java/ai/src/main/java/com/google/genkit/ai/Embedder.java new file mode 100644 index 0000000000..4182c614e3 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Embedder.java @@ -0,0 +1,261 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionDesc; +import com.google.genkit.core.ActionRunResult; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.JsonUtils; +import com.google.genkit.core.Registry; + +/** + * Embedder is an action that generates embeddings from documents. + * + * Embedders convert text or other content into numerical vectors that can be + * used for similarity search and retrieval. + */ +public class Embedder implements Action { + + private final String name; + private final EmbedderInfo info; + private final BiFunction handler; + private final Map metadata; + + /** + * Creates a new Embedder. + * + * @param name + * the embedder name + * @param info + * the embedder info + * @param handler + * the embedding function + */ + public Embedder(String name, EmbedderInfo info, BiFunction handler) { + this.name = name; + this.info = info; + this.handler = handler; + this.metadata = new HashMap<>(); + if (info != null) { + this.metadata.put("info", info); + } + } + + /** + * Creates a builder for Embedder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + @Override + public String getName() { + return name; + } + + @Override + public ActionType getType() { + return ActionType.EMBEDDER; + } + + @Override + public ActionDesc getDesc() { + return ActionDesc.builder().type(ActionType.EMBEDDER).name(name).inputSchema(getInputSchema()) + .outputSchema(getOutputSchema()).metadata(metadata).build(); + } + + @Override + public EmbedResponse run(ActionContext ctx, EmbedRequest input) throws GenkitException { + try { + return handler.apply(ctx, input); + } catch (Exception e) { + throw new GenkitException("Embedder execution failed: " + e.getMessage(), e); + } + } + + @Override + public EmbedResponse run(ActionContext ctx, EmbedRequest input, Consumer streamCallback) + throws GenkitException { + return run(ctx, input); + } + + @Override + public JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException { + EmbedRequest request = JsonUtils.fromJsonNode(input, EmbedRequest.class); + EmbedResponse response = run(ctx, request); + return JsonUtils.toJsonNode(response); + } + + @Override + public ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, + Consumer streamCallback) throws GenkitException { + JsonNode result = runJson(ctx, input, streamCallback); + return new ActionRunResult<>(result, null, null); + } + + @Override + public Map getInputSchema() { + // Define the input schema for embedders + // This follows the EmbedRequestSchema from genkit-tools + Map schema = new HashMap<>(); + schema.put("type", "object"); + + Map properties = new HashMap<>(); + + // input: array of documents (matching Document structure with content array) + Map inputProp = new HashMap<>(); + inputProp.put("type", "array"); + inputProp.put("description", "Array of documents to embed"); + + // Document schema + Map docItemSchema = new HashMap<>(); + docItemSchema.put("type", "object"); + Map docProps = new HashMap<>(); + + // content array in each document (array of Parts) + Map contentProp = new HashMap<>(); + contentProp.put("type", "array"); + Map partSchema = new HashMap<>(); + partSchema.put("type", "object"); + Map partProps = new HashMap<>(); + Map textProp = new HashMap<>(); + textProp.put("type", "string"); + textProp.put("description", "Text content to embed"); + partProps.put("text", textProp); + partSchema.put("properties", partProps); + contentProp.put("items", partSchema); + docProps.put("content", contentProp); + + // metadata in document + Map metaProp = new HashMap<>(); + metaProp.put("type", "object"); + metaProp.put("additionalProperties", true); + docProps.put("metadata", metaProp); + + docItemSchema.put("properties", docProps); + docItemSchema.put("required", java.util.List.of("content")); + inputProp.put("items", docItemSchema); + properties.put("input", inputProp); + + // options: optional configuration + Map optionsProp = new HashMap<>(); + optionsProp.put("type", "object"); + optionsProp.put("description", "Optional embedding configuration"); + properties.put("options", optionsProp); + + schema.put("properties", properties); + schema.put("required", java.util.List.of("input")); + + return schema; + } + + @Override + public Map getOutputSchema() { + // Define the output schema for embedders + Map schema = new HashMap<>(); + schema.put("type", "object"); + + Map properties = new HashMap<>(); + + // embeddings: array of embedding objects + Map embeddingsProp = new HashMap<>(); + embeddingsProp.put("type", "array"); + Map embeddingSchema = new HashMap<>(); + embeddingSchema.put("type", "object"); + Map embeddingProps = new HashMap<>(); + Map embeddingArrayProp = new HashMap<>(); + embeddingArrayProp.put("type", "array"); + Map numberItem = new HashMap<>(); + numberItem.put("type", "number"); + embeddingArrayProp.put("items", numberItem); + embeddingProps.put("values", embeddingArrayProp); + embeddingSchema.put("properties", embeddingProps); + embeddingsProp.put("items", embeddingSchema); + properties.put("embeddings", embeddingsProp); + + schema.put("properties", properties); + + return schema; + } + + @Override + public Map getMetadata() { + return metadata; + } + + @Override + public void register(Registry registry) { + registry.registerAction(ActionType.EMBEDDER.keyFromName(name), this); + } + + /** + * Gets the embedder info. + * + * @return the embedder info + */ + public EmbedderInfo getInfo() { + return info; + } + + /** + * Builder for Embedder. + */ + public static class Builder { + private String name; + private EmbedderInfo info; + private BiFunction handler; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder info(EmbedderInfo info) { + this.info = info; + return this; + } + + public Builder handler(BiFunction handler) { + this.handler = handler; + return this; + } + + public Embedder build() { + if (name == null) { + throw new IllegalStateException("Embedder name is required"); + } + if (handler == null) { + throw new IllegalStateException("Embedder handler is required"); + } + return new Embedder(name, info, handler); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/EmbedderInfo.java b/java/ai/src/main/java/com/google/genkit/ai/EmbedderInfo.java new file mode 100644 index 0000000000..ea169d02bf --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/EmbedderInfo.java @@ -0,0 +1,90 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * EmbedderInfo contains metadata about an embedder's capabilities. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EmbedderInfo { + + @JsonProperty("label") + private String label; + + @JsonProperty("dimensions") + private Integer dimensions; + + @JsonProperty("supports") + private EmbedderCapabilities supports; + + /** + * Default constructor. + */ + public EmbedderInfo() { + } + + // Getters and setters + + public String getLabel() { + return label; + } + + public void setLabel(String label) { + this.label = label; + } + + public Integer getDimensions() { + return dimensions; + } + + public void setDimensions(Integer dimensions) { + this.dimensions = dimensions; + } + + public EmbedderCapabilities getSupports() { + return supports; + } + + public void setSupports(EmbedderCapabilities supports) { + this.supports = supports; + } + + /** + * EmbedderCapabilities describes what an embedder can do. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class EmbedderCapabilities { + + @JsonProperty("input") + private List input; + + public List getInput() { + return input; + } + + public void setInput(List input) { + this.input = input; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/FinishReason.java b/java/ai/src/main/java/com/google/genkit/ai/FinishReason.java new file mode 100644 index 0000000000..699f62faba --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/FinishReason.java @@ -0,0 +1,45 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * FinishReason indicates why the model stopped generating. + */ +public enum FinishReason { + + @JsonProperty("stop") + STOP, + + @JsonProperty("length") + LENGTH, + + @JsonProperty("blocked") + BLOCKED, + + @JsonProperty("interrupted") + INTERRUPTED, + + @JsonProperty("other") + OTHER, + + @JsonProperty("unknown") + UNKNOWN +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/GenerateAction.java b/java/ai/src/main/java/com/google/genkit/ai/GenerateAction.java new file mode 100644 index 0000000000..07b1479856 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/GenerateAction.java @@ -0,0 +1,390 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genkit.core.*; +import com.google.genkit.core.tracing.SpanMetadata; +import com.google.genkit.core.tracing.Tracer; + +/** + * GenerateAction is a utility action that provides a unified interface for + * generating content from AI models. It's registered at /util/generate and is + * used by the Dev UI. + */ +public class GenerateAction implements Action { + + private static final Logger logger = LoggerFactory.getLogger(GenerateAction.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private final Registry registry; + + public GenerateAction(Registry registry) { + this.registry = registry; + } + + /** + * Defines and registers the generate utility action. + * + * @param registry + * the registry to register with + * @return the generate action + */ + public static GenerateAction define(Registry registry) { + GenerateAction action = new GenerateAction(registry); + registry.registerAction("/util/generate", action); + logger.debug("Registered utility action: /util/generate"); + return action; + } + + @Override + public String getName() { + return "generate"; + } + + @Override + public ActionType getType() { + return ActionType.UTIL; + } + + @Override + public ActionDesc getDesc() { + return ActionDesc.builder().name("generate").type(ActionType.UTIL) + .description("Utility action for generating content from AI models").build(); + } + + @Override + public ModelResponse run(ActionContext ctx, GenerateActionOptions options) throws GenkitException { + return run(ctx, options, null); + } + + @Override + public ModelResponse run(ActionContext ctx, GenerateActionOptions options, + Consumer streamCallback) throws GenkitException { + if (options == null) { + throw new GenkitException("GenerateActionOptions cannot be null"); + } + + String modelName = options.getModel(); + if (modelName == null || modelName.isEmpty()) { + throw new GenkitException("Model name is required"); + } + + // Resolve the model action key + String modelKey = resolveModelKey(modelName); + + // Look up the model in the registry + Action action = registry.lookupAction(modelKey); + if (action == null) { + throw new GenkitException("Model not found: " + modelName + " (key: " + modelKey + ")"); + } + + if (!(action instanceof Model)) { + throw new GenkitException("Action is not a model: " + modelKey); + } + + Model model = (Model) action; + + // Build the model request from the options + ModelRequest request = buildModelRequest(options); + + logger.debug("Generating with model: {}", modelKey); + + // Create span metadata for the model call + SpanMetadata spanMetadata = SpanMetadata.builder().name(modelName).type(ActionType.MODEL.getValue()) + .subtype("model").build(); + + String flowName = ctx.getFlowName(); + if (flowName != null) { + spanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); + } + + // Run the model wrapped in a span + return Tracer.runInNewSpan(ctx, spanMetadata, request, (spanCtx, req) -> { + ActionContext newCtx = ctx.withSpanContext(spanCtx); + if (streamCallback != null && model.supportsStreaming()) { + return model.run(newCtx, req, streamCallback); + } else { + return model.run(newCtx, req); + } + }); + } + + @Override + public JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException { + try { + GenerateActionOptions options = objectMapper.treeToValue(input, GenerateActionOptions.class); + Consumer chunkCallback = null; + if (streamCallback != null) { + chunkCallback = chunk -> { + try { + streamCallback.accept(objectMapper.valueToTree(chunk)); + } catch (Exception e) { + logger.error("Error streaming chunk", e); + } + }; + } + ModelResponse response = run(ctx, options, chunkCallback); + return objectMapper.valueToTree(response); + } catch (GenkitException e) { + throw e; + } catch (Exception e) { + throw new GenkitException("Failed to process generate action", e); + } + } + + @Override + public ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, + Consumer streamCallback) throws GenkitException { + String traceId = UUID.randomUUID().toString(); + JsonNode result = runJson(ctx, input, streamCallback); + return new ActionRunResult<>(result, traceId, null); + } + + @Override + public Map getInputSchema() { + Map schema = new HashMap<>(); + schema.put("type", "object"); + Map props = new HashMap<>(); + props.put("model", Map.of("type", "string")); + props.put("messages", Map.of("type", "array")); + props.put("config", Map.of("type", "object")); + props.put("tools", Map.of("type", "array")); + schema.put("properties", props); + schema.put("required", List.of("messages")); + return schema; + } + + @Override + public Map getOutputSchema() { + Map schema = new HashMap<>(); + schema.put("type", "object"); + return schema; + } + + @Override + public Map getMetadata() { + Map metadata = new HashMap<>(); + metadata.put("type", "util"); + return metadata; + } + + @Override + public void register(Registry registry) { + registry.registerAction("/util/generate", this); + } + + /** + * Resolves a model name to a registry key. Handles formats like "openai/gpt-4o" + * -> "/model/openai/gpt-4o" + */ + private String resolveModelKey(String modelName) { + if (modelName.startsWith("/model/")) { + return modelName; + } + return "/model/" + modelName; + } + + /** + * Builds a ModelRequest from GenerateActionOptions. + */ + private ModelRequest buildModelRequest(GenerateActionOptions options) { + ModelRequest.Builder builder = ModelRequest.builder(); + + if (options.getMessages() != null) { + builder.messages(options.getMessages()); + } + + if (options.getConfig() != null) { + // Convert GenerationConfig to Map + Map configMap = objectMapper.convertValue(options.getConfig(), Map.class); + builder.config(configMap); + } + + if (options.getTools() != null && !options.getTools().isEmpty()) { + // Resolve tools from registry + List toolDefs = options.getTools().stream().map(this::resolveToolDefinition) + .filter(t -> t != null).toList(); + builder.tools(toolDefs); + } + + if (options.getOutput() != null) { + builder.output(options.getOutput()); + } + + return builder.build(); + } + + /** + * Resolves a tool name to its definition from the registry. + */ + private ToolDefinition resolveToolDefinition(String toolName) { + String toolKey = toolName.startsWith("/tool/") ? toolName : "/tool/" + toolName; + Action action = registry.lookupAction(toolKey); + if (action == null) { + logger.warn("Tool not found: {}", toolName); + return null; + } + + // Get tool definition from the action's desc + ActionDesc desc = action.getDesc(); + return new ToolDefinition(desc.getName(), desc.getDescription(), desc.getInputSchema(), null); + } + + /** + * Options for the generate utility action. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class GenerateActionOptions { + + @JsonProperty("model") + private String model; + + @JsonProperty("messages") + private List messages; + + @JsonProperty("tools") + private List tools; + + @JsonProperty("resources") + private List resources; + + @JsonProperty("toolChoice") + private String toolChoice; + + @JsonProperty("config") + private GenerationConfig config; + + @JsonProperty("output") + private OutputConfig output; + + @JsonProperty("docs") + private List docs; + + @JsonProperty("returnToolRequests") + private Boolean returnToolRequests; + + @JsonProperty("maxTurns") + private Integer maxTurns; + + @JsonProperty("stepName") + private String stepName; + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public List getMessages() { + return messages; + } + + public void setMessages(List messages) { + this.messages = messages; + } + + public List getTools() { + return tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public List getResources() { + return resources; + } + + public void setResources(List resources) { + this.resources = resources; + } + + public String getToolChoice() { + return toolChoice; + } + + public void setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + } + + public GenerationConfig getConfig() { + return config; + } + + public void setConfig(GenerationConfig config) { + this.config = config; + } + + public OutputConfig getOutput() { + return output; + } + + public void setOutput(OutputConfig output) { + this.output = output; + } + + public List getDocs() { + return docs; + } + + public void setDocs(List docs) { + this.docs = docs; + } + + public Boolean getReturnToolRequests() { + return returnToolRequests; + } + + public void setReturnToolRequests(Boolean returnToolRequests) { + this.returnToolRequests = returnToolRequests; + } + + public Integer getMaxTurns() { + return maxTurns; + } + + public void setMaxTurns(Integer maxTurns) { + this.maxTurns = maxTurns; + } + + public String getStepName() { + return stepName; + } + + public void setStepName(String stepName) { + this.stepName = stepName; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java b/java/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java new file mode 100644 index 0000000000..5019d57b6e --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java @@ -0,0 +1,363 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Options for text generation requests. + */ +public class GenerateOptions { + + private final String model; + private final String prompt; + private final List messages; + private final List docs; + private final String system; + private final List> tools; + private final Object toolChoice; + private final OutputConfig output; + private final GenerationConfig config; + private final Map context; + private final Integer maxTurns; + private final ResumeOptions resume; + + /** + * Creates new GenerateOptions. + * + * @param model + * the model name + * @param prompt + * simple text prompt + * @param messages + * conversation messages + * @param docs + * documents for context + * @param system + * system prompt + * @param tools + * available tools + * @param toolChoice + * tool selection strategy + * @param output + * output configuration + * @param config + * generation configuration + * @param context + * additional context + * @param maxTurns + * maximum conversation turns + * @param resume + * options for resuming after an interrupt + */ + public GenerateOptions(String model, String prompt, List messages, List docs, String system, + List> tools, Object toolChoice, OutputConfig output, GenerationConfig config, + Map context, Integer maxTurns, ResumeOptions resume) { + this.model = model; + this.prompt = prompt; + this.messages = messages; + this.docs = docs; + this.system = system; + this.tools = tools; + this.toolChoice = toolChoice; + this.output = output; + this.config = config; + this.context = context; + this.maxTurns = maxTurns; + this.resume = resume; + } + + /** + * Creates a builder for GenerateOptions. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the model name. + * + * @return the model name + */ + public String getModel() { + return model; + } + + /** + * Gets the text prompt. + * + * @return the prompt + */ + public String getPrompt() { + return prompt; + } + + /** + * Gets the conversation messages. + * + * @return the messages + */ + public List getMessages() { + return messages; + } + + /** + * Gets the documents for context. + * + * @return the documents + */ + public List getDocs() { + return docs; + } + + /** + * Gets the system prompt. + * + * @return the system prompt + */ + public String getSystem() { + return system; + } + + /** + * Gets the available tools. + * + * @return the tools + */ + public List> getTools() { + return tools; + } + + /** + * Gets the tool choice strategy. + * + * @return the tool choice + */ + public Object getToolChoice() { + return toolChoice; + } + + /** + * Gets the output configuration. + * + * @return the output config + */ + public OutputConfig getOutput() { + return output; + } + + /** + * Gets the generation configuration. + * + * @return the config + */ + public GenerationConfig getConfig() { + return config; + } + + /** + * Gets the additional context. + * + * @return the context + */ + public Map getContext() { + return context; + } + + /** + * Gets the maximum conversation turns. + * + * @return the max turns + */ + public Integer getMaxTurns() { + return maxTurns; + } + + /** + * Gets the resume options for continuing after an interrupt. + * + * @return the resume options, or null if not resuming + */ + public ResumeOptions getResume() { + return resume; + } + + /** + * Converts these options to a ModelRequest. + * + * @return a ModelRequest + */ + public ModelRequest toModelRequest() { + ModelRequest.Builder builder = ModelRequest.builder(); + + if (messages != null && !messages.isEmpty()) { + builder.messages(messages); + } else if (prompt != null) { + builder.addUserMessage(prompt); + } + + if (system != null) { + builder.addSystemMessage(system); + } + + if (tools != null && !tools.isEmpty()) { + List toolDefs = tools.stream().map(Tool::getDefinition).collect(Collectors.toList()); + builder.tools(toolDefs); + } + + if (config != null) { + // Convert GenerationConfig to a Map for the ModelRequest + Map configMap = new HashMap<>(); + if (config.getTemperature() != null) { + configMap.put("temperature", config.getTemperature()); + } + if (config.getMaxOutputTokens() != null) { + configMap.put("maxOutputTokens", config.getMaxOutputTokens()); + } + if (config.getTopP() != null) { + configMap.put("topP", config.getTopP()); + } + if (config.getTopK() != null) { + configMap.put("topK", config.getTopK()); + } + if (config.getStopSequences() != null) { + configMap.put("stopSequences", config.getStopSequences()); + } + if (config.getPresencePenalty() != null) { + configMap.put("presencePenalty", config.getPresencePenalty()); + } + if (config.getFrequencyPenalty() != null) { + configMap.put("frequencyPenalty", config.getFrequencyPenalty()); + } + if (config.getSeed() != null) { + configMap.put("seed", config.getSeed()); + } + // Include custom config for model-specific options (e.g., image generation) + if (config.getCustom() != null) { + configMap.putAll(config.getCustom()); + } + builder.config(configMap); + } + + if (output != null) { + builder.output(output); + } + + if (docs != null && !docs.isEmpty()) { + builder.context(docs); + } + + return builder.build(); + } + + /** + * Builder for GenerateOptions. + */ + public static class Builder { + private String model; + private String prompt; + private List messages; + private List docs; + private String system; + private List> tools; + private Object toolChoice; + private OutputConfig output; + private GenerationConfig config; + private Map context; + private Integer maxTurns; + private ResumeOptions resume; + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder prompt(String prompt) { + this.prompt = prompt; + return this; + } + + public Builder messages(List messages) { + this.messages = messages; + return this; + } + + public Builder docs(List docs) { + this.docs = docs; + return this; + } + + public Builder system(String system) { + this.system = system; + return this; + } + + public Builder tools(List> tools) { + this.tools = tools; + return this; + } + + public Builder toolChoice(Object toolChoice) { + this.toolChoice = toolChoice; + return this; + } + + public Builder output(OutputConfig output) { + this.output = output; + return this; + } + + public Builder config(GenerationConfig config) { + this.config = config; + return this; + } + + public Builder context(Map context) { + this.context = context; + return this; + } + + public Builder maxTurns(Integer maxTurns) { + this.maxTurns = maxTurns; + return this; + } + + /** + * Sets the resume options for continuing after an interrupt. + * + * @param resume + * the resume options + * @return this builder + */ + public Builder resume(ResumeOptions resume) { + this.resume = resume; + return this; + } + + public GenerateOptions build() { + return new GenerateOptions(model, prompt, messages, docs, system, tools, toolChoice, output, config, + context, maxTurns, resume); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/GenerationConfig.java b/java/ai/src/main/java/com/google/genkit/ai/GenerationConfig.java new file mode 100644 index 0000000000..6b514ebeea --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/GenerationConfig.java @@ -0,0 +1,201 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * GenerationConfig contains configuration for model generation. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class GenerationConfig { + + @JsonProperty("temperature") + private Double temperature; + + @JsonProperty("maxOutputTokens") + private Integer maxOutputTokens; + + @JsonProperty("topK") + private Integer topK; + + @JsonProperty("topP") + private Double topP; + + @JsonProperty("stopSequences") + private String[] stopSequences; + + @JsonProperty("presencePenalty") + private Double presencePenalty; + + @JsonProperty("frequencyPenalty") + private Double frequencyPenalty; + + @JsonProperty("seed") + private Integer seed; + + @JsonProperty("custom") + private Map custom; + + /** + * Default constructor. + */ + public GenerationConfig() { + } + + /** + * Builder pattern for GenerationConfig. + */ + public static Builder builder() { + return new Builder(); + } + + // Getters and setters + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public Integer getMaxOutputTokens() { + return maxOutputTokens; + } + + public void setMaxOutputTokens(Integer maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public String[] getStopSequences() { + return stopSequences; + } + + public void setStopSequences(String[] stopSequences) { + this.stopSequences = stopSequences; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public Map getCustom() { + return custom; + } + + public void setCustom(Map custom) { + this.custom = custom; + } + + /** + * Builder for GenerationConfig. + */ + public static class Builder { + private final GenerationConfig config = new GenerationConfig(); + + public Builder temperature(Double temperature) { + config.temperature = temperature; + return this; + } + + public Builder maxOutputTokens(Integer maxOutputTokens) { + config.maxOutputTokens = maxOutputTokens; + return this; + } + + public Builder topK(Integer topK) { + config.topK = topK; + return this; + } + + public Builder topP(Double topP) { + config.topP = topP; + return this; + } + + public Builder stopSequences(String... stopSequences) { + config.stopSequences = stopSequences; + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + config.presencePenalty = presencePenalty; + return this; + } + + public Builder frequencyPenalty(Double frequencyPenalty) { + config.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder seed(Integer seed) { + config.seed = seed; + return this; + } + + public Builder custom(Map custom) { + config.custom = custom; + return this; + } + + public GenerationConfig build() { + return config; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Indexer.java b/java/ai/src/main/java/com/google/genkit/ai/Indexer.java new file mode 100644 index 0000000000..73be72e6af --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Indexer.java @@ -0,0 +1,213 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionDesc; +import com.google.genkit.core.ActionRunResult; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.JsonUtils; +import com.google.genkit.core.Registry; + +/** + * Indexer is an action that indexes documents into a vector store. + * + * Indexers are used for RAG (Retrieval Augmented Generation) workflows to store + * documents that can later be retrieved. + */ +public class Indexer implements Action { + + private final String name; + private final BiFunction handler; + private final Map metadata; + + /** + * Creates a new Indexer. + * + * @param name + * the indexer name + * @param handler + * the indexing function + */ + public Indexer(String name, BiFunction handler) { + this.name = name; + this.handler = handler; + this.metadata = new HashMap<>(); + this.metadata.put("type", "indexer"); + } + + /** + * Creates a builder for Indexer. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + @Override + public String getName() { + return name; + } + + @Override + public ActionType getType() { + return ActionType.INDEXER; + } + + @Override + public ActionDesc getDesc() { + return ActionDesc.builder().type(ActionType.INDEXER).name(name).inputSchema(getInputSchema()) + .outputSchema(getOutputSchema()).metadata(getMetadata()).build(); + } + + @Override + public IndexerResponse run(ActionContext ctx, IndexerRequest input) throws GenkitException { + try { + return handler.apply(ctx, input); + } catch (Exception e) { + throw new GenkitException("Indexer execution failed: " + e.getMessage(), e); + } + } + + @Override + public IndexerResponse run(ActionContext ctx, IndexerRequest input, Consumer streamCallback) + throws GenkitException { + return run(ctx, input); + } + + @Override + public JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException { + IndexerRequest request = JsonUtils.fromJsonNode(input, IndexerRequest.class); + IndexerResponse response = run(ctx, request); + return JsonUtils.toJsonNode(response); + } + + @Override + public ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, + Consumer streamCallback) throws GenkitException { + JsonNode result = runJson(ctx, input, streamCallback); + return new ActionRunResult<>(result, null, null); + } + + @Override + public Map getInputSchema() { + // Define the input schema to match genkit-tools IndexerRequest schema + Map schema = new HashMap<>(); + schema.put("type", "object"); + + Map properties = new HashMap<>(); + + // documents array property + Map docsProp = new HashMap<>(); + docsProp.put("type", "array"); + Map docItemSchema = new HashMap<>(); + docItemSchema.put("type", "object"); + Map docProps = new HashMap<>(); + + // content array in each document + Map contentProp = new HashMap<>(); + contentProp.put("type", "array"); + Map partSchema = new HashMap<>(); + partSchema.put("type", "object"); + Map partProps = new HashMap<>(); + Map textProp = new HashMap<>(); + textProp.put("type", "string"); + partProps.put("text", textProp); + partSchema.put("properties", partProps); + contentProp.put("items", partSchema); + docProps.put("content", contentProp); + + // metadata + Map metaProp = new HashMap<>(); + metaProp.put("type", "object"); + metaProp.put("additionalProperties", true); + docProps.put("metadata", metaProp); + + docItemSchema.put("properties", docProps); + docItemSchema.put("required", java.util.List.of("content")); + docsProp.put("items", docItemSchema); + properties.put("documents", docsProp); + + // options property + Map optionsProp = new HashMap<>(); + optionsProp.put("type", "object"); + properties.put("options", optionsProp); + + schema.put("properties", properties); + schema.put("required", java.util.List.of("documents")); + + return schema; + } + + @Override + public Map getOutputSchema() { + // Indexer returns void/empty response + Map schema = new HashMap<>(); + schema.put("type", "object"); + return schema; + } + + @Override + public Map getMetadata() { + return metadata; + } + + @Override + public void register(Registry registry) { + registry.registerAction(ActionType.INDEXER.keyFromName(name), this); + } + + /** + * Builder for Indexer. + */ + public static class Builder { + private String name; + private BiFunction handler; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder handler(BiFunction handler) { + this.handler = handler; + return this; + } + + public Indexer build() { + if (name == null) { + throw new IllegalStateException("Indexer name is required"); + } + if (handler == null) { + throw new IllegalStateException("Indexer handler is required"); + } + return new Indexer(name, handler); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/IndexerRequest.java b/java/ai/src/main/java/com/google/genkit/ai/IndexerRequest.java new file mode 100644 index 0000000000..1b747c9af2 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/IndexerRequest.java @@ -0,0 +1,106 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.ArrayList; +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Request to index documents into a vector store. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class IndexerRequest { + + @JsonProperty("documents") + private List documents; + + @JsonProperty("options") + private Object options; + + /** + * Default constructor. + */ + public IndexerRequest() { + this.documents = new ArrayList<>(); + } + + /** + * Creates a request with documents. + * + * @param documents + * the documents to index + */ + public IndexerRequest(List documents) { + this.documents = documents != null ? documents : new ArrayList<>(); + } + + /** + * Creates a request with documents and options. + * + * @param documents + * the documents to index + * @param options + * the indexing options + */ + public IndexerRequest(List documents, Object options) { + this.documents = documents != null ? documents : new ArrayList<>(); + this.options = options; + } + + /** + * Gets the documents to index. + * + * @return the documents + */ + public List getDocuments() { + return documents; + } + + /** + * Sets the documents to index. + * + * @param documents + * the documents + */ + public void setDocuments(List documents) { + this.documents = documents; + } + + /** + * Gets the indexing options. + * + * @return the options + */ + public Object getOptions() { + return options; + } + + /** + * Sets the indexing options. + * + * @param options + * the options + */ + public void setOptions(Object options) { + this.options = options; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/IndexerResponse.java b/java/ai/src/main/java/com/google/genkit/ai/IndexerResponse.java new file mode 100644 index 0000000000..8d3924794b --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/IndexerResponse.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import com.fasterxml.jackson.annotation.JsonInclude; + +/** + * Response from an indexer operation. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class IndexerResponse { + + /** + * Default constructor. + */ + public IndexerResponse() { + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/InterruptConfig.java b/java/ai/src/main/java/com/google/genkit/ai/InterruptConfig.java new file mode 100644 index 0000000000..5d029921ae --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/InterruptConfig.java @@ -0,0 +1,322 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Map; +import java.util.function.Function; + +/** + * Configuration for defining an interrupt tool. + * + *

+ * An interrupt is a special type of tool that pauses generation to request user + * input (human-in-the-loop pattern). When the model calls an interrupt tool, + * execution stops and the interrupt information is returned to the caller for + * handling. + * + *

+ * Example usage: + * + *

{@code
+ * InterruptConfig config = InterruptConfig.builder()
+ * 		.name("confirmAction").description("Ask user to confirm the action").inputType(ConfirmInput.class)
+ * 		.outputType(ConfirmOutput.class).requestMetadata(input -> Map.of("action", input.getAction())).build();
+ * }
+ * + * @param + * the input type + * @param + * the output type (response type) + */ +public class InterruptConfig { + + private String name; + private String description; + private Class inputType; + private Class outputType; + private Map inputSchema; + private Map outputSchema; + private Function> requestMetadata; + + /** Default constructor. */ + public InterruptConfig() { + } + + /** + * Gets the interrupt name. + * + * @return the name + */ + public String getName() { + return name; + } + + /** + * Sets the interrupt name. + * + * @param name + * the name + */ + public void setName(String name) { + this.name = name; + } + + /** + * Gets the description. + * + * @return the description + */ + public String getDescription() { + return description; + } + + /** + * Sets the description. + * + * @param description + * the description + */ + public void setDescription(String description) { + this.description = description; + } + + /** + * Gets the input type class. + * + * @return the input type class + */ + public Class getInputType() { + return inputType; + } + + /** + * Sets the input type class. + * + * @param inputType + * the input type class + */ + public void setInputType(Class inputType) { + this.inputType = inputType; + } + + /** + * Gets the output type class. + * + * @return the output type class + */ + public Class getOutputType() { + return outputType; + } + + /** + * Sets the output type class. + * + * @param outputType + * the output type class + */ + public void setOutputType(Class outputType) { + this.outputType = outputType; + } + + /** + * Gets the input schema. + * + * @return the input schema + */ + public Map getInputSchema() { + return inputSchema; + } + + /** + * Sets the input schema. + * + * @param inputSchema + * the input schema + */ + public void setInputSchema(Map inputSchema) { + this.inputSchema = inputSchema; + } + + /** + * Gets the output schema. + * + * @return the output schema + */ + public Map getOutputSchema() { + return outputSchema; + } + + /** + * Sets the output schema. + * + * @param outputSchema + * the output schema + */ + public void setOutputSchema(Map outputSchema) { + this.outputSchema = outputSchema; + } + + /** + * Gets the request metadata function. + * + * @return the request metadata function + */ + public Function> getRequestMetadata() { + return requestMetadata; + } + + /** + * Sets the request metadata function. + * + * @param requestMetadata + * the request metadata function + */ + public void setRequestMetadata(Function> requestMetadata) { + this.requestMetadata = requestMetadata; + } + + /** + * Creates a new builder. + * + * @param + * the input type + * @param + * the output type + * @return a new builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** Builder for InterruptConfig. */ + public static class Builder { + private String name; + private String description; + private Class inputType; + private Class outputType; + private Map inputSchema; + private Map outputSchema; + private Function> requestMetadata; + + /** + * Sets the interrupt name. + * + * @param name + * the name + * @return this builder + */ + public Builder name(String name) { + this.name = name; + return this; + } + + /** + * Sets the description. + * + * @param description + * the description + * @return this builder + */ + public Builder description(String description) { + this.description = description; + return this; + } + + /** + * Sets the input type class. + * + * @param inputType + * the input type class + * @return this builder + */ + public Builder inputType(Class inputType) { + this.inputType = inputType; + return this; + } + + /** + * Sets the output type class. + * + * @param outputType + * the output type class + * @return this builder + */ + public Builder outputType(Class outputType) { + this.outputType = outputType; + return this; + } + + /** + * Sets the input schema. + * + * @param inputSchema + * the input schema + * @return this builder + */ + public Builder inputSchema(Map inputSchema) { + this.inputSchema = inputSchema; + return this; + } + + /** + * Sets the output schema. + * + * @param outputSchema + * the output schema + * @return this builder + */ + public Builder outputSchema(Map outputSchema) { + this.outputSchema = outputSchema; + return this; + } + + /** + * Sets the request metadata function. + * + *

+ * This function is called with the tool input when the interrupt is triggered, + * and should return metadata that will be included in the interrupt. + * + * @param requestMetadata + * function to generate metadata from input + * @return this builder + */ + public Builder requestMetadata(Function> requestMetadata) { + this.requestMetadata = requestMetadata; + return this; + } + + /** + * Builds the InterruptConfig. + * + * @return the built config + */ + public InterruptConfig build() { + InterruptConfig config = new InterruptConfig<>(); + config.setName(name); + config.setDescription(description); + config.setInputType(inputType); + config.setOutputType(outputType); + config.setInputSchema(inputSchema); + config.setOutputSchema(outputSchema); + config.setRequestMetadata(requestMetadata); + return config; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/InterruptRequest.java b/java/ai/src/main/java/com/google/genkit/ai/InterruptRequest.java new file mode 100644 index 0000000000..255376b9fd --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/InterruptRequest.java @@ -0,0 +1,180 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.HashMap; +import java.util.Map; + +/** + * Represents an interrupt request from a tool. + * + *

+ * When a tool triggers an interrupt, this class captures the tool request + * information and any associated metadata. The caller can then respond to the + * interrupt or restart the tool. + * + *

+ * Example usage: + * + *

{@code
+ * // Check for interrupts in response
+ * List interrupts = response.getInterrupts();
+ * if (!interrupts.isEmpty()) {
+ * 	InterruptRequest interrupt = interrupts.get(0);
+ * 	// Present to user and get response
+ * 	String userResponse = getUserInput(interrupt.getToolRequest().getInput());
+ *
+ * 	// Resume with user's response
+ * 	ModelResponse resumed = chat.send(message, SendOptions.builder()
+ * 			.resume(ResumeOptions.builder().respond(interrupt.respond(userResponse)).build()).build());
+ * }
+ * }
+ */ +public class InterruptRequest { + + private final ToolRequest toolRequest; + private final Map metadata; + + /** + * Creates a new InterruptRequest. + * + * @param toolRequest + * the original tool request + * @param metadata + * the interrupt metadata + */ + public InterruptRequest(ToolRequest toolRequest, Map metadata) { + this.toolRequest = toolRequest; + this.metadata = metadata != null ? new HashMap<>(metadata) : new HashMap<>(); + // Mark as interrupt + this.metadata.put("interrupt", true); + } + + /** + * Gets the tool request that was interrupted. + * + * @return the tool request + */ + public ToolRequest getToolRequest() { + return toolRequest; + } + + /** + * Gets the interrupt metadata. + * + * @return the metadata + */ + public Map getMetadata() { + return metadata; + } + + /** + * Checks if this is an interrupt. + * + * @return true if this is an interrupt (always true for InterruptRequest) + */ + public boolean isInterrupt() { + return true; + } + + /** + * Creates a tool response to respond to this interrupt. + * + * @param output + * the output data to respond with + * @return a ToolResponse part + */ + public ToolResponse respond(Object output) { + return respond(output, null); + } + + /** + * Creates a tool response to respond to this interrupt with additional + * metadata. + * + * @param output + * the output data to respond with + * @param responseMetadata + * additional metadata for the response + * @return a ToolResponse part + */ + public ToolResponse respond(Object output, Map responseMetadata) { + ToolResponse response = new ToolResponse(); + response.setName(toolRequest.getName()); + response.setRef(toolRequest.getRef()); + response.setOutput(output); + + Map meta = new HashMap<>(); + meta.put("interruptResponse", responseMetadata != null ? responseMetadata : true); + response.setMetadata(meta); + + return response; + } + + /** + * Creates a tool request to restart this interrupt. + * + * @return a ToolRequest to restart execution + */ + public ToolRequest restart() { + return restart(null, null); + } + + /** + * Creates a tool request to restart this interrupt with new metadata. + * + * @param resumedMetadata + * metadata for the resumed execution + * @return a ToolRequest to restart execution + */ + public ToolRequest restart(Map resumedMetadata) { + return restart(resumedMetadata, null); + } + + /** + * Creates a tool request to restart this interrupt with new input. + * + * @param resumedMetadata + * metadata for the resumed execution + * @param replaceInput + * new input to replace the original + * @return a ToolRequest to restart execution + */ + public ToolRequest restart(Map resumedMetadata, Object replaceInput) { + ToolRequest request = new ToolRequest(); + request.setName(toolRequest.getName()); + request.setRef(toolRequest.getRef()); + request.setInput(replaceInput != null ? replaceInput : toolRequest.getInput()); + + Map meta = new HashMap<>(this.metadata); + meta.put("resumed", resumedMetadata != null ? resumedMetadata : true); + if (replaceInput != null) { + meta.put("replacedInput", toolRequest.getInput()); + } + request.setMetadata(meta); + + return request; + } + + @Override + public String toString() { + return "InterruptRequest{" + "toolName=" + (toolRequest != null ? toolRequest.getName() : null) + ", metadata=" + + metadata + '}'; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Media.java b/java/ai/src/main/java/com/google/genkit/ai/Media.java new file mode 100644 index 0000000000..b85f9fecf9 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Media.java @@ -0,0 +1,72 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Media represents media content in a message part. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class Media { + + @JsonProperty("contentType") + private String contentType; + + @JsonProperty("url") + private String url; + + /** + * Default constructor. + */ + public Media() { + } + + /** + * Creates a Media with the given content type and URL. + * + * @param contentType + * the MIME type + * @param url + * the media URL or data URI + */ + public Media(String contentType, String url) { + this.contentType = contentType; + this.url = url; + } + + // Getters and setters + + public String getContentType() { + return contentType; + } + + public void setContentType(String contentType) { + this.contentType = contentType; + } + + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Message.java b/java/ai/src/main/java/com/google/genkit/ai/Message.java new file mode 100644 index 0000000000..bf01d3c0d1 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Message.java @@ -0,0 +1,214 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Message represents a message in a conversation with a generative AI model. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class Message { + + @JsonProperty("role") + private Role role; + + @JsonProperty("content") + private List content = new ArrayList<>(); + + @JsonProperty("metadata") + private Map metadata; + + /** + * Default constructor. + */ + public Message() { + } + + /** + * Creates a message with the given role and content. + * + * @param role + * the message role + * @param content + * the content parts + */ + public Message(Role role, List content) { + this.role = role; + this.content = content != null ? new ArrayList<>(content) : new ArrayList<>(); + } + + /** + * Creates a user message with text content. + * + * @param text + * the text content + * @return a new user message + */ + public static Message user(String text) { + return new Message(Role.USER, Collections.singletonList(Part.text(text))); + } + + /** + * Creates a system message with text content. + * + * @param text + * the text content + * @return a new system message + */ + public static Message system(String text) { + return new Message(Role.SYSTEM, Collections.singletonList(Part.text(text))); + } + + /** + * Creates a model message with text content. + * + * @param text + * the text content + * @return a new model message + */ + public static Message model(String text) { + return new Message(Role.MODEL, Collections.singletonList(Part.text(text))); + } + + /** + * Creates a tool message with content. + * + * @param content + * the content parts + * @return a new tool message + */ + public static Message tool(List content) { + return new Message(Role.TOOL, content); + } + + /** + * Returns the text content from all text parts. + * + * @return the concatenated text + */ + public String getText() { + if (content == null || content.isEmpty()) { + return ""; + } + StringBuilder sb = new StringBuilder(); + for (Part part : content) { + if (part.getText() != null) { + sb.append(part.getText()); + } + } + return sb.toString(); + } + + // Getters and setters + + public Role getRole() { + return role; + } + + public void setRole(Role role) { + this.role = role; + } + + public List getContent() { + return content; + } + + public void setContent(List content) { + this.content = content; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } + + /** + * Adds a part to the message content. + * + * @param part + * the part to add + * @return this message for chaining + */ + public Message addPart(Part part) { + if (this.content == null) { + this.content = new ArrayList<>(); + } + this.content.add(part); + return this; + } + + /** + * Creates a builder for Message. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for Message. + */ + public static class Builder { + private Role role; + private List content = new ArrayList<>(); + private Map metadata; + + public Builder role(Role role) { + this.role = role; + return this; + } + + public Builder content(List content) { + this.content = new ArrayList<>(content); + return this; + } + + public Builder addPart(Part part) { + this.content.add(part); + return this; + } + + public Builder addText(String text) { + this.content.add(Part.text(text)); + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Message build() { + Message message = new Message(role, content); + message.setMetadata(metadata); + return message; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Model.java b/java/ai/src/main/java/com/google/genkit/ai/Model.java new file mode 100644 index 0000000000..14d7bcb749 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Model.java @@ -0,0 +1,145 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionDesc; +import com.google.genkit.core.ActionRunResult; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.JsonUtils; +import com.google.genkit.core.Registry; + +/** + * Model is the interface for AI model implementations. + * + * Models are registered as actions and can be invoked to generate responses + * from prompts. + */ +public interface Model extends Action { + + /** + * Gets information about the model's capabilities. + * + * @return the model info + */ + ModelInfo getInfo(); + + /** + * Generates a response from the given request. + * + * @param ctx + * the action context + * @param request + * the model request + * @return the model response + * @throws GenkitException + * if generation fails + */ + @Override + ModelResponse run(ActionContext ctx, ModelRequest request) throws GenkitException; + + /** + * Generates a streaming response from the given request. + * + * @param ctx + * the action context + * @param request + * the model request + * @param streamCallback + * callback for streaming chunks + * @return the final model response + * @throws GenkitException + * if generation fails + */ + @Override + default ModelResponse run(ActionContext ctx, ModelRequest request, Consumer streamCallback) + throws GenkitException { + // Default implementation doesn't support streaming + return run(ctx, request); + } + + /** + * Returns whether this model supports streaming. + * + * @return true if streaming is supported + */ + default boolean supportsStreaming() { + return false; + } + + @Override + default ActionType getType() { + return ActionType.MODEL; + } + + @Override + default ActionDesc getDesc() { + return ActionDesc.builder().type(ActionType.MODEL).name(getName()).metadata(getMetadata()).build(); + } + + @Override + default JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException { + ModelRequest request = JsonUtils.fromJsonNode(input, ModelRequest.class); + Consumer typedCallback = null; + if (streamCallback != null) { + typedCallback = chunk -> streamCallback.accept(JsonUtils.toJsonNode(chunk)); + } + ModelResponse response = run(ctx, request, typedCallback); + return JsonUtils.toJsonNode(response); + } + + @Override + default ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, + Consumer streamCallback) throws GenkitException { + JsonNode result = runJson(ctx, input, streamCallback); + return new ActionRunResult<>(result, null, null); + } + + @Override + default Map getInputSchema() { + return null; + } + + @Override + default Map getOutputSchema() { + return null; + } + + @Override + default Map getMetadata() { + Map metadata = new HashMap<>(); + if (getInfo() != null) { + metadata.put("model", getInfo()); + } + return metadata; + } + + @Override + default void register(Registry registry) { + registry.registerAction(ActionType.MODEL.keyFromName(getName()), this); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/ModelInfo.java b/java/ai/src/main/java/com/google/genkit/ai/ModelInfo.java new file mode 100644 index 0000000000..36f94d48a6 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/ModelInfo.java @@ -0,0 +1,159 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.List; +import java.util.Set; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * ModelInfo contains metadata about a model's capabilities. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ModelInfo { + + @JsonProperty("label") + private String label; + + @JsonProperty("supports") + private ModelCapabilities supports; + + @JsonProperty("versions") + private List versions; + + /** + * Default constructor. + */ + public ModelInfo() { + } + + // Getters and setters + + public String getLabel() { + return label; + } + + public void setLabel(String label) { + this.label = label; + } + + public ModelCapabilities getSupports() { + return supports; + } + + public void setSupports(ModelCapabilities supports) { + this.supports = supports; + } + + public List getVersions() { + return versions; + } + + public void setVersions(List versions) { + this.versions = versions; + } + + /** + * ModelCapabilities describes what a model can do. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class ModelCapabilities { + + @JsonProperty("multiturn") + private Boolean multiturn; + + @JsonProperty("media") + private Boolean media; + + @JsonProperty("tools") + private Boolean tools; + + @JsonProperty("systemRole") + private Boolean systemRole; + + @JsonProperty("output") + private Set output; + + @JsonProperty("context") + private Boolean context; + + @JsonProperty("contextCaching") + private Boolean contextCaching; + + // Getters and setters + + public Boolean getMultiturn() { + return multiturn; + } + + public void setMultiturn(Boolean multiturn) { + this.multiturn = multiturn; + } + + public Boolean getMedia() { + return media; + } + + public void setMedia(Boolean media) { + this.media = media; + } + + public Boolean getTools() { + return tools; + } + + public void setTools(Boolean tools) { + this.tools = tools; + } + + public Boolean getSystemRole() { + return systemRole; + } + + public void setSystemRole(Boolean systemRole) { + this.systemRole = systemRole; + } + + public Set getOutput() { + return output; + } + + public void setOutput(Set output) { + this.output = output; + } + + public Boolean getContext() { + return context; + } + + public void setContext(Boolean context) { + this.context = context; + } + + public Boolean getContextCaching() { + return contextCaching; + } + + public void setContextCaching(Boolean contextCaching) { + this.contextCaching = contextCaching; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/ModelRequest.java b/java/ai/src/main/java/com/google/genkit/ai/ModelRequest.java new file mode 100644 index 0000000000..0742586423 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/ModelRequest.java @@ -0,0 +1,190 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * ModelRequest represents a request to a generative AI model. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ModelRequest { + + @JsonProperty("messages") + private List messages = new ArrayList<>(); + + @JsonProperty("config") + private Map config; + + @JsonProperty("tools") + private List tools; + + @JsonProperty("output") + private OutputConfig output; + + @JsonProperty("context") + private List context; + + /** + * Default constructor. + */ + public ModelRequest() { + } + + /** + * Creates a ModelRequest with the given messages. + * + * @param messages + * the messages + */ + public ModelRequest(List messages) { + this.messages = messages != null ? new ArrayList<>(messages) : new ArrayList<>(); + } + + /** + * Creates a builder for ModelRequest. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + // Getters and setters + + public List getMessages() { + return messages; + } + + public void setMessages(List messages) { + this.messages = messages; + } + + public Map getConfig() { + return config; + } + + public void setConfig(Map config) { + this.config = config; + } + + public List getTools() { + return tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public OutputConfig getOutput() { + return output; + } + + public void setOutput(OutputConfig output) { + this.output = output; + } + + public List getContext() { + return context; + } + + public void setContext(List context) { + this.context = context; + } + + /** + * Adds a message to the request. + * + * @param message + * the message to add + * @return this request for chaining + */ + public ModelRequest addMessage(Message message) { + if (this.messages == null) { + this.messages = new ArrayList<>(); + } + this.messages.add(message); + return this; + } + + /** + * Builder for ModelRequest. + */ + public static class Builder { + private List messages = new ArrayList<>(); + private Map config; + private List tools; + private OutputConfig output; + private List context; + + public Builder messages(List messages) { + this.messages = new ArrayList<>(messages); + return this; + } + + public Builder addMessage(Message message) { + this.messages.add(message); + return this; + } + + public Builder addUserMessage(String text) { + this.messages.add(Message.user(text)); + return this; + } + + public Builder addSystemMessage(String text) { + this.messages.add(Message.system(text)); + return this; + } + + public Builder config(Map config) { + this.config = config; + return this; + } + + public Builder tools(List tools) { + this.tools = tools; + return this; + } + + public Builder output(OutputConfig output) { + this.output = output; + return this; + } + + public Builder context(List context) { + this.context = context; + return this; + } + + public ModelRequest build() { + ModelRequest request = new ModelRequest(messages); + request.setConfig(config); + request.setTools(tools); + request.setOutput(output); + request.setContext(context); + return request; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/ModelResponse.java b/java/ai/src/main/java/com/google/genkit/ai/ModelResponse.java new file mode 100644 index 0000000000..5e16e50c2e --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/ModelResponse.java @@ -0,0 +1,313 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * ModelResponse represents a response from a generative AI model. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ModelResponse { + + @JsonProperty("candidates") + private List candidates = new ArrayList<>(); + + @JsonProperty("usage") + private Usage usage; + + @JsonProperty("request") + private ModelRequest request; + + @JsonProperty("custom") + private Map custom; + + @JsonProperty("latencyMs") + private Long latencyMs; + + @JsonProperty("finishReason") + private FinishReason finishReason; + + @JsonProperty("finishMessage") + private String finishMessage; + + @JsonProperty("interrupts") + private List interrupts; + + /** + * Default constructor. + */ + public ModelResponse() { + } + + /** + * Creates a ModelResponse with the given candidates. + * + * @param candidates + * the candidates + */ + public ModelResponse(List candidates) { + this.candidates = candidates != null ? new ArrayList<>(candidates) : new ArrayList<>(); + } + + /** + * Creates a builder for ModelResponse. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Returns the text content from the first candidate's first text part. + * + * @return the text content, or null if no text content is available + */ + public String getText() { + if (candidates == null || candidates.isEmpty()) { + return null; + } + Candidate first = candidates.get(0); + if (first.getMessage() == null || first.getMessage().getContent() == null) { + return null; + } + return first.getMessage().getContent().stream().filter(part -> part.getText() != null).map(Part::getText) + .collect(Collectors.joining()); + } + + /** + * Returns the first candidate's message. + * + * @return the message, or null if no candidates + */ + public Message getMessage() { + if (candidates == null || candidates.isEmpty()) { + return null; + } + return candidates.get(0).getMessage(); + } + + /** + * Returns all messages including the model's response. + * + *

+ * This is useful when resuming after an interrupt - pass these messages to the + * next generate call to maintain context. + * + * @return list of all messages (request messages + model response) + */ + public List getMessages() { + List messages = new ArrayList<>(); + if (request != null && request.getMessages() != null) { + messages.addAll(request.getMessages()); + } + Message responseMessage = getMessage(); + if (responseMessage != null) { + messages.add(responseMessage); + } + return messages; + } + + /** + * Returns all tool request parts from the first candidate. + * + * @return list of tool requests + */ + public List getToolRequests() { + if (candidates == null || candidates.isEmpty()) { + return new ArrayList<>(); + } + Candidate first = candidates.get(0); + if (first.getMessage() == null || first.getMessage().getContent() == null) { + return new ArrayList<>(); + } + return first.getMessage().getContent().stream().filter(part -> part.getToolRequest() != null) + .collect(Collectors.toList()); + } + + // Getters and setters + + public List getCandidates() { + return candidates; + } + + public void setCandidates(List candidates) { + this.candidates = candidates; + } + + public Usage getUsage() { + return usage; + } + + public void setUsage(Usage usage) { + this.usage = usage; + } + + public ModelRequest getRequest() { + return request; + } + + public void setRequest(ModelRequest request) { + this.request = request; + } + + public Map getCustom() { + return custom; + } + + public void setCustom(Map custom) { + this.custom = custom; + } + + public Long getLatencyMs() { + return latencyMs; + } + + public void setLatencyMs(Long latencyMs) { + this.latencyMs = latencyMs; + } + + public FinishReason getFinishReason() { + if (finishReason != null) { + return finishReason; + } + // Fall back to first candidate's finish reason + if (candidates != null && !candidates.isEmpty()) { + return candidates.get(0).getFinishReason(); + } + return null; + } + + public void setFinishReason(FinishReason finishReason) { + this.finishReason = finishReason; + } + + public String getFinishMessage() { + return finishMessage; + } + + public void setFinishMessage(String finishMessage) { + this.finishMessage = finishMessage; + } + + /** + * Returns the list of interrupt tool requests. + * + *

+ * When the model requests tools that are interrupts, this list contains the + * tool request parts with interrupt metadata. Check if this list is non-empty + * to determine if generation was interrupted. + * + * @return list of interrupt tool request parts, or empty list if none + */ + public List getInterrupts() { + return interrupts != null ? interrupts : new ArrayList<>(); + } + + public void setInterrupts(List interrupts) { + this.interrupts = interrupts; + } + + /** + * Checks if generation was interrupted. + * + * @return true if there are pending interrupts + */ + public boolean isInterrupted() { + return interrupts != null && !interrupts.isEmpty(); + } + + /** + * Builder for ModelResponse. + */ + public static class Builder { + private List candidates = new ArrayList<>(); + private Usage usage; + private ModelRequest request; + private Map custom; + private Long latencyMs; + private FinishReason finishReason; + private String finishMessage; + private List interrupts; + + public Builder candidates(List candidates) { + this.candidates = new ArrayList<>(candidates); + return this; + } + + public Builder addCandidate(Candidate candidate) { + this.candidates.add(candidate); + return this; + } + + public Builder usage(Usage usage) { + this.usage = usage; + return this; + } + + public Builder request(ModelRequest request) { + this.request = request; + return this; + } + + public Builder custom(Map custom) { + this.custom = custom; + return this; + } + + public Builder latencyMs(Long latencyMs) { + this.latencyMs = latencyMs; + return this; + } + + public Builder finishReason(FinishReason finishReason) { + this.finishReason = finishReason; + return this; + } + + public Builder finishMessage(String finishMessage) { + this.finishMessage = finishMessage; + return this; + } + + public Builder interrupts(List interrupts) { + this.interrupts = interrupts != null ? new ArrayList<>(interrupts) : null; + return this; + } + + public ModelResponse build() { + ModelResponse response = new ModelResponse(candidates); + response.setUsage(usage); + response.setRequest(request); + response.setCustom(custom); + response.setLatencyMs(latencyMs); + response.setFinishReason(finishReason); + response.setFinishMessage(finishMessage); + response.setInterrupts(interrupts); + return response; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/ModelResponseChunk.java b/java/ai/src/main/java/com/google/genkit/ai/ModelResponseChunk.java new file mode 100644 index 0000000000..3e48d5bb1d --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/ModelResponseChunk.java @@ -0,0 +1,103 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.ArrayList; +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * ModelResponseChunk represents a streaming chunk from a generative AI model. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ModelResponseChunk { + + @JsonProperty("content") + private List content = new ArrayList<>(); + + @JsonProperty("index") + private Integer index; + + /** + * Default constructor. + */ + public ModelResponseChunk() { + } + + /** + * Creates a ModelResponseChunk with the given content. + * + * @param content + * the content parts + */ + public ModelResponseChunk(List content) { + this.content = content != null ? new ArrayList<>(content) : new ArrayList<>(); + } + + /** + * Creates a ModelResponseChunk with text content. + * + * @param text + * the text content + * @return a new chunk + */ + public static ModelResponseChunk text(String text) { + ModelResponseChunk chunk = new ModelResponseChunk(); + chunk.content.add(Part.text(text)); + return chunk; + } + + /** + * Returns the text content from all text parts. + * + * @return the concatenated text, or null if no text content + */ + public String getText() { + if (content == null || content.isEmpty()) { + return null; + } + StringBuilder sb = new StringBuilder(); + for (Part part : content) { + if (part.getText() != null) { + sb.append(part.getText()); + } + } + return sb.length() > 0 ? sb.toString() : null; + } + + // Getters and setters + + public List getContent() { + return content; + } + + public void setContent(List content) { + this.content = content; + } + + public Integer getIndex() { + return index; + } + + public void setIndex(Integer index) { + this.index = index; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/OutputConfig.java b/java/ai/src/main/java/com/google/genkit/ai/OutputConfig.java new file mode 100644 index 0000000000..fde0d957c2 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/OutputConfig.java @@ -0,0 +1,129 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * OutputConfig contains configuration for model output generation. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class OutputConfig { + + @JsonProperty("format") + private OutputFormat format; + + @JsonProperty("schema") + private Map schema; + + @JsonProperty("constrained") + private Boolean constrained; + + @JsonProperty("contentType") + private String contentType; + + @JsonProperty("instructions") + private String instructions; + + /** + * Default constructor. + */ + public OutputConfig() { + } + + /** + * Creates an OutputConfig with the given format. + * + * @param format + * the output format + */ + public OutputConfig(OutputFormat format) { + this.format = format; + } + + /** + * Creates an OutputConfig for JSON output with schema. + * + * @param schema + * the JSON schema + * @return an OutputConfig configured for JSON + */ + public static OutputConfig json(Map schema) { + OutputConfig config = new OutputConfig(); + config.format = OutputFormat.JSON; + config.schema = schema; + return config; + } + + /** + * Creates an OutputConfig for text output. + * + * @return an OutputConfig configured for text + */ + public static OutputConfig text() { + OutputConfig config = new OutputConfig(); + config.format = OutputFormat.TEXT; + return config; + } + + // Getters and setters + + public OutputFormat getFormat() { + return format; + } + + public void setFormat(OutputFormat format) { + this.format = format; + } + + public Map getSchema() { + return schema; + } + + public void setSchema(Map schema) { + this.schema = schema; + } + + public Boolean getConstrained() { + return constrained; + } + + public void setConstrained(Boolean constrained) { + this.constrained = constrained; + } + + public String getContentType() { + return contentType; + } + + public void setContentType(String contentType) { + this.contentType = contentType; + } + + public String getInstructions() { + return instructions; + } + + public void setInstructions(String instructions) { + this.instructions = instructions; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/OutputFormat.java b/java/ai/src/main/java/com/google/genkit/ai/OutputFormat.java new file mode 100644 index 0000000000..0139be6ab6 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/OutputFormat.java @@ -0,0 +1,36 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * OutputFormat specifies the format for model output. + */ +public enum OutputFormat { + + @JsonProperty("text") + TEXT, + + @JsonProperty("json") + JSON, + + @JsonProperty("media") + MEDIA +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Part.java b/java/ai/src/main/java/com/google/genkit/ai/Part.java new file mode 100644 index 0000000000..e0f5a695bf --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Part.java @@ -0,0 +1,218 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Part represents a part of a message content, which can be text, media, tool + * request, or tool response. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class Part { + + @JsonProperty("text") + private String text; + + @JsonProperty("media") + private Media media; + + @JsonProperty("toolRequest") + private ToolRequest toolRequest; + + @JsonProperty("toolResponse") + private ToolResponse toolResponse; + + @JsonProperty("data") + private Object data; + + @JsonProperty("metadata") + private Map metadata; + + /** + * Default constructor. + */ + public Part() { + } + + /** + * Creates a text part. + * + * @param text + * the text content + * @return a new text part + */ + public static Part text(String text) { + Part part = new Part(); + part.text = text; + return part; + } + + /** + * Creates a media part. + * + * @param contentType + * the media content type + * @param url + * the media URL + * @return a new media part + */ + public static Part media(String contentType, String url) { + Part part = new Part(); + part.media = new Media(contentType, url); + return part; + } + + /** + * Creates a tool request part. + * + * @param toolRequest + * the tool request + * @return a new tool request part + */ + public static Part toolRequest(ToolRequest toolRequest) { + Part part = new Part(); + part.toolRequest = toolRequest; + return part; + } + + /** + * Creates a tool response part. + * + * @param toolResponse + * the tool response + * @return a new tool response part + */ + public static Part toolResponse(ToolResponse toolResponse) { + Part part = new Part(); + part.toolResponse = toolResponse; + return part; + } + + /** + * Creates a data part. + * + * @param data + * the structured data + * @return a new data part + */ + public static Part data(Object data) { + Part part = new Part(); + part.data = data; + return part; + } + + // Getters and setters + + public String getText() { + return text; + } + + public void setText(String text) { + this.text = text; + } + + public Media getMedia() { + return media; + } + + public void setMedia(Media media) { + this.media = media; + } + + public ToolRequest getToolRequest() { + return toolRequest; + } + + public void setToolRequest(ToolRequest toolRequest) { + this.toolRequest = toolRequest; + } + + public ToolResponse getToolResponse() { + return toolResponse; + } + + public void setToolResponse(ToolResponse toolResponse) { + this.toolResponse = toolResponse; + } + + public Object getData() { + return data; + } + + public void setData(Object data) { + this.data = data; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } + + /** + * Returns true if this is a text part. + * + * @return true if text + */ + public boolean isText() { + return text != null; + } + + /** + * Returns true if this is a media part. + * + * @return true if media + */ + public boolean isMedia() { + return media != null; + } + + /** + * Returns true if this is a tool request part. + * + * @return true if tool request + */ + public boolean isToolRequest() { + return toolRequest != null; + } + + /** + * Returns true if this is a tool response part. + * + * @return true if tool response + */ + public boolean isToolResponse() { + return toolResponse != null; + } + + /** + * Returns true if this is a data part. + * + * @return true if data + */ + public boolean isData() { + return data != null; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Prompt.java b/java/ai/src/main/java/com/google/genkit/ai/Prompt.java new file mode 100644 index 0000000000..7a449d5d62 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Prompt.java @@ -0,0 +1,268 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionDesc; +import com.google.genkit.core.ActionRunResult; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.JsonUtils; +import com.google.genkit.core.Registry; + +/** + * Prompt is a template that generates ModelRequests from input variables. + * + * Prompts are registered as actions and can be rendered with different input + * values to create model requests. + * + * @param + * the input type for the prompt + */ +public class Prompt implements Action { + + private final String name; + private final String model; + private final String template; + private final Map inputSchema; + private final GenerationConfig config; + private final BiFunction renderer; + private final Map metadata; + private final Class inputClass; + + /** + * Creates a new Prompt. + * + * @param name + * the prompt name + * @param model + * the default model name + * @param template + * the prompt template + * @param inputSchema + * the input JSON schema + * @param config + * the default generation config + * @param inputClass + * the input class for JSON deserialization + * @param renderer + * the function that renders the prompt + */ + public Prompt(String name, String model, String template, Map inputSchema, GenerationConfig config, + Class inputClass, BiFunction renderer) { + this.name = name; + this.model = model; + this.template = template; + this.inputSchema = inputSchema; + this.config = config; + this.inputClass = inputClass; + this.renderer = renderer; + + // Build metadata structure to match Go SDK format + // The metadata.type identifies this as an executable-prompt + // The metadata.prompt contains the prompt-specific metadata + this.metadata = new HashMap<>(); + this.metadata.put("type", ActionType.EXECUTABLE_PROMPT.getValue()); + + // Build the prompt sub-object with detailed metadata + Map promptMetadata = new HashMap<>(); + promptMetadata.put("name", name); + promptMetadata.put("model", model); + promptMetadata.put("template", template); + if (inputSchema != null) { + promptMetadata.put("input", Map.of("schema", inputSchema)); + } + if (config != null) { + promptMetadata.put("config", config); + } + this.metadata.put("prompt", promptMetadata); + } + + /** + * Creates a builder for Prompt. + * + * @param + * the input type + * @return a new builder + */ + public static Builder builder() { + return new Builder<>(); + } + + @Override + public String getName() { + return name; + } + + @Override + public ActionType getType() { + return ActionType.EXECUTABLE_PROMPT; + } + + @Override + public ActionDesc getDesc() { + return ActionDesc.builder().type(ActionType.EXECUTABLE_PROMPT).name(name).inputSchema(inputSchema) + .metadata(metadata).build(); + } + + @Override + public ModelRequest run(ActionContext ctx, I input) throws GenkitException { + try { + return renderer.apply(ctx, input); + } catch (Exception e) { + throw new GenkitException("Prompt rendering failed: " + e.getMessage(), e); + } + } + + @Override + public ModelRequest run(ActionContext ctx, I input, Consumer streamCallback) throws GenkitException { + return run(ctx, input); + } + + @Override + @SuppressWarnings("unchecked") + public JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException { + I typedInput = inputClass != null ? JsonUtils.fromJsonNode(input, inputClass) : (I) input; + ModelRequest output = run(ctx, typedInput); + return JsonUtils.toJsonNode(output); + } + + @Override + public ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, + Consumer streamCallback) throws GenkitException { + JsonNode result = runJson(ctx, input, streamCallback); + return new ActionRunResult<>(result, null, null); + } + + @Override + public Map getInputSchema() { + return inputSchema; + } + + @Override + public Map getOutputSchema() { + return null; + } + + @Override + public Map getMetadata() { + return metadata; + } + + @Override + public void register(Registry registry) { + registry.registerAction(ActionType.EXECUTABLE_PROMPT.keyFromName(name), this); + } + + /** + * Gets the default model name. + * + * @return the model name + */ + public String getModel() { + return model; + } + + /** + * Gets the prompt template. + * + * @return the template + */ + public String getTemplate() { + return template; + } + + /** + * Gets the default generation config. + * + * @return the config + */ + public GenerationConfig getConfig() { + return config; + } + + /** + * Builder for Prompt. + * + * @param + * the input type + */ + public static class Builder { + private String name; + private String model; + private String template; + private Map inputSchema; + private GenerationConfig config; + private Class inputClass; + private BiFunction renderer; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder template(String template) { + this.template = template; + return this; + } + + public Builder inputSchema(Map inputSchema) { + this.inputSchema = inputSchema; + return this; + } + + public Builder config(GenerationConfig config) { + this.config = config; + return this; + } + + public Builder inputClass(Class inputClass) { + this.inputClass = inputClass; + return this; + } + + public Builder renderer(BiFunction renderer) { + this.renderer = renderer; + return this; + } + + public Prompt build() { + if (name == null) { + throw new IllegalStateException("Prompt name is required"); + } + if (renderer == null) { + throw new IllegalStateException("Prompt renderer is required"); + } + return new Prompt<>(name, model, template, inputSchema, config, inputClass, renderer); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/ResumeOptions.java b/java/ai/src/main/java/com/google/genkit/ai/ResumeOptions.java new file mode 100644 index 0000000000..3c8ef21aa8 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/ResumeOptions.java @@ -0,0 +1,170 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.ArrayList; +import java.util.List; + +/** + * Options for resuming after an interrupt. + * + *

+ * When generation is interrupted by a tool, you can resume by providing + * responses to the interrupted tool requests or by restarting them with new + * inputs. + * + *

+ * Example usage: + * + *

{@code
+ * // Respond to an interrupt
+ * ResumeOptions resume = ResumeOptions.builder().respond(interrupt.respond("user confirmed")).build();
+ *
+ * // Restart an interrupt with new input
+ * ResumeOptions resume = ResumeOptions.builder().restart(interrupt.restart(null, newInput)).build();
+ * }
+ */ +public class ResumeOptions { + + private List respond; + private List restart; + + /** Default constructor. */ + public ResumeOptions() { + } + + /** + * Gets the tool responses for interrupted requests. + * + * @return the tool responses + */ + public List getRespond() { + return respond; + } + + /** + * Sets the tool responses for interrupted requests. + * + * @param respond + * the tool responses + */ + public void setRespond(List respond) { + this.respond = respond; + } + + /** + * Gets the tool requests to restart. + * + * @return the tool requests to restart + */ + public List getRestart() { + return restart; + } + + /** + * Sets the tool requests to restart. + * + * @param restart + * the tool requests to restart + */ + public void setRestart(List restart) { + this.restart = restart; + } + + /** + * Creates a new builder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for ResumeOptions. */ + public static class Builder { + private List respond; + private List restart; + + /** + * Adds a tool response. + * + * @param response + * the tool response + * @return this builder + */ + public Builder respond(ToolResponse response) { + if (this.respond == null) { + this.respond = new ArrayList<>(); + } + this.respond.add(response); + return this; + } + + /** + * Sets all tool responses. + * + * @param responses + * the tool responses + * @return this builder + */ + public Builder respond(List responses) { + this.respond = responses; + return this; + } + + /** + * Adds a tool request to restart. + * + * @param request + * the tool request + * @return this builder + */ + public Builder restart(ToolRequest request) { + if (this.restart == null) { + this.restart = new ArrayList<>(); + } + this.restart.add(request); + return this; + } + + /** + * Sets all tool requests to restart. + * + * @param requests + * the tool requests + * @return this builder + */ + public Builder restart(List requests) { + this.restart = requests; + return this; + } + + /** + * Builds the ResumeOptions. + * + * @return the built options + */ + public ResumeOptions build() { + ResumeOptions options = new ResumeOptions(); + options.setRespond(respond); + options.setRestart(restart); + return options; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Retriever.java b/java/ai/src/main/java/com/google/genkit/ai/Retriever.java new file mode 100644 index 0000000000..0fe39a4353 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Retriever.java @@ -0,0 +1,251 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionDesc; +import com.google.genkit.core.ActionRunResult; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.JsonUtils; +import com.google.genkit.core.Registry; + +/** + * Retriever is an action that retrieves documents based on a query. + * + * Retrievers are used for RAG (Retrieval Augmented Generation) workflows to + * find relevant documents to include in model prompts. + */ +public class Retriever implements Action { + + private final String name; + private final BiFunction handler; + private final Map metadata; + + /** + * Creates a new Retriever. + * + * @param name + * the retriever name + * @param handler + * the retrieval function + */ + public Retriever(String name, BiFunction handler) { + this.name = name; + this.handler = handler; + this.metadata = new HashMap<>(); + this.metadata.put("type", "retriever"); + } + + /** + * Creates a builder for Retriever. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + @Override + public String getName() { + return name; + } + + @Override + public ActionType getType() { + return ActionType.RETRIEVER; + } + + @Override + public ActionDesc getDesc() { + return ActionDesc.builder().type(ActionType.RETRIEVER).name(name).inputSchema(getInputSchema()) + .outputSchema(getOutputSchema()).metadata(getMetadata()).build(); + } + + @Override + public RetrieverResponse run(ActionContext ctx, RetrieverRequest input) throws GenkitException { + try { + return handler.apply(ctx, input); + } catch (Exception e) { + throw new GenkitException("Retriever execution failed: " + e.getMessage(), e); + } + } + + @Override + public RetrieverResponse run(ActionContext ctx, RetrieverRequest input, Consumer streamCallback) + throws GenkitException { + return run(ctx, input); + } + + @Override + public JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException { + RetrieverRequest request = JsonUtils.fromJsonNode(input, RetrieverRequest.class); + RetrieverResponse response = run(ctx, request); + return JsonUtils.toJsonNode(response); + } + + @Override + public ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, + Consumer streamCallback) throws GenkitException { + JsonNode result = runJson(ctx, input, streamCallback); + return new ActionRunResult<>(result, null, null); + } + + @Override + public Map getInputSchema() { + // Define the input schema to match genkit-tools RetrieverRequest schema + Map schema = new HashMap<>(); + schema.put("type", "object"); + + Map properties = new HashMap<>(); + + // query property - DocumentData + Map queryProp = new HashMap<>(); + queryProp.put("type", "object"); + Map queryProps = new HashMap<>(); + + // content array + Map contentProp = new HashMap<>(); + contentProp.put("type", "array"); + Map contentItemSchema = new HashMap<>(); + contentItemSchema.put("type", "object"); + Map partProps = new HashMap<>(); + Map textProp = new HashMap<>(); + textProp.put("type", "string"); + partProps.put("text", textProp); + contentItemSchema.put("properties", partProps); + contentProp.put("items", contentItemSchema); + queryProps.put("content", contentProp); + + // metadata + Map metaProp = new HashMap<>(); + metaProp.put("type", "object"); + metaProp.put("additionalProperties", true); + queryProps.put("metadata", metaProp); + + queryProp.put("properties", queryProps); + queryProp.put("required", java.util.List.of("content")); + properties.put("query", queryProp); + + // options property + Map optionsProp = new HashMap<>(); + optionsProp.put("type", "object"); + Map optionsProps = new HashMap<>(); + Map kProp = new HashMap<>(); + kProp.put("type", "integer"); + kProp.put("description", "Number of documents to retrieve"); + optionsProps.put("k", kProp); + optionsProp.put("properties", optionsProps); + properties.put("options", optionsProp); + + schema.put("properties", properties); + schema.put("required", java.util.List.of("query")); + + return schema; + } + + @Override + public Map getOutputSchema() { + // Define the output schema to match genkit-tools RetrieverResponse schema + Map schema = new HashMap<>(); + schema.put("type", "object"); + + Map properties = new HashMap<>(); + + // documents array + Map docsProp = new HashMap<>(); + docsProp.put("type", "array"); + Map docItemSchema = new HashMap<>(); + docItemSchema.put("type", "object"); + Map docProps = new HashMap<>(); + + // content array in each document + Map contentProp = new HashMap<>(); + contentProp.put("type", "array"); + Map partSchema = new HashMap<>(); + partSchema.put("type", "object"); + Map partProps = new HashMap<>(); + Map textProp = new HashMap<>(); + textProp.put("type", "string"); + partProps.put("text", textProp); + partSchema.put("properties", partProps); + contentProp.put("items", partSchema); + docProps.put("content", contentProp); + + // metadata + Map metaProp = new HashMap<>(); + metaProp.put("type", "object"); + docProps.put("metadata", metaProp); + + docItemSchema.put("properties", docProps); + docsProp.put("items", docItemSchema); + properties.put("documents", docsProp); + + schema.put("properties", properties); + schema.put("required", java.util.List.of("documents")); + + return schema; + } + + @Override + public Map getMetadata() { + return metadata; + } + + @Override + public void register(Registry registry) { + registry.registerAction(ActionType.RETRIEVER.keyFromName(name), this); + } + + /** + * Builder for Retriever. + */ + public static class Builder { + private String name; + private BiFunction handler; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder handler(BiFunction handler) { + this.handler = handler; + return this; + } + + public Retriever build() { + if (name == null) { + throw new IllegalStateException("Retriever name is required"); + } + if (handler == null) { + throw new IllegalStateException("Retriever handler is required"); + } + return new Retriever(name, handler); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/RetrieverRequest.java b/java/ai/src/main/java/com/google/genkit/ai/RetrieverRequest.java new file mode 100644 index 0000000000..fafbab3c3e --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/RetrieverRequest.java @@ -0,0 +1,111 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * RetrieverRequest contains a query for document retrieval. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class RetrieverRequest { + + @JsonProperty("query") + private Document query; + + @JsonProperty("options") + private RetrieverOptions options; + + /** + * Default constructor. + */ + public RetrieverRequest() { + } + + /** + * Creates a RetrieverRequest with a query. + * + * @param query + * the query document + */ + public RetrieverRequest(Document query) { + this.query = query; + } + + /** + * Creates a RetrieverRequest with a text query. + * + * @param queryText + * the query text + * @return a RetrieverRequest + */ + public static RetrieverRequest fromText(String queryText) { + return new RetrieverRequest(Document.fromText(queryText)); + } + + // Getters and setters + + public Document getQuery() { + return query; + } + + public void setQuery(Document query) { + this.query = query; + } + + public RetrieverOptions getOptions() { + return options; + } + + public void setOptions(RetrieverOptions options) { + this.options = options; + } + + /** + * RetrieverOptions contains options for retrieval. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class RetrieverOptions { + + @JsonProperty("k") + private Integer k; + + @JsonProperty("custom") + private Map custom; + + public Integer getK() { + return k; + } + + public void setK(Integer k) { + this.k = k; + } + + public Map getCustom() { + return custom; + } + + public void setCustom(Map custom) { + this.custom = custom; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/RetrieverResponse.java b/java/ai/src/main/java/com/google/genkit/ai/RetrieverResponse.java new file mode 100644 index 0000000000..8758c4f781 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/RetrieverResponse.java @@ -0,0 +1,60 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * RetrieverResponse contains documents retrieved from a query. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class RetrieverResponse { + + @JsonProperty("documents") + private List documents; + + /** + * Default constructor. + */ + public RetrieverResponse() { + } + + /** + * Creates a RetrieverResponse with documents. + * + * @param documents + * the retrieved documents + */ + public RetrieverResponse(List documents) { + this.documents = documents; + } + + // Getters and setters + + public List getDocuments() { + return documents; + } + + public void setDocuments(List documents) { + this.documents = documents; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Role.java b/java/ai/src/main/java/com/google/genkit/ai/Role.java new file mode 100644 index 0000000000..eee099addb --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Role.java @@ -0,0 +1,59 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +/** + * Role represents the role of a message sender in a conversation. + */ +public enum Role { + USER("user"), MODEL("model"), SYSTEM("system"), TOOL("tool"); + + private final String value; + + Role(String value) { + this.value = value; + } + + @JsonValue + public String getValue() { + return value; + } + + @JsonCreator + public static Role fromValue(String value) { + for (Role role : values()) { + if (role.value.equalsIgnoreCase(value)) { + return role; + } + } + // Try matching "assistant" to MODEL for compatibility + if ("assistant".equalsIgnoreCase(value)) { + return MODEL; + } + throw new IllegalArgumentException("Unknown role: " + value); + } + + @Override + public String toString() { + return value; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Tool.java b/java/ai/src/main/java/com/google/genkit/ai/Tool.java new file mode 100644 index 0000000000..3f9d9ec101 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Tool.java @@ -0,0 +1,403 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionDesc; +import com.google.genkit.core.ActionRunResult; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.JsonUtils; +import com.google.genkit.core.Registry; +import com.google.genkit.core.tracing.SpanMetadata; +import com.google.genkit.core.tracing.Tracer; + +/** + * Tool represents a function that can be called by an AI model. + * + * Tools allow models to interact with external systems and perform actions + * during generation. + * + * @param + * the input type + * @param + * the output type + */ +public class Tool implements Action { + + private final String name; + private final String description; + private final Map inputSchema; + private final Map outputSchema; + private final BiFunction handler; + private final Map metadata; + private final Class inputClass; + + /** + * Creates a new Tool. + * + * @param name + * the tool name + * @param description + * the tool description + * @param inputSchema + * the input JSON schema + * @param outputSchema + * the output JSON schema + * @param inputClass + * the input class for JSON deserialization + * @param handler + * the tool handler function + */ + public Tool(String name, String description, Map inputSchema, Map outputSchema, + Class inputClass, BiFunction handler) { + this.name = name; + this.description = description; + this.inputSchema = inputSchema; + this.outputSchema = outputSchema; + this.inputClass = inputClass; + this.handler = handler; + this.metadata = new HashMap<>(); + this.metadata.put("description", description); + } + + /** + * Creates a builder for Tool. + * + * @param + * the input type + * @param + * the output type + * @return a new builder + */ + public static Builder builder() { + return new Builder<>(); + } + + @Override + public String getName() { + return name; + } + + @Override + public ActionType getType() { + return ActionType.TOOL; + } + + @Override + public ActionDesc getDesc() { + return ActionDesc.builder().type(ActionType.TOOL).name(name).description(description).inputSchema(inputSchema) + .outputSchema(outputSchema).build(); + } + + @Override + public O run(ActionContext ctx, I input) throws GenkitException { + SpanMetadata spanMetadata = SpanMetadata.builder().name(name).type(ActionType.TOOL.getValue()).subtype("tool") + .build(); + + String flowName = ctx.getFlowName(); + if (flowName != null) { + spanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); + } + + return Tracer.runInNewSpan(ctx, spanMetadata, input, (spanCtx, in) -> { + try { + O result = handler.apply(ctx.withSpanContext(spanCtx), in); + return result; + } catch (AgentHandoffException e) { + // Re-throw agent handoff exceptions for multi-agent pattern + throw e; + } catch (ToolInterruptException e) { + // Re-throw interrupt exceptions for human-in-the-loop pattern + throw e; + } catch (Exception e) { + if (e instanceof GenkitException) { + throw (GenkitException) e; + } + throw new GenkitException("Tool execution failed: " + e.getMessage(), e); + } + }); + } + + @Override + public O run(ActionContext ctx, I input, Consumer streamCallback) throws GenkitException { + return run(ctx, input); + } + + @Override + @SuppressWarnings("unchecked") + public JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException { + I typedInput = inputClass != null ? JsonUtils.fromJsonNode(input, inputClass) : (I) input; + O output = run(ctx, typedInput); + return JsonUtils.toJsonNode(output); + } + + @Override + public ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, + Consumer streamCallback) throws GenkitException { + JsonNode result = runJson(ctx, input, streamCallback); + return new ActionRunResult<>(result, null, null); + } + + @Override + public Map getInputSchema() { + return inputSchema; + } + + @Override + public Map getOutputSchema() { + return outputSchema; + } + + @Override + public Map getMetadata() { + return metadata; + } + + @Override + public void register(Registry registry) { + registry.registerAction(ActionType.TOOL.keyFromName(name), this); + } + + /** + * Gets the tool description. + * + * @return the description + */ + public String getDescription() { + return description; + } + + /** + * Gets the input class for JSON deserialization. + * + * @return the input class, or null if not specified + */ + public Class getInputClass() { + return inputClass; + } + + /** + * Gets the tool definition for use in model requests. + * + * @return the tool definition + */ + public ToolDefinition getDefinition() { + return new ToolDefinition(name, description, inputSchema, outputSchema); + } + + /** + * Constructs a tool response for an interrupted tool request. + * + *

+ * This method is used when resuming generation after an interrupt. It creates a + * tool response part that can be passed to {@link ResumeOptions#getRespond()}. + * + *

+ * Example usage: + * + *

{@code
+   * // Get interrupt from response
+   * Part interrupt = response.getInterrupts().get(0);
+   * 
+   * // Create response with user-provided data
+   * Part responseData = tool.respond(interrupt, userConfirmation);
+   * 
+   * // Resume generation
+   * ModelResponse resumed = genkit.generate(GenerateOptions.builder().messages(response.getMessages())
+   * 		.resume(ResumeOptions.builder().respond(responseData).build()).build());
+   * }
+ * + * @param interrupt + * the interrupted tool request part + * @param output + * the output data to respond with + * @return a tool response part + */ + public Part respond(Part interrupt, O output) { + return respond(interrupt, output, null); + } + + /** + * Constructs a tool response for an interrupted tool request with metadata. + * + * @param interrupt + * the interrupted tool request part + * @param output + * the output data to respond with + * @param metadata + * optional metadata to include in the response + * @return a tool response part + */ + public Part respond(Part interrupt, O output, Map metadata) { + if (interrupt == null || interrupt.getToolRequest() == null) { + throw new IllegalArgumentException("Interrupt must be a tool request part"); + } + + ToolRequest toolRequest = interrupt.getToolRequest(); + Part responsePart = new Part(); + ToolResponse toolResponse = new ToolResponse(toolRequest.getRef(), toolRequest.getName(), output); + responsePart.setToolResponse(toolResponse); + + // Add interruptResponse marker in metadata + Map responseMetadata = new HashMap<>(); + responseMetadata.put("interruptResponse", true); + if (metadata != null) { + responseMetadata.putAll(metadata); + } + responsePart.setMetadata(responseMetadata); + + return responsePart; + } + + /** + * Constructs a restart request for an interrupted tool. + * + *

+ * This method creates a tool request that will cause the tool to be + * re-executed. The resumed metadata will be passed to the tool handler. + * + *

+ * Example usage: + * + *

{@code
+   * // Get interrupt from response
+   * Part interrupt = response.getInterrupts().get(0);
+   * 
+   * // Create restart request with confirmation metadata
+   * Part restartRequest = tool.restart(interrupt, Map.of("confirmed", true));
+   * 
+   * // Resume generation
+   * ModelResponse resumed = genkit.generate(GenerateOptions.builder().messages(response.getMessages())
+   * 		.resume(ResumeOptions.builder().restart(restartRequest).build()).build());
+   * }
+ * + * @param interrupt + * the interrupted tool request part + * @param resumedMetadata + * metadata to pass to the tool handler's resumed context + * @return a tool request part for restart + */ + public Part restart(Part interrupt, Map resumedMetadata) { + return restart(interrupt, resumedMetadata, null); + } + + /** + * Constructs a restart request with replacement input. + * + * @param interrupt + * the interrupted tool request part + * @param resumedMetadata + * metadata to pass to the tool handler's resumed context + * @param replaceInput + * optional new input to use instead of the original + * @return a tool request part for restart + */ + public Part restart(Part interrupt, Map resumedMetadata, I replaceInput) { + if (interrupt == null || interrupt.getToolRequest() == null) { + throw new IllegalArgumentException("Interrupt must be a tool request part"); + } + + ToolRequest originalRequest = interrupt.getToolRequest(); + Part restartPart = new Part(); + + // Create new tool request with either original or replacement input + Object inputToUse = replaceInput != null ? replaceInput : originalRequest.getInput(); + ToolRequest restartRequest = new ToolRequest(originalRequest.getName(), inputToUse); + restartRequest.setRef(originalRequest.getRef()); + restartPart.setToolRequest(restartRequest); + + // Add resumed metadata + Map restartMetadata = new HashMap<>(); + restartMetadata.put("source", "restart"); + if (resumedMetadata != null) { + restartMetadata.put("resumed", resumedMetadata); + } else { + restartMetadata.put("resumed", true); + } + restartPart.setMetadata(restartMetadata); + + return restartPart; + } + + /** + * Builder for Tool. + * + * @param + * the input type + * @param + * the output type + */ + public static class Builder { + private String name; + private String description; + private Map inputSchema; + private Map outputSchema; + private Class inputClass; + private BiFunction handler; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder inputSchema(Map inputSchema) { + this.inputSchema = inputSchema; + return this; + } + + public Builder outputSchema(Map outputSchema) { + this.outputSchema = outputSchema; + return this; + } + + public Builder inputClass(Class inputClass) { + this.inputClass = inputClass; + return this; + } + + public Builder handler(BiFunction handler) { + this.handler = handler; + return this; + } + + public Tool build() { + if (name == null) { + throw new IllegalStateException("Tool name is required"); + } + if (handler == null) { + throw new IllegalStateException("Tool handler is required"); + } + return new Tool<>(name, description, inputSchema, outputSchema, inputClass, handler); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/ToolDefinition.java b/java/ai/src/main/java/com/google/genkit/ai/ToolDefinition.java new file mode 100644 index 0000000000..a7624b3bf3 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/ToolDefinition.java @@ -0,0 +1,127 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * ToolDefinition describes a tool that can be used by a model. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ToolDefinition { + + @JsonProperty("name") + private String name; + + @JsonProperty("description") + private String description; + + @JsonProperty("inputSchema") + private Map inputSchema; + + @JsonProperty("outputSchema") + private Map outputSchema; + + @JsonProperty("metadata") + private Map metadata; + + /** + * Default constructor. + */ + public ToolDefinition() { + } + + /** + * Creates a ToolDefinition with the given name and description. + * + * @param name + * the tool name + * @param description + * the tool description + */ + public ToolDefinition(String name, String description) { + this.name = name; + this.description = description; + } + + /** + * Creates a ToolDefinition with full parameters. + * + * @param name + * the tool name + * @param description + * the tool description + * @param inputSchema + * the input JSON schema + * @param outputSchema + * the output JSON schema + */ + public ToolDefinition(String name, String description, Map inputSchema, + Map outputSchema) { + this.name = name; + this.description = description; + this.inputSchema = inputSchema; + this.outputSchema = outputSchema; + } + + // Getters and setters + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public Map getInputSchema() { + return inputSchema; + } + + public void setInputSchema(Map inputSchema) { + this.inputSchema = inputSchema; + } + + public Map getOutputSchema() { + return outputSchema; + } + + public void setOutputSchema(Map outputSchema) { + this.outputSchema = outputSchema; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/ToolInterruptException.java b/java/ai/src/main/java/com/google/genkit/ai/ToolInterruptException.java new file mode 100644 index 0000000000..ecce96e8f8 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/ToolInterruptException.java @@ -0,0 +1,89 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Exception thrown when a tool execution is interrupted. + * + *

+ * This exception is used to implement the interrupt pattern, which allows tools + * to pause execution and request user input (human-in-the-loop). When a tool + * throws this exception, the generation loop stops and returns the interrupt + * information to the caller. + * + *

+ * Example usage: + * + *

{@code
+ * Tool confirmTool = genkit.defineInterrupt(InterruptConfig.builder()
+ * 		.name("confirmAction").description("Ask user for confirmation before proceeding").inputSchema(Input.class)
+ * 		.outputSchema(Output.class).build());
+ * }
+ */ +public class ToolInterruptException extends RuntimeException { + + private final Map metadata; + + /** Creates a new ToolInterruptException with no metadata. */ + public ToolInterruptException() { + super("Tool execution interrupted"); + this.metadata = Collections.emptyMap(); + } + + /** + * Creates a new ToolInterruptException with metadata. + * + * @param metadata + * additional metadata about the interrupt + */ + public ToolInterruptException(Map metadata) { + super("Tool execution interrupted"); + this.metadata = metadata != null + ? Collections.unmodifiableMap(new HashMap<>(metadata)) + : Collections.emptyMap(); + } + + /** + * Creates a new ToolInterruptException with a message and metadata. + * + * @param message + * the exception message + * @param metadata + * additional metadata about the interrupt + */ + public ToolInterruptException(String message, Map metadata) { + super(message); + this.metadata = metadata != null + ? Collections.unmodifiableMap(new HashMap<>(metadata)) + : Collections.emptyMap(); + } + + /** + * Gets the interrupt metadata. + * + * @return the metadata, never null (returns empty map if not set) + */ + public Map getMetadata() { + return metadata; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/ToolRequest.java b/java/ai/src/main/java/com/google/genkit/ai/ToolRequest.java new file mode 100644 index 0000000000..6b03b2b719 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/ToolRequest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * ToolRequest represents a request from the model to invoke a tool. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ToolRequest { + + @JsonProperty("ref") + private String ref; + + @JsonProperty("name") + private String name; + + @JsonProperty("input") + private Object input; + + @JsonProperty("metadata") + private Map metadata; + + /** + * Default constructor. + */ + public ToolRequest() { + } + + /** + * Creates a ToolRequest with the given name and input. + * + * @param name + * the tool name + * @param input + * the tool input + */ + public ToolRequest(String name, Object input) { + this.name = name; + this.input = input; + } + + // Getters and setters + + public String getRef() { + return ref; + } + + public void setRef(String ref) { + this.ref = ref; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public Object getInput() { + return input; + } + + public void setInput(Object input) { + this.input = input; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/ToolResponse.java b/java/ai/src/main/java/com/google/genkit/ai/ToolResponse.java new file mode 100644 index 0000000000..fa53ba7664 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/ToolResponse.java @@ -0,0 +1,112 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * ToolResponse represents a response from a tool invocation. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ToolResponse { + + @JsonProperty("ref") + private String ref; + + @JsonProperty("name") + private String name; + + @JsonProperty("output") + private Object output; + + @JsonProperty("metadata") + private Map metadata; + + /** + * Default constructor. + */ + public ToolResponse() { + } + + /** + * Creates a ToolResponse with the given name and output. + * + * @param name + * the tool name + * @param output + * the tool output + */ + public ToolResponse(String name, Object output) { + this.name = name; + this.output = output; + } + + /** + * Creates a ToolResponse with the given ref, name and output. + * + * @param ref + * the reference ID + * @param name + * the tool name + * @param output + * the tool output + */ + public ToolResponse(String ref, String name, Object output) { + this.ref = ref; + this.name = name; + this.output = output; + } + + // Getters and setters + + public String getRef() { + return ref; + } + + public void setRef(String ref) { + this.ref = ref; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public Object getOutput() { + return output; + } + + public void setOutput(Object output) { + this.output = output; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/Usage.java b/java/ai/src/main/java/com/google/genkit/ai/Usage.java new file mode 100644 index 0000000000..a150a5235c --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/Usage.java @@ -0,0 +1,196 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Usage represents token usage statistics from a model response. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class Usage { + + @JsonProperty("inputTokens") + private Integer inputTokens; + + @JsonProperty("outputTokens") + private Integer outputTokens; + + @JsonProperty("totalTokens") + private Integer totalTokens; + + @JsonProperty("inputCharacters") + private Integer inputCharacters; + + @JsonProperty("outputCharacters") + private Integer outputCharacters; + + @JsonProperty("inputImages") + private Integer inputImages; + + @JsonProperty("outputImages") + private Integer outputImages; + + @JsonProperty("inputAudioFiles") + private Integer inputAudioFiles; + + @JsonProperty("outputAudioFiles") + private Integer outputAudioFiles; + + @JsonProperty("inputVideoFiles") + private Integer inputVideoFiles; + + @JsonProperty("outputVideoFiles") + private Integer outputVideoFiles; + + @JsonProperty("thoughtsTokens") + private Integer thoughtsTokens; + + @JsonProperty("cachedContentTokens") + private Integer cachedContentTokens; + + /** + * Default constructor. + */ + public Usage() { + } + + /** + * Creates a Usage with token counts. + * + * @param inputTokens + * number of input tokens + * @param outputTokens + * number of output tokens + * @param totalTokens + * total number of tokens + */ + public Usage(Integer inputTokens, Integer outputTokens, Integer totalTokens) { + this.inputTokens = inputTokens; + this.outputTokens = outputTokens; + this.totalTokens = totalTokens; + } + + // Getters and setters + + public Integer getInputTokens() { + return inputTokens; + } + + public void setInputTokens(Integer inputTokens) { + this.inputTokens = inputTokens; + } + + public Integer getOutputTokens() { + return outputTokens; + } + + public void setOutputTokens(Integer outputTokens) { + this.outputTokens = outputTokens; + } + + public Integer getTotalTokens() { + return totalTokens; + } + + public void setTotalTokens(Integer totalTokens) { + this.totalTokens = totalTokens; + } + + public Integer getInputCharacters() { + return inputCharacters; + } + + public void setInputCharacters(Integer inputCharacters) { + this.inputCharacters = inputCharacters; + } + + public Integer getOutputCharacters() { + return outputCharacters; + } + + public void setOutputCharacters(Integer outputCharacters) { + this.outputCharacters = outputCharacters; + } + + public Integer getInputImages() { + return inputImages; + } + + public void setInputImages(Integer inputImages) { + this.inputImages = inputImages; + } + + public Integer getOutputImages() { + return outputImages; + } + + public void setOutputImages(Integer outputImages) { + this.outputImages = outputImages; + } + + public Integer getInputAudioFiles() { + return inputAudioFiles; + } + + public void setInputAudioFiles(Integer inputAudioFiles) { + this.inputAudioFiles = inputAudioFiles; + } + + public Integer getOutputAudioFiles() { + return outputAudioFiles; + } + + public void setOutputAudioFiles(Integer outputAudioFiles) { + this.outputAudioFiles = outputAudioFiles; + } + + public Integer getInputVideoFiles() { + return inputVideoFiles; + } + + public void setInputVideoFiles(Integer inputVideoFiles) { + this.inputVideoFiles = inputVideoFiles; + } + + public Integer getOutputVideoFiles() { + return outputVideoFiles; + } + + public void setOutputVideoFiles(Integer outputVideoFiles) { + this.outputVideoFiles = outputVideoFiles; + } + + public Integer getThoughtsTokens() { + return thoughtsTokens; + } + + public void setThoughtsTokens(Integer thoughtsTokens) { + this.thoughtsTokens = thoughtsTokens; + } + + public Integer getCachedContentTokens() { + return cachedContentTokens; + } + + public void setCachedContentTokens(Integer cachedContentTokens) { + this.cachedContentTokens = cachedContentTokens; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/CreateDatasetRequest.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/CreateDatasetRequest.java new file mode 100644 index 0000000000..79f3f9ed72 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/CreateDatasetRequest.java @@ -0,0 +1,175 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; + +/** + * Request to create a new dataset. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class CreateDatasetRequest { + + /** + * The dataset samples. + */ + @JsonProperty("data") + private List data; + + /** + * Optional ID for the dataset. If not provided, one will be generated. + */ + @JsonProperty("datasetId") + private String datasetId; + + /** + * The type of dataset. + */ + @JsonProperty("datasetType") + private DatasetType datasetType; + + /** + * Optional schema for the dataset. + */ + @JsonProperty("schema") + private JsonNode schema; + + /** + * References to metrics/evaluators for this dataset. + */ + @JsonProperty("metricRefs") + private List metricRefs; + + /** + * The target action this dataset is designed for. + */ + @JsonProperty("targetAction") + private String targetAction; + + public CreateDatasetRequest() { + } + + private CreateDatasetRequest(Builder builder) { + this.data = builder.data; + this.datasetId = builder.datasetId; + this.datasetType = builder.datasetType; + this.schema = builder.schema; + this.metricRefs = builder.metricRefs; + this.targetAction = builder.targetAction; + } + + public static Builder builder() { + return new Builder(); + } + + public List getData() { + return data; + } + + public void setData(List data) { + this.data = data; + } + + public String getDatasetId() { + return datasetId; + } + + public void setDatasetId(String datasetId) { + this.datasetId = datasetId; + } + + public DatasetType getDatasetType() { + return datasetType; + } + + public void setDatasetType(DatasetType datasetType) { + this.datasetType = datasetType; + } + + public JsonNode getSchema() { + return schema; + } + + public void setSchema(JsonNode schema) { + this.schema = schema; + } + + public List getMetricRefs() { + return metricRefs; + } + + public void setMetricRefs(List metricRefs) { + this.metricRefs = metricRefs; + } + + public String getTargetAction() { + return targetAction; + } + + public void setTargetAction(String targetAction) { + this.targetAction = targetAction; + } + + public static class Builder { + private List data; + private String datasetId; + private DatasetType datasetType = DatasetType.UNKNOWN; + private JsonNode schema; + private List metricRefs; + private String targetAction; + + public Builder data(List data) { + this.data = data; + return this; + } + + public Builder datasetId(String datasetId) { + this.datasetId = datasetId; + return this; + } + + public Builder datasetType(DatasetType datasetType) { + this.datasetType = datasetType; + return this; + } + + public Builder schema(JsonNode schema) { + this.schema = schema; + return this; + } + + public Builder metricRefs(List metricRefs) { + this.metricRefs = metricRefs; + return this; + } + + public Builder targetAction(String targetAction) { + this.targetAction = targetAction; + return this; + } + + public CreateDatasetRequest build() { + return new CreateDatasetRequest(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetMetadata.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetMetadata.java new file mode 100644 index 0000000000..6f332f287b --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetMetadata.java @@ -0,0 +1,238 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; + +/** + * Metadata about a dataset stored in the dataset store. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class DatasetMetadata { + + /** + * Unique identifier for the dataset. + */ + @JsonProperty("datasetId") + private String datasetId; + + /** + * Number of samples in the dataset. + */ + @JsonProperty("size") + private int size; + + /** + * Optional schema definition for the dataset. + */ + @JsonProperty("schema") + private JsonNode schema; + + /** + * The type of dataset (FLOW, MODEL, etc.). + */ + @JsonProperty("datasetType") + private DatasetType datasetType; + + /** + * The action this dataset is designed for. + */ + @JsonProperty("targetAction") + private String targetAction; + + /** + * References to metrics/evaluators to use with this dataset. + */ + @JsonProperty("metricRefs") + private List metricRefs; + + /** + * Version number of the dataset. + */ + @JsonProperty("version") + private int version; + + /** + * Timestamp when the dataset was created. + */ + @JsonProperty("createTime") + private String createTime; + + /** + * Timestamp when the dataset was last updated. + */ + @JsonProperty("updateTime") + private String updateTime; + + public DatasetMetadata() { + } + + private DatasetMetadata(Builder builder) { + this.datasetId = builder.datasetId; + this.size = builder.size; + this.schema = builder.schema; + this.datasetType = builder.datasetType; + this.targetAction = builder.targetAction; + this.metricRefs = builder.metricRefs; + this.version = builder.version; + this.createTime = builder.createTime; + this.updateTime = builder.updateTime; + } + + public static Builder builder() { + return new Builder(); + } + + public String getDatasetId() { + return datasetId; + } + + public void setDatasetId(String datasetId) { + this.datasetId = datasetId; + } + + public int getSize() { + return size; + } + + public void setSize(int size) { + this.size = size; + } + + public JsonNode getSchema() { + return schema; + } + + public void setSchema(JsonNode schema) { + this.schema = schema; + } + + public DatasetType getDatasetType() { + return datasetType; + } + + public void setDatasetType(DatasetType datasetType) { + this.datasetType = datasetType; + } + + public String getTargetAction() { + return targetAction; + } + + public void setTargetAction(String targetAction) { + this.targetAction = targetAction; + } + + public List getMetricRefs() { + return metricRefs; + } + + public void setMetricRefs(List metricRefs) { + this.metricRefs = metricRefs; + } + + public int getVersion() { + return version; + } + + public void setVersion(int version) { + this.version = version; + } + + public String getCreateTime() { + return createTime; + } + + public void setCreateTime(String createTime) { + this.createTime = createTime; + } + + public String getUpdateTime() { + return updateTime; + } + + public void setUpdateTime(String updateTime) { + this.updateTime = updateTime; + } + + public static class Builder { + private String datasetId; + private int size; + private JsonNode schema; + private DatasetType datasetType = DatasetType.UNKNOWN; + private String targetAction; + private List metricRefs; + private int version = 1; + private String createTime; + private String updateTime; + + public Builder datasetId(String datasetId) { + this.datasetId = datasetId; + return this; + } + + public Builder size(int size) { + this.size = size; + return this; + } + + public Builder schema(JsonNode schema) { + this.schema = schema; + return this; + } + + public Builder datasetType(DatasetType datasetType) { + this.datasetType = datasetType; + return this; + } + + public Builder targetAction(String targetAction) { + this.targetAction = targetAction; + return this; + } + + public Builder metricRefs(List metricRefs) { + this.metricRefs = metricRefs; + return this; + } + + public Builder version(int version) { + this.version = version; + return this; + } + + public Builder createTime(String createTime) { + this.createTime = createTime; + return this; + } + + public Builder updateTime(String updateTime) { + this.updateTime = updateTime; + return this; + } + + public DatasetMetadata build() { + return new DatasetMetadata(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetSample.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetSample.java new file mode 100644 index 0000000000..13f7403f5f --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetSample.java @@ -0,0 +1,113 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Represents a single sample in an inference dataset. + * + *

+ * A sample contains the input to run through the AI system and an optional + * reference output for comparison during evaluation. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class DatasetSample { + + /** + * Optional identifier for this test case. + */ + @JsonProperty("testCaseId") + private String testCaseId; + + /** + * The input to provide to the AI system. + */ + @JsonProperty("input") + private Object input; + + /** + * The expected/reference output for comparison. + */ + @JsonProperty("reference") + private Object reference; + + public DatasetSample() { + } + + private DatasetSample(Builder builder) { + this.testCaseId = builder.testCaseId; + this.input = builder.input; + this.reference = builder.reference; + } + + public static Builder builder() { + return new Builder(); + } + + public String getTestCaseId() { + return testCaseId; + } + + public void setTestCaseId(String testCaseId) { + this.testCaseId = testCaseId; + } + + public Object getInput() { + return input; + } + + public void setInput(Object input) { + this.input = input; + } + + public Object getReference() { + return reference; + } + + public void setReference(Object reference) { + this.reference = reference; + } + + public static class Builder { + private String testCaseId; + private Object input; + private Object reference; + + public Builder testCaseId(String testCaseId) { + this.testCaseId = testCaseId; + return this; + } + + public Builder input(Object input) { + this.input = input; + return this; + } + + public Builder reference(Object reference) { + this.reference = reference; + return this; + } + + public DatasetSample build() { + return new DatasetSample(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetStore.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetStore.java new file mode 100644 index 0000000000..584d045bbb --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetStore.java @@ -0,0 +1,83 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; + +/** + * Interface for dataset storage operations. + * + *

+ * Implementations of this interface handle CRUD operations for datasets used in + * evaluation workflows. + */ +public interface DatasetStore { + + /** + * Creates a new dataset. + * + * @param request + * the create request containing dataset data and metadata + * @return metadata about the created dataset + * @throws Exception + * if creation fails + */ + DatasetMetadata createDataset(CreateDatasetRequest request) throws Exception; + + /** + * Updates an existing dataset. + * + * @param request + * the update request containing dataset ID and new data + * @return metadata about the updated dataset + * @throws Exception + * if update fails or dataset not found + */ + DatasetMetadata updateDataset(UpdateDatasetRequest request) throws Exception; + + /** + * Gets the data for a dataset. + * + * @param datasetId + * the dataset ID + * @return the list of dataset samples + * @throws Exception + * if retrieval fails or dataset not found + */ + List getDataset(String datasetId) throws Exception; + + /** + * Lists all datasets. + * + * @return list of dataset metadata + * @throws Exception + * if listing fails + */ + List listDatasets() throws Exception; + + /** + * Deletes a dataset. + * + * @param datasetId + * the dataset ID to delete + * @throws Exception + * if deletion fails + */ + void deleteDataset(String datasetId) throws Exception; +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetType.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetType.java new file mode 100644 index 0000000000..1b0e093470 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/DatasetType.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Type of dataset based on the target action. + */ +public enum DatasetType { + @JsonProperty("UNKNOWN") + UNKNOWN, + + @JsonProperty("FLOW") + FLOW, + + @JsonProperty("MODEL") + MODEL, + + @JsonProperty("EXECUTABLE_PROMPT") + EXECUTABLE_PROMPT +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalDataPoint.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalDataPoint.java new file mode 100644 index 0000000000..9845f61dc3 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalDataPoint.java @@ -0,0 +1,221 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Represents a single data point for evaluation. + * + *

+ * A data point contains the input, output, and optional context and reference + * data that is used to evaluate the quality of an AI system's output. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EvalDataPoint { + + /** + * Unique identifier for this test case. + */ + @JsonProperty("testCaseId") + private String testCaseId; + + /** + * The input provided to the AI system. + */ + @JsonProperty("input") + private Object input; + + /** + * The output generated by the AI system. + */ + @JsonProperty("output") + private Object output; + + /** + * Error message if the AI system failed to generate output. + */ + @JsonProperty("error") + private String error; + + /** + * Additional context provided to the AI system (e.g., retrieved documents). + */ + @JsonProperty("context") + private List context; + + /** + * The expected/reference output for comparison. + */ + @JsonProperty("reference") + private Object reference; + + /** + * Custom fields for domain-specific evaluation. + */ + @JsonProperty("custom") + private Map custom; + + /** + * Trace IDs associated with this evaluation. + */ + @JsonProperty("traceIds") + private List traceIds; + + public EvalDataPoint() { + } + + private EvalDataPoint(Builder builder) { + this.testCaseId = builder.testCaseId; + this.input = builder.input; + this.output = builder.output; + this.error = builder.error; + this.context = builder.context; + this.reference = builder.reference; + this.custom = builder.custom; + this.traceIds = builder.traceIds; + } + + public static Builder builder() { + return new Builder(); + } + + public String getTestCaseId() { + return testCaseId; + } + + public void setTestCaseId(String testCaseId) { + this.testCaseId = testCaseId; + } + + public Object getInput() { + return input; + } + + public void setInput(Object input) { + this.input = input; + } + + public Object getOutput() { + return output; + } + + public void setOutput(Object output) { + this.output = output; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public List getContext() { + return context; + } + + public void setContext(List context) { + this.context = context; + } + + public Object getReference() { + return reference; + } + + public void setReference(Object reference) { + this.reference = reference; + } + + public Map getCustom() { + return custom; + } + + public void setCustom(Map custom) { + this.custom = custom; + } + + public List getTraceIds() { + return traceIds; + } + + public void setTraceIds(List traceIds) { + this.traceIds = traceIds; + } + + public static class Builder { + private String testCaseId; + private Object input; + private Object output; + private String error; + private List context; + private Object reference; + private Map custom; + private List traceIds; + + public Builder testCaseId(String testCaseId) { + this.testCaseId = testCaseId; + return this; + } + + public Builder input(Object input) { + this.input = input; + return this; + } + + public Builder output(Object output) { + this.output = output; + return this; + } + + public Builder error(String error) { + this.error = error; + return this; + } + + public Builder context(List context) { + this.context = context; + return this; + } + + public Builder reference(Object reference) { + this.reference = reference; + return this; + } + + public Builder custom(Map custom) { + this.custom = custom; + return this; + } + + public Builder traceIds(List traceIds) { + this.traceIds = traceIds; + return this; + } + + public EvalDataPoint build() { + return new EvalDataPoint(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalMetric.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalMetric.java new file mode 100644 index 0000000000..88c90567a4 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalMetric.java @@ -0,0 +1,214 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Represents a single metric score from an evaluator. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EvalMetric { + + /** + * Name of the evaluator that produced this metric. + */ + @JsonProperty("evaluator") + private String evaluator; + + /** + * Optional ID for multi-score evaluators. + */ + @JsonProperty("scoreId") + private String scoreId; + + /** + * The score value. + */ + @JsonProperty("score") + private Object score; + + /** + * The evaluation status. + */ + @JsonProperty("status") + private EvalStatus status; + + /** + * Reasoning/explanation for the score. + */ + @JsonProperty("rationale") + private String rationale; + + /** + * Error message if evaluation failed. + */ + @JsonProperty("error") + private String error; + + /** + * Trace ID associated with this evaluation. + */ + @JsonProperty("traceId") + private String traceId; + + /** + * Span ID within the trace. + */ + @JsonProperty("spanId") + private String spanId; + + public EvalMetric() { + } + + private EvalMetric(Builder builder) { + this.evaluator = builder.evaluator; + this.scoreId = builder.scoreId; + this.score = builder.score; + this.status = builder.status; + this.rationale = builder.rationale; + this.error = builder.error; + this.traceId = builder.traceId; + this.spanId = builder.spanId; + } + + public static Builder builder() { + return new Builder(); + } + + public String getEvaluator() { + return evaluator; + } + + public void setEvaluator(String evaluator) { + this.evaluator = evaluator; + } + + public String getScoreId() { + return scoreId; + } + + public void setScoreId(String scoreId) { + this.scoreId = scoreId; + } + + public Object getScore() { + return score; + } + + public void setScore(Object score) { + this.score = score; + } + + public EvalStatus getStatus() { + return status; + } + + public void setStatus(EvalStatus status) { + this.status = status; + } + + public String getRationale() { + return rationale; + } + + public void setRationale(String rationale) { + this.rationale = rationale; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public String getTraceId() { + return traceId; + } + + public void setTraceId(String traceId) { + this.traceId = traceId; + } + + public String getSpanId() { + return spanId; + } + + public void setSpanId(String spanId) { + this.spanId = spanId; + } + + public static class Builder { + private String evaluator; + private String scoreId; + private Object score; + private EvalStatus status; + private String rationale; + private String error; + private String traceId; + private String spanId; + + public Builder evaluator(String evaluator) { + this.evaluator = evaluator; + return this; + } + + public Builder scoreId(String scoreId) { + this.scoreId = scoreId; + return this; + } + + public Builder score(Object score) { + this.score = score; + return this; + } + + public Builder status(EvalStatus status) { + this.status = status; + return this; + } + + public Builder rationale(String rationale) { + this.rationale = rationale; + return this; + } + + public Builder error(String error) { + this.error = error; + return this; + } + + public Builder traceId(String traceId) { + this.traceId = traceId; + return this; + } + + public Builder spanId(String spanId) { + this.spanId = spanId; + return this; + } + + public EvalMetric build() { + return new EvalMetric(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalRequest.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalRequest.java new file mode 100644 index 0000000000..ffebfd57db --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalRequest.java @@ -0,0 +1,132 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Request to run an evaluator on a dataset. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EvalRequest { + + /** + * The dataset to evaluate. + */ + @JsonProperty("dataset") + private List dataset; + + /** + * Unique identifier for this evaluation run. + */ + @JsonProperty("evalRunId") + private String evalRunId; + + /** + * Options for the evaluator. + */ + @JsonProperty("options") + private Object options; + + /** + * Number of data points to process in each batch. + */ + @JsonProperty("batchSize") + private Integer batchSize; + + public EvalRequest() { + } + + private EvalRequest(Builder builder) { + this.dataset = builder.dataset; + this.evalRunId = builder.evalRunId; + this.options = builder.options; + this.batchSize = builder.batchSize; + } + + public static Builder builder() { + return new Builder(); + } + + public List getDataset() { + return dataset; + } + + public void setDataset(List dataset) { + this.dataset = dataset; + } + + public String getEvalRunId() { + return evalRunId; + } + + public void setEvalRunId(String evalRunId) { + this.evalRunId = evalRunId; + } + + public Object getOptions() { + return options; + } + + public void setOptions(Object options) { + this.options = options; + } + + public Integer getBatchSize() { + return batchSize; + } + + public void setBatchSize(Integer batchSize) { + this.batchSize = batchSize; + } + + public static class Builder { + private List dataset; + private String evalRunId; + private Object options; + private Integer batchSize; + + public Builder dataset(List dataset) { + this.dataset = dataset; + return this; + } + + public Builder evalRunId(String evalRunId) { + this.evalRunId = evalRunId; + return this; + } + + public Builder options(Object options) { + this.options = options; + return this; + } + + public Builder batchSize(Integer batchSize) { + this.batchSize = batchSize; + return this; + } + + public EvalRequest build() { + return new EvalRequest(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalResponse.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalResponse.java new file mode 100644 index 0000000000..cc4d2f11e9 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalResponse.java @@ -0,0 +1,184 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Response from an evaluator for a single test case. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EvalResponse { + + /** + * Index of this sample in the batch (optional). + */ + @JsonProperty("sampleIndex") + private Integer sampleIndex; + + /** + * The test case ID that was evaluated. + */ + @JsonProperty("testCaseId") + private String testCaseId; + + /** + * The trace ID associated with this evaluation. + */ + @JsonProperty("traceId") + private String traceId; + + /** + * The span ID within the trace. + */ + @JsonProperty("spanId") + private String spanId; + + /** + * The evaluation score(s). Can be a single Score or a list of Scores for + * multi-metric evaluators. + */ + @JsonProperty("evaluation") + private Object evaluation; + + public EvalResponse() { + } + + private EvalResponse(Builder builder) { + this.sampleIndex = builder.sampleIndex; + this.testCaseId = builder.testCaseId; + this.traceId = builder.traceId; + this.spanId = builder.spanId; + this.evaluation = builder.evaluation; + } + + public static Builder builder() { + return new Builder(); + } + + public Integer getSampleIndex() { + return sampleIndex; + } + + public void setSampleIndex(Integer sampleIndex) { + this.sampleIndex = sampleIndex; + } + + public String getTestCaseId() { + return testCaseId; + } + + public void setTestCaseId(String testCaseId) { + this.testCaseId = testCaseId; + } + + public String getTraceId() { + return traceId; + } + + public void setTraceId(String traceId) { + this.traceId = traceId; + } + + public String getSpanId() { + return spanId; + } + + public void setSpanId(String spanId) { + this.spanId = spanId; + } + + public Object getEvaluation() { + return evaluation; + } + + /** + * Gets the evaluation as a single Score. + * + * @return the score, or null if the evaluation is a list + */ + public Score getEvaluationAsScore() { + if (evaluation instanceof Score) { + return (Score) evaluation; + } + return null; + } + + /** + * Gets the evaluation as a list of Scores. + * + * @return the list of scores, or null if the evaluation is a single score + */ + @SuppressWarnings("unchecked") + public List getEvaluationAsScoreList() { + if (evaluation instanceof List) { + return (List) evaluation; + } + return null; + } + + public void setEvaluation(Object evaluation) { + this.evaluation = evaluation; + } + + public static class Builder { + private Integer sampleIndex; + private String testCaseId; + private String traceId; + private String spanId; + private Object evaluation; + + public Builder sampleIndex(Integer sampleIndex) { + this.sampleIndex = sampleIndex; + return this; + } + + public Builder testCaseId(String testCaseId) { + this.testCaseId = testCaseId; + return this; + } + + public Builder traceId(String traceId) { + this.traceId = traceId; + return this; + } + + public Builder spanId(String spanId) { + this.spanId = spanId; + return this; + } + + public Builder evaluation(Score evaluation) { + this.evaluation = evaluation; + return this; + } + + public Builder evaluation(List evaluation) { + this.evaluation = evaluation; + return this; + } + + public EvalResponse build() { + return new EvalResponse(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalResult.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalResult.java new file mode 100644 index 0000000000..3fc093c913 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalResult.java @@ -0,0 +1,238 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A single evaluation result combining input data with metric scores. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EvalResult { + + /** + * The test case ID. + */ + @JsonProperty("testCaseId") + private String testCaseId; + + /** + * The input to the evaluated action. + */ + @JsonProperty("input") + private Object input; + + /** + * The output from the evaluated action. + */ + @JsonProperty("output") + private Object output; + + /** + * Error from the evaluated action. + */ + @JsonProperty("error") + private String error; + + /** + * Context used during evaluation. + */ + @JsonProperty("context") + private List context; + + /** + * Reference output for comparison. + */ + @JsonProperty("reference") + private Object reference; + + /** + * Custom fields. + */ + @JsonProperty("custom") + private Map custom; + + /** + * Trace IDs associated with this result. + */ + @JsonProperty("traceIds") + private List traceIds; + + /** + * Metrics from all evaluators. + */ + @JsonProperty("metrics") + private List metrics; + + public EvalResult() { + } + + private EvalResult(Builder builder) { + this.testCaseId = builder.testCaseId; + this.input = builder.input; + this.output = builder.output; + this.error = builder.error; + this.context = builder.context; + this.reference = builder.reference; + this.custom = builder.custom; + this.traceIds = builder.traceIds; + this.metrics = builder.metrics; + } + + public static Builder builder() { + return new Builder(); + } + + public String getTestCaseId() { + return testCaseId; + } + + public void setTestCaseId(String testCaseId) { + this.testCaseId = testCaseId; + } + + public Object getInput() { + return input; + } + + public void setInput(Object input) { + this.input = input; + } + + public Object getOutput() { + return output; + } + + public void setOutput(Object output) { + this.output = output; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public List getContext() { + return context; + } + + public void setContext(List context) { + this.context = context; + } + + public Object getReference() { + return reference; + } + + public void setReference(Object reference) { + this.reference = reference; + } + + public Map getCustom() { + return custom; + } + + public void setCustom(Map custom) { + this.custom = custom; + } + + public List getTraceIds() { + return traceIds; + } + + public void setTraceIds(List traceIds) { + this.traceIds = traceIds; + } + + public List getMetrics() { + return metrics; + } + + public void setMetrics(List metrics) { + this.metrics = metrics; + } + + public static class Builder { + private String testCaseId; + private Object input; + private Object output; + private String error; + private List context; + private Object reference; + private Map custom; + private List traceIds; + private List metrics; + + public Builder testCaseId(String testCaseId) { + this.testCaseId = testCaseId; + return this; + } + + public Builder input(Object input) { + this.input = input; + return this; + } + + public Builder output(Object output) { + this.output = output; + return this; + } + + public Builder error(String error) { + this.error = error; + return this; + } + + public Builder context(List context) { + this.context = context; + return this; + } + + public Builder reference(Object reference) { + this.reference = reference; + return this; + } + + public Builder custom(Map custom) { + this.custom = custom; + return this; + } + + public Builder traceIds(List traceIds) { + this.traceIds = traceIds; + return this; + } + + public Builder metrics(List metrics) { + this.metrics = metrics; + return this; + } + + public EvalResult build() { + return new EvalResult(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalRun.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalRun.java new file mode 100644 index 0000000000..4b00fe6a03 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalRun.java @@ -0,0 +1,90 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Represents a complete evaluation run with results. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EvalRun { + + /** + * Key identifying this evaluation run. + */ + @JsonProperty("key") + private EvalRunKey key; + + /** + * Results for all test cases. + */ + @JsonProperty("results") + private List results; + + public EvalRun() { + } + + private EvalRun(Builder builder) { + this.key = builder.key; + this.results = builder.results; + } + + public static Builder builder() { + return new Builder(); + } + + public EvalRunKey getKey() { + return key; + } + + public void setKey(EvalRunKey key) { + this.key = key; + } + + public List getResults() { + return results; + } + + public void setResults(List results) { + this.results = results; + } + + public static class Builder { + private EvalRunKey key; + private List results; + + public Builder key(EvalRunKey key) { + this.key = key; + return this; + } + + public Builder results(List results) { + this.results = results; + return this; + } + + public EvalRun build() { + return new EvalRun(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalRunKey.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalRunKey.java new file mode 100644 index 0000000000..487e1a6669 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalRunKey.java @@ -0,0 +1,172 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Key that uniquely identifies an evaluation run. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EvalRunKey { + + /** + * The action that was evaluated. + */ + @JsonProperty("actionRef") + private String actionRef; + + /** + * The dataset used for evaluation. + */ + @JsonProperty("datasetId") + private String datasetId; + + /** + * The version of the dataset used. + */ + @JsonProperty("datasetVersion") + private Integer datasetVersion; + + /** + * Unique identifier for this evaluation run. + */ + @JsonProperty("evalRunId") + private String evalRunId; + + /** + * When the evaluation was created. + */ + @JsonProperty("createdAt") + private String createdAt; + + /** + * Configuration used for the action. + */ + @JsonProperty("actionConfig") + private Object actionConfig; + + public EvalRunKey() { + } + + private EvalRunKey(Builder builder) { + this.actionRef = builder.actionRef; + this.datasetId = builder.datasetId; + this.datasetVersion = builder.datasetVersion; + this.evalRunId = builder.evalRunId; + this.createdAt = builder.createdAt; + this.actionConfig = builder.actionConfig; + } + + public static Builder builder() { + return new Builder(); + } + + public String getActionRef() { + return actionRef; + } + + public void setActionRef(String actionRef) { + this.actionRef = actionRef; + } + + public String getDatasetId() { + return datasetId; + } + + public void setDatasetId(String datasetId) { + this.datasetId = datasetId; + } + + public Integer getDatasetVersion() { + return datasetVersion; + } + + public void setDatasetVersion(Integer datasetVersion) { + this.datasetVersion = datasetVersion; + } + + public String getEvalRunId() { + return evalRunId; + } + + public void setEvalRunId(String evalRunId) { + this.evalRunId = evalRunId; + } + + public String getCreatedAt() { + return createdAt; + } + + public void setCreatedAt(String createdAt) { + this.createdAt = createdAt; + } + + public Object getActionConfig() { + return actionConfig; + } + + public void setActionConfig(Object actionConfig) { + this.actionConfig = actionConfig; + } + + public static class Builder { + private String actionRef; + private String datasetId; + private Integer datasetVersion; + private String evalRunId; + private String createdAt; + private Object actionConfig; + + public Builder actionRef(String actionRef) { + this.actionRef = actionRef; + return this; + } + + public Builder datasetId(String datasetId) { + this.datasetId = datasetId; + return this; + } + + public Builder datasetVersion(Integer datasetVersion) { + this.datasetVersion = datasetVersion; + return this; + } + + public Builder evalRunId(String evalRunId) { + this.evalRunId = evalRunId; + return this; + } + + public Builder createdAt(String createdAt) { + this.createdAt = createdAt; + return this; + } + + public Builder actionConfig(Object actionConfig) { + this.actionConfig = actionConfig; + return this; + } + + public EvalRunKey build() { + return new EvalRunKey(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalStatus.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalStatus.java new file mode 100644 index 0000000000..e285859262 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalStatus.java @@ -0,0 +1,35 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Status of an evaluation result. + */ +public enum EvalStatus { + @JsonProperty("UNKNOWN") + UNKNOWN, + + @JsonProperty("PASS") + PASS, + + @JsonProperty("FAIL") + FAIL +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalStore.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalStore.java new file mode 100644 index 0000000000..4a88edcd61 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvalStore.java @@ -0,0 +1,80 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; + +/** + * Interface for storing and retrieving evaluation runs. + */ +public interface EvalStore { + + /** + * Saves an evaluation run. + * + * @param evalRun + * the evaluation run to save + * @throws Exception + * if save fails + */ + void save(EvalRun evalRun) throws Exception; + + /** + * Loads an evaluation run by ID. + * + * @param evalRunId + * the evaluation run ID + * @return the evaluation run, or null if not found + * @throws Exception + * if load fails + */ + EvalRun load(String evalRunId) throws Exception; + + /** + * Lists all evaluation run keys. + * + * @return list of evaluation run keys + * @throws Exception + * if listing fails + */ + List list() throws Exception; + + /** + * Lists evaluation run keys with optional filtering. + * + * @param actionRef + * filter by action reference + * @param datasetId + * filter by dataset ID + * @return filtered list of evaluation run keys + * @throws Exception + * if listing fails + */ + List list(String actionRef, String datasetId) throws Exception; + + /** + * Deletes an evaluation run. + * + * @param evalRunId + * the evaluation run ID to delete + * @throws Exception + * if deletion fails + */ + void delete(String evalRunId) throws Exception; +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvaluationManager.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvaluationManager.java new file mode 100644 index 0000000000..8bdafc6a10 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvaluationManager.java @@ -0,0 +1,308 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.time.Instant; +import java.util.*; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.core.*; + +/** + * Manages the execution of evaluations. + * + *

+ * The EvaluationManager coordinates running evaluations by: + *

    + *
  • Loading datasets from the dataset store
  • + *
  • Running inference on the target action
  • + *
  • Executing evaluators on the results
  • + *
  • Storing evaluation results
  • + *
+ */ +public class EvaluationManager { + + private static final Logger logger = LoggerFactory.getLogger(EvaluationManager.class); + + private final Registry registry; + private final DatasetStore datasetStore; + private final EvalStore evalStore; + + /** + * Creates a new EvaluationManager. + * + * @param registry + * the Genkit registry + */ + public EvaluationManager(Registry registry) { + this(registry, LocalFileDatasetStore.getInstance(), LocalFileEvalStore.getInstance()); + } + + /** + * Creates a new EvaluationManager with custom stores. + * + * @param registry + * the Genkit registry + * @param datasetStore + * the dataset store + * @param evalStore + * the eval store + */ + public EvaluationManager(Registry registry, DatasetStore datasetStore, EvalStore evalStore) { + this.registry = registry; + this.datasetStore = datasetStore; + this.evalStore = evalStore; + } + + /** + * Runs a new evaluation. + * + * @param request + * the evaluation request + * @return the evaluation run key + * @throws Exception + * if evaluation fails + */ + public EvalRunKey runEvaluation(RunEvaluationRequest request) throws Exception { + String evalRunId = UUID.randomUUID().toString(); + String actionRef = request.getTargetAction(); + + // 1. Load or parse dataset + List dataset; + String datasetId = null; + Integer datasetVersion = null; + + if (request.getDataSource().getDatasetId() != null) { + datasetId = request.getDataSource().getDatasetId(); + dataset = datasetStore.getDataset(datasetId); + List metadataList = datasetStore.listDatasets(); + for (DatasetMetadata m : metadataList) { + if (m.getDatasetId().equals(datasetId)) { + datasetVersion = m.getVersion(); + break; + } + } + } else { + dataset = request.getDataSource().getData(); + } + + if (dataset == null || dataset.isEmpty()) { + throw new IllegalArgumentException("Dataset is empty"); + } + + // Ensure all samples have testCaseIds + for (int i = 0; i < dataset.size(); i++) { + DatasetSample sample = dataset.get(i); + if (sample.getTestCaseId() == null) { + sample.setTestCaseId("test_case_" + (i + 1)); + } + } + + // 2. Run inference on the target action + List evalDataset = runInference(actionRef, dataset, request.getOptions()); + + // 3. Get matching evaluator actions + List evaluatorNames = request.getEvaluators(); + if (evaluatorNames == null || evaluatorNames.isEmpty()) { + // Get all evaluators + evaluatorNames = getAllEvaluatorNames(); + } + + // 4. Run evaluation + Map> allScores = new HashMap<>(); + int batchSize = request.getOptions() != null && request.getOptions().getBatchSize() != null + ? request.getOptions().getBatchSize() + : 10; + + for (String evaluatorName : evaluatorNames) { + String evalKey = ActionType.EVALUATOR.keyFromName(evaluatorName); + Action evaluatorAction = registry.lookupAction(evalKey); + + if (evaluatorAction == null) { + logger.warn("Evaluator not found: {}", evaluatorName); + continue; + } + + try { + // Filter out data points with errors + List validDataPoints = new ArrayList<>(); + for (EvalDataPoint dp : evalDataset) { + if (dp.getError() == null) { + validDataPoints.add(dp); + } + } + + EvalRequest evalRequest = EvalRequest.builder().dataset(validDataPoints).evalRunId(evalRunId) + .batchSize(batchSize).build(); + + ActionContext ctx = new ActionContext(registry); + @SuppressWarnings("unchecked") + List responses = ((Action, ?>) evaluatorAction).run(ctx, + evalRequest); + + allScores.put(evaluatorName, responses); + } catch (Exception e) { + logger.error("Error running evaluator: {}", evaluatorName, e); + } + } + + // 5. Combine scores with dataset + List results = combineResults(evalDataset, allScores); + + // 6. Create and save eval run + EvalRunKey key = EvalRunKey.builder().evalRunId(evalRunId).actionRef(actionRef).datasetId(datasetId) + .datasetVersion(datasetVersion).createdAt(Instant.now().toString()) + .actionConfig(request.getOptions() != null ? request.getOptions().getActionConfig() : null).build(); + + EvalRun evalRun = EvalRun.builder().key(key).results(results).build(); + + evalStore.save(evalRun); + + logger.info("Completed evaluation run: {} with {} results", evalRunId, results.size()); + return key; + } + + /** + * Runs inference on the target action for all dataset samples. + */ + private List runInference(String actionRef, List dataset, + RunEvaluationRequest.EvaluationOptions options) { + + List evalDataset = new ArrayList<>(); + Action action = registry.lookupAction(actionRef); + + for (DatasetSample sample : dataset) { + EvalDataPoint.Builder dpBuilder = EvalDataPoint.builder().testCaseId(sample.getTestCaseId()) + .input(sample.getInput()).reference(sample.getReference()).traceIds(new ArrayList<>()); + + if (action != null) { + try { + ActionContext ctx = new ActionContext(registry); + Object input = sample.getInput(); + + @SuppressWarnings("unchecked") + Object output = ((Action) action).run(ctx, input); + dpBuilder.output(output); + } catch (Exception e) { + logger.error("Error running inference for test case: {}", sample.getTestCaseId(), e); + dpBuilder.error(e.getMessage()); + } + } else { + logger.warn("Action not found: {}. Using input as output.", actionRef); + dpBuilder.output(sample.getInput()); + } + + evalDataset.add(dpBuilder.build()); + } + + return evalDataset; + } + + /** + * Gets all registered evaluator names. + */ + private List getAllEvaluatorNames() { + List names = new ArrayList<>(); + for (Action action : registry.listActions()) { + if (action.getType() == ActionType.EVALUATOR) { + names.add(action.getName()); + } + } + return names; + } + + /** + * Combines evaluation results with scores from all evaluators. + */ + private List combineResults(List evalDataset, + Map> allScores) { + + // Create a map of testCaseId to EvalResult + Map resultBuilders = new LinkedHashMap<>(); + + for (EvalDataPoint dp : evalDataset) { + resultBuilders.put(dp.getTestCaseId(), + EvalResult.builder().testCaseId(dp.getTestCaseId()).input(dp.getInput()).output(dp.getOutput()) + .error(dp.getError()).context(dp.getContext()).reference(dp.getReference()) + .traceIds(dp.getTraceIds()).metrics(new ArrayList<>())); + } + + // Add scores from each evaluator + for (Map.Entry> entry : allScores.entrySet()) { + String evaluatorName = entry.getKey(); + + for (EvalResponse response : entry.getValue()) { + EvalResult.Builder builder = resultBuilders.get(response.getTestCaseId()); + if (builder == null) { + continue; + } + + Object evaluation = response.getEvaluation(); + if (evaluation instanceof Score) { + Score score = (Score) evaluation; + EvalMetric metric = scoreToMetric(evaluatorName, score, response); + // Need to build, modify, and rebuild since metrics is already set + EvalResult temp = builder.build(); + temp.getMetrics().add(metric); + } else if (evaluation instanceof List) { + @SuppressWarnings("unchecked") + List scores = (List) evaluation; + for (Score score : scores) { + EvalMetric metric = scoreToMetric(evaluatorName, score, response); + EvalResult temp = builder.build(); + temp.getMetrics().add(metric); + } + } + } + } + + List results = new ArrayList<>(); + for (EvalResult.Builder builder : resultBuilders.values()) { + results.add(builder.build()); + } + return results; + } + + private EvalMetric scoreToMetric(String evaluatorName, Score score, EvalResponse response) { + String rationale = null; + if (score.getDetails() != null) { + rationale = score.getDetails().getReasoning(); + } + + return EvalMetric.builder().evaluator(evaluatorName).scoreId(score.getId()).score(score.getScore()) + .status(score.getStatus()).rationale(rationale).error(score.getError()).traceId(response.getTraceId()) + .spanId(response.getSpanId()).build(); + } + + /** + * Gets the dataset store. + */ + public DatasetStore getDatasetStore() { + return datasetStore; + } + + /** + * Gets the eval store. + */ + public EvalStore getEvalStore() { + return evalStore; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/Evaluator.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/Evaluator.java new file mode 100644 index 0000000000..8b14cadadf --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/Evaluator.java @@ -0,0 +1,374 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.*; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.genkit.core.*; + +/** + * Evaluator represents an evaluation action that assesses the quality of AI + * outputs. + * + *

+ * Evaluators are a core primitive in Genkit that allow you to measure and track + * the quality of your AI applications. They can be used to: + *

    + *
  • Score outputs based on various criteria (accuracy, relevance, etc.)
  • + *
  • Compare outputs against reference data
  • + *
  • Run automated quality checks in CI/CD pipelines
  • + *
  • Monitor production quality over time
  • + *
+ * + * @param + * the type of evaluator-specific options + */ +public class Evaluator implements Action, Void> { + + private static final Logger logger = LoggerFactory.getLogger(Evaluator.class); + + public static final String METADATA_KEY_DISPLAY_NAME = "evaluatorDisplayName"; + public static final String METADATA_KEY_DEFINITION = "evaluatorDefinition"; + public static final String METADATA_KEY_IS_BILLED = "evaluatorIsBilled"; + + private final String name; + private final EvaluatorInfo info; + private final EvaluatorFn evaluatorFn; + private final Class optionsClass; + private final ActionDesc desc; + + private Evaluator(Builder builder) { + this.name = builder.name; + this.info = EvaluatorInfo.builder().displayName(builder.displayName).definition(builder.definition) + .isBilled(builder.isBilled).build(); + this.evaluatorFn = builder.evaluatorFn; + this.optionsClass = builder.optionsClass; + + // Build metadata + Map metadata = new HashMap<>(); + metadata.put("type", "evaluator"); + Map evaluatorMetadata = new HashMap<>(); + evaluatorMetadata.put(METADATA_KEY_DISPLAY_NAME, builder.displayName); + evaluatorMetadata.put(METADATA_KEY_DEFINITION, builder.definition); + evaluatorMetadata.put(METADATA_KEY_IS_BILLED, builder.isBilled); + metadata.put("evaluator", evaluatorMetadata); + + this.desc = ActionDesc.builder().type(ActionType.EVALUATOR).name(name).description(builder.definition) + .metadata(metadata).build(); + } + + /** + * Creates a new Evaluator builder. + * + * @param + * the options type + * @return a new builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Defines a new evaluator and registers it with the registry. + * + * @param + * the options type + * @param registry + * the registry to register with + * @param name + * the evaluator name + * @param displayName + * the display name shown in the UI + * @param definition + * description of what the evaluator measures + * @param evaluatorFn + * the evaluation function + * @return the created evaluator + */ + public static Evaluator define(Registry registry, String name, String displayName, String definition, + EvaluatorFn evaluatorFn) { + return define(registry, name, displayName, definition, true, null, evaluatorFn); + } + + /** + * Defines a new evaluator with full options and registers it with the registry. + * + * @param + * the options type + * @param registry + * the registry to register with + * @param name + * the evaluator name + * @param displayName + * the display name shown in the UI + * @param definition + * description of what the evaluator measures + * @param isBilled + * whether using this evaluator incurs costs + * @param optionsClass + * the class for evaluator-specific options + * @param evaluatorFn + * the evaluation function + * @return the created evaluator + */ + public static Evaluator define(Registry registry, String name, String displayName, String definition, + boolean isBilled, Class optionsClass, EvaluatorFn evaluatorFn) { + + Evaluator evaluator = Evaluator.builder().name(name).displayName(displayName).definition(definition) + .isBilled(isBilled).optionsClass(optionsClass).evaluatorFn(evaluatorFn).build(); + + evaluator.register(registry); + return evaluator; + } + + @Override + public String getName() { + return name; + } + + @Override + public ActionType getType() { + return ActionType.EVALUATOR; + } + + @Override + public ActionDesc getDesc() { + return desc; + } + + /** + * Gets the evaluator info. + * + * @return the evaluator info + */ + public EvaluatorInfo getInfo() { + return info; + } + + @Override + public List run(ActionContext ctx, EvalRequest input) throws GenkitException { + return run(ctx, input, null); + } + + @Override + public List run(ActionContext ctx, EvalRequest input, Consumer streamCallback) + throws GenkitException { + List responses = new ArrayList<>(); + List dataset = input.getDataset(); + + if (dataset == null || dataset.isEmpty()) { + return responses; + } + + int batchSize = input.getBatchSize() != null ? input.getBatchSize() : 10; + + // Process in batches + List> batches = batchList(dataset, batchSize); + int sampleIndex = 0; + + for (List batch : batches) { + for (EvalDataPoint dataPoint : batch) { + try { + @SuppressWarnings("unchecked") + O options = input.getOptions() != null && optionsClass != null + ? JsonUtils.getObjectMapper().convertValue(input.getOptions(), optionsClass) + : null; + + EvalResponse response = evaluatorFn.evaluate(dataPoint, options); + if (response.getSampleIndex() == null) { + response.setSampleIndex(sampleIndex); + } + responses.add(response); + } catch (Exception e) { + logger.error("Error evaluating data point: {}", dataPoint.getTestCaseId(), e); + // Create an error response + Score errorScore = Score.builder().error(e.getMessage()).status(EvalStatus.UNKNOWN).build(); + responses.add(EvalResponse.builder().testCaseId(dataPoint.getTestCaseId()).sampleIndex(sampleIndex) + .evaluation(errorScore).build()); + } + sampleIndex++; + } + } + + return responses; + } + + @Override + public JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException { + EvalRequest request = JsonUtils.fromJsonNode(input, EvalRequest.class); + List result = run(ctx, request, null); + return JsonUtils.toJsonNode(result); + } + + @Override + public ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, + Consumer streamCallback) throws GenkitException { + JsonNode result = runJson(ctx, input, streamCallback); + return new ActionRunResult<>(result, null, null); + } + + @Override + public void register(Registry registry) { + String key = ActionType.EVALUATOR.keyFromName(name); + registry.registerAction(key, this); + logger.info("Registered evaluator: {}", key); + } + + @Override + public Map getInputSchema() { + // Define the input schema for EvalRequest + Map schema = new HashMap<>(); + schema.put("type", "object"); + + Map properties = new HashMap<>(); + + // dataset property + Map datasetProp = new HashMap<>(); + datasetProp.put("type", "array"); + Map dataPointSchema = new HashMap<>(); + dataPointSchema.put("type", "object"); + datasetProp.put("items", dataPointSchema); + properties.put("dataset", datasetProp); + + // evalRunId property + Map evalRunIdProp = new HashMap<>(); + evalRunIdProp.put("type", "string"); + properties.put("evalRunId", evalRunIdProp); + + // batchSize property + Map batchSizeProp = new HashMap<>(); + batchSizeProp.put("type", "integer"); + properties.put("batchSize", batchSizeProp); + + schema.put("properties", properties); + schema.put("required", Arrays.asList("dataset")); + + return schema; + } + + @Override + public Map getOutputSchema() { + // Define the output schema for List + Map schema = new HashMap<>(); + schema.put("type", "array"); + + Map itemSchema = new HashMap<>(); + itemSchema.put("type", "object"); + + Map itemProps = new HashMap<>(); + + // testCaseId + Map testCaseIdProp = new HashMap<>(); + testCaseIdProp.put("type", "string"); + itemProps.put("testCaseId", testCaseIdProp); + + // evaluation (Score) + Map evaluationProp = new HashMap<>(); + evaluationProp.put("type", "object"); + itemProps.put("evaluation", evaluationProp); + + itemSchema.put("properties", itemProps); + schema.put("items", itemSchema); + + return schema; + } + + @Override + public Map getMetadata() { + return desc.getMetadata(); + } + + /** + * Splits a list into batches. + */ + private static List> batchList(List list, int batchSize) { + List> batches = new ArrayList<>(); + for (int i = 0; i < list.size(); i += batchSize) { + batches.add(list.subList(i, Math.min(i + batchSize, list.size()))); + } + return batches; + } + + /** + * Builder for creating Evaluator instances. + * + * @param + * the options type + */ + public static class Builder { + private String name; + private String displayName; + private String definition; + private boolean isBilled = true; + private Class optionsClass; + private EvaluatorFn evaluatorFn; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder displayName(String displayName) { + this.displayName = displayName; + return this; + } + + public Builder definition(String definition) { + this.definition = definition; + return this; + } + + public Builder isBilled(boolean isBilled) { + this.isBilled = isBilled; + return this; + } + + public Builder optionsClass(Class optionsClass) { + this.optionsClass = optionsClass; + return this; + } + + public Builder evaluatorFn(EvaluatorFn evaluatorFn) { + this.evaluatorFn = evaluatorFn; + return this; + } + + public Evaluator build() { + if (name == null || name.isEmpty()) { + throw new IllegalArgumentException("Evaluator name is required"); + } + if (displayName == null) { + displayName = name; + } + if (definition == null) { + definition = ""; + } + if (evaluatorFn == null) { + throw new IllegalArgumentException("Evaluator function is required"); + } + return new Evaluator<>(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvaluatorFn.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvaluatorFn.java new file mode 100644 index 0000000000..0c2fbe0961 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvaluatorFn.java @@ -0,0 +1,46 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +/** + * Functional interface for evaluator functions. + * + *

+ * An evaluator function takes a data point and optional options, and returns an + * evaluation response containing scores. + * + * @param + * the type of evaluator-specific options + */ +@FunctionalInterface +public interface EvaluatorFn { + + /** + * Evaluates a single data point. + * + * @param dataPoint + * the data point to evaluate + * @param options + * optional evaluator-specific options + * @return the evaluation response + * @throws Exception + * if evaluation fails + */ + EvalResponse evaluate(EvalDataPoint dataPoint, O options) throws Exception; +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvaluatorInfo.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvaluatorInfo.java new file mode 100644 index 0000000000..a96b9eb8ef --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/EvaluatorInfo.java @@ -0,0 +1,109 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Information about an evaluator including display metadata and metrics. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class EvaluatorInfo { + + /** + * Display name for the evaluator. + */ + @JsonProperty("displayName") + private String displayName; + + /** + * Description of what the evaluator measures. + */ + @JsonProperty("definition") + private String definition; + + /** + * Whether using this evaluator incurs costs (e.g., LLM API calls). + */ + @JsonProperty("isBilled") + private Boolean isBilled; + + public EvaluatorInfo() { + } + + private EvaluatorInfo(Builder builder) { + this.displayName = builder.displayName; + this.definition = builder.definition; + this.isBilled = builder.isBilled; + } + + public static Builder builder() { + return new Builder(); + } + + public String getDisplayName() { + return displayName; + } + + public void setDisplayName(String displayName) { + this.displayName = displayName; + } + + public String getDefinition() { + return definition; + } + + public void setDefinition(String definition) { + this.definition = definition; + } + + public Boolean getIsBilled() { + return isBilled; + } + + public void setIsBilled(Boolean isBilled) { + this.isBilled = isBilled; + } + + public static class Builder { + private String displayName; + private String definition; + private Boolean isBilled; + + public Builder displayName(String displayName) { + this.displayName = displayName; + return this; + } + + public Builder definition(String definition) { + this.definition = definition; + return this; + } + + public Builder isBilled(Boolean isBilled) { + this.isBilled = isBilled; + return this; + } + + public EvaluatorInfo build() { + return new EvaluatorInfo(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/LocalFileDatasetStore.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/LocalFileDatasetStore.java new file mode 100644 index 0000000000..eacd124e17 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/LocalFileDatasetStore.java @@ -0,0 +1,258 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Instant; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.google.genkit.core.JsonUtils; + +/** + * File-based implementation of DatasetStore. + * + *

+ * Stores datasets in the .genkit/datasets directory with: + *

    + *
  • index.json - metadata for all datasets
  • + *
  • {datasetId}.json - actual dataset data
  • + *
+ */ +public class LocalFileDatasetStore implements DatasetStore { + + private static final Logger logger = LoggerFactory.getLogger(LocalFileDatasetStore.class); + + private static final String GENKIT_DIR = ".genkit"; + private static final String DATASETS_DIR = "datasets"; + private static final String INDEX_FILE = "index.json"; + + private static LocalFileDatasetStore instance; + + private final Path storeRoot; + private final Path indexFile; + private final Map indexCache; + + /** + * Gets the singleton instance of the dataset store. + * + * @return the dataset store instance + */ + public static synchronized LocalFileDatasetStore getInstance() { + if (instance == null) { + instance = new LocalFileDatasetStore(); + } + return instance; + } + + /** + * Creates a new LocalFileDatasetStore using the default location. + */ + public LocalFileDatasetStore() { + this(Paths.get(System.getProperty("user.dir"), GENKIT_DIR, DATASETS_DIR)); + } + + /** + * Creates a new LocalFileDatasetStore with a custom root path. + * + * @param storeRoot + * the root directory for storing datasets + */ + public LocalFileDatasetStore(Path storeRoot) { + this.storeRoot = storeRoot; + this.indexFile = storeRoot.resolve(INDEX_FILE); + this.indexCache = new ConcurrentHashMap<>(); + initializeStore(); + } + + private void initializeStore() { + try { + Files.createDirectories(storeRoot); + if (!Files.exists(indexFile)) { + saveIndex(new HashMap<>()); + } + loadIndex(); + } catch (IOException e) { + logger.error("Failed to initialize dataset store", e); + throw new RuntimeException("Failed to initialize dataset store", e); + } + } + + private void loadIndex() throws IOException { + if (Files.exists(indexFile)) { + String content = Files.readString(indexFile); + Map index = JsonUtils.getObjectMapper().readValue(content, + new TypeReference>() { + }); + indexCache.clear(); + indexCache.putAll(index); + } + } + + private void saveIndex(Map index) throws IOException { + String json = JsonUtils.toJson(index); + Files.writeString(indexFile, json); + } + + private Path getDatasetFile(String datasetId) { + return storeRoot.resolve(datasetId + ".json"); + } + + private String generateDatasetId() { + return "dataset_" + UUID.randomUUID().toString().replace("-", "").substring(0, 12); + } + + @Override + public DatasetMetadata createDataset(CreateDatasetRequest request) throws Exception { + String datasetId = request.getDatasetId(); + if (datasetId == null || datasetId.isEmpty()) { + datasetId = generateDatasetId(); + } + + if (indexCache.containsKey(datasetId)) { + throw new IllegalArgumentException("Dataset already exists: " + datasetId); + } + + List data = request.getData(); + if (data == null) { + data = new ArrayList<>(); + } + + // Ensure all samples have testCaseIds + for (int i = 0; i < data.size(); i++) { + DatasetSample sample = data.get(i); + if (sample.getTestCaseId() == null || sample.getTestCaseId().isEmpty()) { + sample.setTestCaseId("test_case_" + (i + 1)); + } + } + + String now = Instant.now().toString(); + DatasetMetadata metadata = DatasetMetadata.builder().datasetId(datasetId).size(data.size()) + .schema(request.getSchema()) + .datasetType(request.getDatasetType() != null ? request.getDatasetType() : DatasetType.UNKNOWN) + .targetAction(request.getTargetAction()) + .metricRefs(request.getMetricRefs() != null ? request.getMetricRefs() : new ArrayList<>()).version(1) + .createTime(now).updateTime(now).build(); + + // Save the dataset data + String dataJson = JsonUtils.toJson(data); + Files.writeString(getDatasetFile(datasetId), dataJson); + + // Update the index + indexCache.put(datasetId, metadata); + saveIndex(indexCache); + + logger.info("Created dataset: {} with {} samples", datasetId, data.size()); + return metadata; + } + + @Override + public DatasetMetadata updateDataset(UpdateDatasetRequest request) throws Exception { + String datasetId = request.getDatasetId(); + if (datasetId == null || datasetId.isEmpty()) { + throw new IllegalArgumentException("Dataset ID is required"); + } + + DatasetMetadata existing = indexCache.get(datasetId); + if (existing == null) { + throw new IllegalArgumentException("Dataset not found: " + datasetId); + } + + List data = request.getData(); + int size = existing.getSize(); + + if (data != null) { + // Ensure all samples have testCaseIds + for (int i = 0; i < data.size(); i++) { + DatasetSample sample = data.get(i); + if (sample.getTestCaseId() == null || sample.getTestCaseId().isEmpty()) { + sample.setTestCaseId("test_case_" + (i + 1)); + } + } + size = data.size(); + + // Save the updated dataset data + String dataJson = JsonUtils.toJson(data); + Files.writeString(getDatasetFile(datasetId), dataJson); + } + + String now = Instant.now().toString(); + DatasetMetadata updated = DatasetMetadata.builder().datasetId(datasetId).size(size) + .schema(request.getSchema() != null ? request.getSchema() : existing.getSchema()) + .datasetType(existing.getDatasetType()) + .targetAction( + request.getTargetAction() != null ? request.getTargetAction() : existing.getTargetAction()) + .metricRefs(request.getMetricRefs() != null ? request.getMetricRefs() : existing.getMetricRefs()) + .version(existing.getVersion() + 1).createTime(existing.getCreateTime()).updateTime(now).build(); + + // Update the index + indexCache.put(datasetId, updated); + saveIndex(indexCache); + + logger.info("Updated dataset: {} (version {})", datasetId, updated.getVersion()); + return updated; + } + + @Override + public List getDataset(String datasetId) throws Exception { + if (!indexCache.containsKey(datasetId)) { + throw new IllegalArgumentException("Dataset not found: " + datasetId); + } + + Path datasetFile = getDatasetFile(datasetId); + if (!Files.exists(datasetFile)) { + throw new IllegalArgumentException("Dataset data file not found: " + datasetId); + } + + String content = Files.readString(datasetFile); + return JsonUtils.getObjectMapper().readValue(content, new TypeReference>() { + }); + } + + @Override + public List listDatasets() throws Exception { + // Reload index to get latest data + loadIndex(); + return new ArrayList<>(indexCache.values()); + } + + @Override + public void deleteDataset(String datasetId) throws Exception { + if (!indexCache.containsKey(datasetId)) { + throw new IllegalArgumentException("Dataset not found: " + datasetId); + } + + // Delete the data file + Path datasetFile = getDatasetFile(datasetId); + Files.deleteIfExists(datasetFile); + + // Update the index + indexCache.remove(datasetId); + saveIndex(indexCache); + + logger.info("Deleted dataset: {}", datasetId); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/LocalFileEvalStore.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/LocalFileEvalStore.java new file mode 100644 index 0000000000..666f1c97e8 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/LocalFileEvalStore.java @@ -0,0 +1,194 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.google.genkit.core.JsonUtils; + +/** + * File-based implementation of EvalStore. + * + *

+ * Stores evaluation runs in the .genkit/evals directory with: + *

    + *
  • index.json - metadata for all eval runs
  • + *
  • {evalRunId}.json - actual eval run data
  • + *
+ */ +public class LocalFileEvalStore implements EvalStore { + + private static final Logger logger = LoggerFactory.getLogger(LocalFileEvalStore.class); + + private static final String GENKIT_DIR = ".genkit"; + private static final String EVALS_DIR = "evals"; + private static final String INDEX_FILE = "index.json"; + + private static LocalFileEvalStore instance; + + private final Path storeRoot; + private final Path indexFile; + private final Map indexCache; + + /** + * Gets the singleton instance of the eval store. + * + * @return the eval store instance + */ + public static synchronized LocalFileEvalStore getInstance() { + if (instance == null) { + instance = new LocalFileEvalStore(); + } + return instance; + } + + /** + * Creates a new LocalFileEvalStore using the default location. + */ + public LocalFileEvalStore() { + this(Paths.get(System.getProperty("user.dir"), GENKIT_DIR, EVALS_DIR)); + } + + /** + * Creates a new LocalFileEvalStore with a custom root path. + * + * @param storeRoot + * the root directory for storing eval runs + */ + public LocalFileEvalStore(Path storeRoot) { + this.storeRoot = storeRoot; + this.indexFile = storeRoot.resolve(INDEX_FILE); + this.indexCache = new ConcurrentHashMap<>(); + initializeStore(); + } + + private void initializeStore() { + try { + Files.createDirectories(storeRoot); + if (!Files.exists(indexFile)) { + saveIndex(new HashMap<>()); + } + loadIndex(); + } catch (IOException e) { + logger.error("Failed to initialize eval store", e); + throw new RuntimeException("Failed to initialize eval store", e); + } + } + + private void loadIndex() throws IOException { + if (Files.exists(indexFile)) { + String content = Files.readString(indexFile); + Map index = JsonUtils.getObjectMapper().readValue(content, + new TypeReference>() { + }); + indexCache.clear(); + indexCache.putAll(index); + } + } + + private void saveIndex(Map index) throws IOException { + String json = JsonUtils.toJson(index); + Files.writeString(indexFile, json); + } + + private Path getEvalRunFile(String evalRunId) { + return storeRoot.resolve(evalRunId + ".json"); + } + + @Override + public void save(EvalRun evalRun) throws Exception { + if (evalRun.getKey() == null || evalRun.getKey().getEvalRunId() == null) { + throw new IllegalArgumentException("EvalRun must have a key with evalRunId"); + } + + String evalRunId = evalRun.getKey().getEvalRunId(); + + // Save the eval run data + String dataJson = JsonUtils.toJson(evalRun); + Files.writeString(getEvalRunFile(evalRunId), dataJson); + + // Update the index + indexCache.put(evalRunId, evalRun.getKey()); + saveIndex(indexCache); + + logger.info("Saved eval run: {}", evalRunId); + } + + @Override + public EvalRun load(String evalRunId) throws Exception { + Path evalRunFile = getEvalRunFile(evalRunId); + if (!Files.exists(evalRunFile)) { + return null; + } + + String content = Files.readString(evalRunFile); + return JsonUtils.fromJson(content, EvalRun.class); + } + + @Override + public List list() throws Exception { + return list(null, null); + } + + @Override + public List list(String actionRef, String datasetId) throws Exception { + // Reload index to get latest data + loadIndex(); + + return indexCache.values().stream().filter(key -> { + if (actionRef != null && !actionRef.equals(key.getActionRef())) { + return false; + } + if (datasetId != null && !datasetId.equals(key.getDatasetId())) { + return false; + } + return true; + }).sorted((a, b) -> { + // Sort by createdAt descending + if (a.getCreatedAt() == null) + return 1; + if (b.getCreatedAt() == null) + return -1; + return b.getCreatedAt().compareTo(a.getCreatedAt()); + }).collect(Collectors.toList()); + } + + @Override + public void delete(String evalRunId) throws Exception { + // Delete the data file + Path evalRunFile = getEvalRunFile(evalRunId); + Files.deleteIfExists(evalRunFile); + + // Update the index + indexCache.remove(evalRunId); + saveIndex(indexCache); + + logger.info("Deleted eval run: {}", evalRunId); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/RunEvaluationRequest.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/RunEvaluationRequest.java new file mode 100644 index 0000000000..66edb08753 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/RunEvaluationRequest.java @@ -0,0 +1,205 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Request to run a new evaluation. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class RunEvaluationRequest { + + /** + * The data source for evaluation. + */ + @JsonProperty("dataSource") + private DataSource dataSource; + + /** + * The action to evaluate (e.g., "/flow/myFlow"). + */ + @JsonProperty("targetAction") + private String targetAction; + + /** + * The evaluators to run. + */ + @JsonProperty("evaluators") + private List evaluators; + + /** + * Options for the evaluation. + */ + @JsonProperty("options") + private EvaluationOptions options; + + public RunEvaluationRequest() { + } + + private RunEvaluationRequest(Builder builder) { + this.dataSource = builder.dataSource; + this.targetAction = builder.targetAction; + this.evaluators = builder.evaluators; + this.options = builder.options; + } + + public static Builder builder() { + return new Builder(); + } + + public DataSource getDataSource() { + return dataSource; + } + + public void setDataSource(DataSource dataSource) { + this.dataSource = dataSource; + } + + public String getTargetAction() { + return targetAction; + } + + public void setTargetAction(String targetAction) { + this.targetAction = targetAction; + } + + public List getEvaluators() { + return evaluators; + } + + public void setEvaluators(List evaluators) { + this.evaluators = evaluators; + } + + public EvaluationOptions getOptions() { + return options; + } + + public void setOptions(EvaluationOptions options) { + this.options = options; + } + + /** + * Data source for evaluation - either a dataset ID or inline data. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class DataSource { + @JsonProperty("datasetId") + private String datasetId; + + @JsonProperty("data") + private List data; + + public DataSource() { + } + + public String getDatasetId() { + return datasetId; + } + + public void setDatasetId(String datasetId) { + this.datasetId = datasetId; + } + + public List getData() { + return data; + } + + public void setData(List data) { + this.data = data; + } + } + + /** + * Options for evaluation. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class EvaluationOptions { + @JsonProperty("context") + private String context; + + @JsonProperty("actionConfig") + private Object actionConfig; + + @JsonProperty("batchSize") + private Integer batchSize; + + public EvaluationOptions() { + } + + public String getContext() { + return context; + } + + public void setContext(String context) { + this.context = context; + } + + public Object getActionConfig() { + return actionConfig; + } + + public void setActionConfig(Object actionConfig) { + this.actionConfig = actionConfig; + } + + public Integer getBatchSize() { + return batchSize; + } + + public void setBatchSize(Integer batchSize) { + this.batchSize = batchSize; + } + } + + public static class Builder { + private DataSource dataSource; + private String targetAction; + private List evaluators; + private EvaluationOptions options; + + public Builder dataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + public Builder targetAction(String targetAction) { + this.targetAction = targetAction; + return this; + } + + public Builder evaluators(List evaluators) { + this.evaluators = evaluators; + return this; + } + + public Builder options(EvaluationOptions options) { + this.options = options; + return this; + } + + public RunEvaluationRequest build() { + return new RunEvaluationRequest(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/Score.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/Score.java new file mode 100644 index 0000000000..c7c9639dbd --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/Score.java @@ -0,0 +1,191 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Score represents the result of an evaluation. + * + *

+ * A score can contain a numeric value, a string value, or a boolean value, + * along with an optional status and detailed information about the evaluation. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class Score { + + /** + * Optional identifier to differentiate scores in multi-score evaluations. + */ + @JsonProperty("id") + private String id; + + /** + * The numeric score value. Can be null if using string or boolean score. + */ + @JsonProperty("score") + private Object score; + + /** + * The status of the evaluation (PASS, FAIL, UNKNOWN). + */ + @JsonProperty("status") + private EvalStatus status; + + /** + * Error message if the evaluation failed. + */ + @JsonProperty("error") + private String error; + + /** + * Additional details about the evaluation including reasoning. + */ + @JsonProperty("details") + private ScoreDetails details; + + public Score() { + } + + private Score(Builder builder) { + this.id = builder.id; + this.score = builder.score; + this.status = builder.status; + this.error = builder.error; + this.details = builder.details; + } + + public static Builder builder() { + return new Builder(); + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public Object getScore() { + return score; + } + + public void setScore(Object score) { + this.score = score; + } + + public Double getScoreAsDouble() { + if (score instanceof Number) { + return ((Number) score).doubleValue(); + } + return null; + } + + public String getScoreAsString() { + if (score instanceof String) { + return (String) score; + } + return score != null ? score.toString() : null; + } + + public Boolean getScoreAsBoolean() { + if (score instanceof Boolean) { + return (Boolean) score; + } + return null; + } + + public EvalStatus getStatus() { + return status; + } + + public void setStatus(EvalStatus status) { + this.status = status; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public ScoreDetails getDetails() { + return details; + } + + public void setDetails(ScoreDetails details) { + this.details = details; + } + + public static class Builder { + private String id; + private Object score; + private EvalStatus status; + private String error; + private ScoreDetails details; + + public Builder id(String id) { + this.id = id; + return this; + } + + public Builder score(double score) { + this.score = score; + return this; + } + + public Builder score(String score) { + this.score = score; + return this; + } + + public Builder score(boolean score) { + this.score = score; + return this; + } + + public Builder status(EvalStatus status) { + this.status = status; + return this; + } + + public Builder error(String error) { + this.error = error; + return this; + } + + public Builder details(ScoreDetails details) { + this.details = details; + return this; + } + + public Builder reasoning(String reasoning) { + this.details = ScoreDetails.builder().reasoning(reasoning).build(); + return this; + } + + public Score build() { + return new Score(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/ScoreDetails.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/ScoreDetails.java new file mode 100644 index 0000000000..2877d2332b --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/ScoreDetails.java @@ -0,0 +1,91 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.HashMap; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonAnyGetter; +import com.fasterxml.jackson.annotation.JsonAnySetter; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Details about an evaluation score, including reasoning. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ScoreDetails { + + @JsonProperty("reasoning") + private String reasoning; + + /** + * Additional properties that may be included in the details. + */ + private Map additionalProperties = new HashMap<>(); + + public ScoreDetails() { + } + + private ScoreDetails(Builder builder) { + this.reasoning = builder.reasoning; + this.additionalProperties = builder.additionalProperties; + } + + public static Builder builder() { + return new Builder(); + } + + public String getReasoning() { + return reasoning; + } + + public void setReasoning(String reasoning) { + this.reasoning = reasoning; + } + + @JsonAnyGetter + public Map getAdditionalProperties() { + return additionalProperties; + } + + @JsonAnySetter + public void setAdditionalProperty(String name, Object value) { + this.additionalProperties.put(name, value); + } + + public static class Builder { + private String reasoning; + private Map additionalProperties = new HashMap<>(); + + public Builder reasoning(String reasoning) { + this.reasoning = reasoning; + return this; + } + + public Builder additionalProperty(String name, Object value) { + this.additionalProperties.put(name, value); + return this; + } + + public ScoreDetails build() { + return new ScoreDetails(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/evaluation/UpdateDatasetRequest.java b/java/ai/src/main/java/com/google/genkit/ai/evaluation/UpdateDatasetRequest.java new file mode 100644 index 0000000000..dddc96182c --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/evaluation/UpdateDatasetRequest.java @@ -0,0 +1,154 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.evaluation; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; + +/** + * Request to update an existing dataset. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class UpdateDatasetRequest { + + /** + * The ID of the dataset to update. + */ + @JsonProperty("datasetId") + private String datasetId; + + /** + * New dataset samples (replaces existing data). + */ + @JsonProperty("data") + private List data; + + /** + * New schema for the dataset. + */ + @JsonProperty("schema") + private JsonNode schema; + + /** + * New metric references. + */ + @JsonProperty("metricRefs") + private List metricRefs; + + /** + * New target action. + */ + @JsonProperty("targetAction") + private String targetAction; + + public UpdateDatasetRequest() { + } + + private UpdateDatasetRequest(Builder builder) { + this.datasetId = builder.datasetId; + this.data = builder.data; + this.schema = builder.schema; + this.metricRefs = builder.metricRefs; + this.targetAction = builder.targetAction; + } + + public static Builder builder() { + return new Builder(); + } + + public String getDatasetId() { + return datasetId; + } + + public void setDatasetId(String datasetId) { + this.datasetId = datasetId; + } + + public List getData() { + return data; + } + + public void setData(List data) { + this.data = data; + } + + public JsonNode getSchema() { + return schema; + } + + public void setSchema(JsonNode schema) { + this.schema = schema; + } + + public List getMetricRefs() { + return metricRefs; + } + + public void setMetricRefs(List metricRefs) { + this.metricRefs = metricRefs; + } + + public String getTargetAction() { + return targetAction; + } + + public void setTargetAction(String targetAction) { + this.targetAction = targetAction; + } + + public static class Builder { + private String datasetId; + private List data; + private JsonNode schema; + private List metricRefs; + private String targetAction; + + public Builder datasetId(String datasetId) { + this.datasetId = datasetId; + return this; + } + + public Builder data(List data) { + this.data = data; + return this; + } + + public Builder schema(JsonNode schema) { + this.schema = schema; + return this; + } + + public Builder metricRefs(List metricRefs) { + this.metricRefs = metricRefs; + return this; + } + + public Builder targetAction(String targetAction) { + this.targetAction = targetAction; + return this; + } + + public UpdateDatasetRequest build() { + return new UpdateDatasetRequest(this); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/session/Chat.java b/java/ai/src/main/java/com/google/genkit/ai/session/Chat.java new file mode 100644 index 0000000000..9ee57a776d --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/session/Chat.java @@ -0,0 +1,1019 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import com.google.genkit.ai.*; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.JsonUtils; +import com.google.genkit.core.Registry; + +/** + * Chat represents a conversation within a session thread. + * + *

+ * Chat provides a simple interface for multi-turn conversations with automatic + * history management. Messages are persisted to the session store after each + * interaction. + * + *

+ * Example usage: + * + *

{@code
+ * // Simple chat
+ * Chat chat = session.chat();
+ * ModelResponse response = chat.send("Hello!");
+ * 
+ * // Chat with system prompt
+ * Chat chat = session
+ * 		.chat(ChatOptions.builder().model("openai/gpt-4o").system("You are a helpful assistant.").build());
+ * 
+ * // Multi-turn conversation
+ * chat.send("What is the capital of France?");
+ * chat.send("And what about Germany?"); // Context is preserved
+ * }
+ * + * @param + * the type of the session state + */ +public class Chat { + + private static final String PREAMBLE_KEY = "preamble"; + + private final Session session; + private final String threadName; + private final ChatOptions originalOptions; + private final Registry registry; + private final Map effectiveAgentRegistry; + private List history; + + /** Pending interrupt requests from the last send. */ + private List pendingInterrupts; + + /** Current agent context (mutable for handoffs). */ + private String currentAgentName; + private String currentSystem; + private String currentModel; + private List> currentTools; + + /** + * Creates a new Chat instance. + * + * @param session + * the parent session + * @param threadName + * the thread name + * @param options + * the chat options + * @param registry + * the Genkit registry + * @param sessionAgentRegistry + * the agent registry from the session (may be null) + */ + Chat(Session session, String threadName, ChatOptions options, Registry registry, + Map sessionAgentRegistry) { + this.session = session; + this.threadName = threadName; + this.originalOptions = options; + this.registry = registry; + // Use agent registry from options if provided, otherwise fall back to session's + // registry + this.effectiveAgentRegistry = options.getAgentRegistry() != null + ? options.getAgentRegistry() + : sessionAgentRegistry; + this.history = new ArrayList<>(session.getMessages(threadName)); + this.pendingInterrupts = new ArrayList<>(); + + // Initialize current context from options + this.currentAgentName = null; + this.currentSystem = options.getSystem(); + this.currentModel = options.getModel(); + this.currentTools = options.getTools(); + } + + /** + * Sends a message and gets a response. + * + *

+ * This method: + *

    + *
  1. Adds the user message to history
  2. + *
  3. Builds a request with all conversation history
  4. + *
  5. Sends to the model and gets a response
  6. + *
  7. Adds the model response to history
  8. + *
  9. Persists the updated history
  10. + *
+ * + * @param text + * the user message + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse send(String text) throws GenkitException { + return send(Message.user(text)); + } + + /** + * Sends a message and gets a response. + * + * @param message + * the message to send + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse send(Message message) throws GenkitException { + return send(message, null); + } + + /** + * Sends a message with send options and gets a response. + * + * @param text + * the user message + * @param sendOptions + * additional options for this send + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse send(String text, SendOptions sendOptions) throws GenkitException { + return send(Message.user(text), sendOptions); + } + + /** + * Sends a message with send options and gets a response. + * + * @param message + * the message to send + * @param sendOptions + * additional options for this send + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse send(Message message, SendOptions sendOptions) throws GenkitException { + // Clear any pending interrupts from previous send + pendingInterrupts.clear(); + + // Check if we're resuming from an interrupt + ResumeOptions resumeOptions = (sendOptions != null) ? sendOptions.getResumeOptions() : null; + if (resumeOptions != null) { + return resumeFromInterrupt(resumeOptions, sendOptions); + } + + // Add user message to history + history.add(message); + + return executeGenerationLoop(sendOptions); + } + + /** + * Resumes generation after an interrupt. + * + * @param resumeOptions + * the resume options containing tool responses + * @param sendOptions + * additional send options + * @return the model response + * @throws GenkitException + * if generation fails + */ + private ModelResponse resumeFromInterrupt(ResumeOptions resumeOptions, SendOptions sendOptions) + throws GenkitException { + // Add tool responses to history + if (resumeOptions.getRespond() != null && !resumeOptions.getRespond().isEmpty()) { + List responseParts = new ArrayList<>(); + for (ToolResponse response : resumeOptions.getRespond()) { + Part part = new Part(); + part.setToolResponse(response); + responseParts.add(part); + } + Message toolResponseMessage = new Message(); + toolResponseMessage.setRole(Role.TOOL); + toolResponseMessage.setContent(responseParts); + history.add(toolResponseMessage); + } + + // Handle restart requests by re-executing those tools + if (resumeOptions.getRestart() != null && !resumeOptions.getRestart().isEmpty()) { + ActionContext ctx = new ActionContext(registry); + List restartParts = executeToolsWithInterruptHandling(ctx, resumeOptions.getRestart(), sendOptions); + + // Check if any restarts also triggered interrupts + if (!pendingInterrupts.isEmpty()) { + persistHistory(); + return createInterruptResponse(); + } + + Message toolResponseMessage = new Message(); + toolResponseMessage.setRole(Role.TOOL); + toolResponseMessage.setContent(restartParts); + history.add(toolResponseMessage); + } + + return executeGenerationLoop(sendOptions); + } + + /** + * Executes the main generation loop with tool handling. + */ + private ModelResponse executeGenerationLoop(SendOptions sendOptions) throws GenkitException { + // Build the request + ModelRequest request = buildRequest(sendOptions); + + // Get the model + String modelName = resolveModelName(sendOptions); + if (modelName == null) { + throw new GenkitException("No model specified. Set model in ChatOptions or SendOptions."); + } + + Model model = getModel(modelName); + ActionContext ctx = new ActionContext(registry); + + // Handle tool execution loop with session context + int maxTurns = resolveMaxTurns(sendOptions); + int turn = 0; + + while (turn < maxTurns) { + // Make request effectively final for lambda + final ModelRequest finalRequest = request; + ModelResponse response; + try { + response = SessionContext.runWithSession(session, () -> model.run(ctx, finalRequest)); + } catch (GenkitException e) { + throw e; + } catch (Exception e) { + throw new GenkitException("Error during model execution", e); + } + + // Check for tool requests + List toolRequests = extractToolRequests(response); + if (toolRequests.isEmpty()) { + // No tool calls, add response to history and persist + Message responseMessage = response.getMessage(); + if (responseMessage != null) { + history.add(responseMessage); + persistHistory(); + } + return response; + } + + // Execute tools with interrupt handling + List toolResponseParts = executeToolsWithInterruptHandling(ctx, toolRequests, sendOptions); + + // Add assistant message with tool requests to history + Message assistantMessage = response.getMessage(); + if (assistantMessage != null) { + history.add(assistantMessage); + } + + // Check if any tools triggered interrupts + if (!pendingInterrupts.isEmpty()) { + persistHistory(); + return createInterruptResponse(); + } + + // Add tool response message to history + Message toolResponseMessage = new Message(); + toolResponseMessage.setRole(Role.TOOL); + toolResponseMessage.setContent(toolResponseParts); + history.add(toolResponseMessage); + + // Rebuild request with updated history + request = buildRequest(sendOptions); + turn++; + } + + throw new GenkitException("Max tool execution turns (" + maxTurns + ") exceeded"); + } + + /** + * Sends a message with streaming response. + * + * @param text + * the user message + * @param streamCallback + * callback for each response chunk + * @return the final aggregated response + * @throws GenkitException + * if generation fails + */ + public ModelResponse sendStream(String text, Consumer streamCallback) throws GenkitException { + return sendStream(Message.user(text), null, streamCallback); + } + + /** + * Sends a message with streaming response. + * + * @param message + * the message to send + * @param sendOptions + * additional options for this send + * @param streamCallback + * callback for each response chunk + * @return the final aggregated response + * @throws GenkitException + * if generation fails + */ + public ModelResponse sendStream(Message message, SendOptions sendOptions, + Consumer streamCallback) throws GenkitException { + // Add user message to history + history.add(message); + + // Build the request + ModelRequest request = buildRequest(sendOptions); + + // Get the model + String modelName = resolveModelName(sendOptions); + if (modelName == null) { + throw new GenkitException("No model specified. Set model in ChatOptions or SendOptions."); + } + + Model model = getModel(modelName); + if (!model.supportsStreaming()) { + throw new GenkitException("Model " + modelName + " does not support streaming"); + } + + ActionContext ctx = new ActionContext(registry); + ModelResponse response = model.run(ctx, request, streamCallback); + + // Add response to history and persist + Message responseMessage = response.getMessage(); + if (responseMessage != null) { + history.add(responseMessage); + persistHistory(); + } + + return response; + } + + /** + * Gets the current conversation history. + * + * @return a copy of the message history + */ + public List getHistory() { + return new ArrayList<>(history); + } + + /** + * Gets the session. + * + * @return the parent session + */ + public Session getSession() { + return session; + } + + /** + * Gets the thread name. + * + * @return the thread name + */ + public String getThreadName() { + return threadName; + } + + /** + * Gets the pending interrupt requests from the last send. + * + *

+ * If the last {@link #send} call returned with interrupts, this method returns + * the list of pending interrupts that need to be resolved before continuing. + * + * @return the list of pending interrupt requests (empty if none) + */ + public List getPendingInterrupts() { + return new ArrayList<>(pendingInterrupts); + } + + /** + * Checks if there are pending interrupts. + * + * @return true if there are pending interrupts + */ + public boolean hasPendingInterrupts() { + return !pendingInterrupts.isEmpty(); + } + + /** + * Gets the current agent name. + * + *

+ * Returns null if no agent handoff has occurred, otherwise returns the name of + * the agent that the conversation was most recently handed off to. + * + * @return the current agent name, or null if no handoff has occurred + */ + public String getCurrentAgentName() { + return currentAgentName; + } + + /** + * Builds a ModelRequest from current history and options. + */ + private ModelRequest buildRequest(SendOptions sendOptions) { + ModelRequest.Builder builder = ModelRequest.builder(); + + // Build messages list with system prompt (preamble) + List messages = new ArrayList<>(); + + // Add system prompt if specified (use current context for handoffs) + String systemPrompt = currentSystem; + if (systemPrompt != null && !systemPrompt.isEmpty()) { + Message systemMessage = Message.system(systemPrompt); + // Mark as preamble in metadata + Map metadata = new HashMap<>(); + metadata.put(PREAMBLE_KEY, true); + systemMessage.setMetadata(metadata); + messages.add(systemMessage); + } + + // Add conversation history (excluding any existing preamble) + for (Message msg : history) { + if (!isPreamble(msg)) { + messages.add(msg); + } + } + + builder.messages(messages); + + // Add tools (use current context for handoffs) + List> tools = resolveTools(sendOptions); + if (tools != null && !tools.isEmpty()) { + List toolDefs = new ArrayList<>(); + for (Tool tool : tools) { + toolDefs.add(tool.getDefinition()); + } + builder.tools(toolDefs); + } + + // Add config + if (originalOptions.getConfig() != null) { + builder.config(convertConfigToMap(originalOptions.getConfig())); + } + + // Add output config + if (originalOptions.getOutput() != null) { + builder.output(originalOptions.getOutput()); + } + + return builder.build(); + } + + /** + * Checks if a message is a preamble (system prompt). + */ + private boolean isPreamble(Message message) { + if (message.getMetadata() == null) { + return false; + } + Object preamble = message.getMetadata().get(PREAMBLE_KEY); + return Boolean.TRUE.equals(preamble); + } + + /** + * Resolves the model name from options. + */ + private String resolveModelName(SendOptions sendOptions) { + if (sendOptions != null && sendOptions.getModel() != null) { + return sendOptions.getModel(); + } + // Use current context (which may have been updated by handoff) + return currentModel; + } + + /** + * Resolves the max turns from options. + */ + private int resolveMaxTurns(SendOptions sendOptions) { + if (sendOptions != null && sendOptions.getMaxTurns() != null) { + return sendOptions.getMaxTurns(); + } + if (originalOptions.getMaxTurns() != null) { + return originalOptions.getMaxTurns(); + } + return 5; // Default + } + + /** + * Resolves the tools from options. + */ + private List> resolveTools(SendOptions sendOptions) { + if (sendOptions != null && sendOptions.getTools() != null) { + return sendOptions.getTools(); + } + // Use current context (which may have been updated by handoff) + return currentTools; + } + + /** + * Gets a model by name from the registry. + */ + private Model getModel(String name) { + Action action = registry.lookupAction(ActionType.MODEL, name); + if (action == null) { + throw new GenkitException("Model not found: " + name); + } + return (Model) action; + } + + /** + * Extracts tool requests from a model response. + */ + private List extractToolRequests(ModelResponse response) { + List requests = new ArrayList<>(); + if (response.getCandidates() != null) { + for (Candidate candidate : response.getCandidates()) { + if (candidate.getMessage() != null && candidate.getMessage().getContent() != null) { + for (Part part : candidate.getMessage().getContent()) { + if (part.getToolRequest() != null) { + requests.add(part.getToolRequest()); + } + } + } + } + } + return requests; + } + + /** + * Executes tools and returns response parts. + */ + private List executeTools(ActionContext ctx, List toolRequests, SendOptions sendOptions) { + List responseParts = new ArrayList<>(); + List> tools = resolveTools(sendOptions); + + for (ToolRequest toolRequest : toolRequests) { + String toolName = toolRequest.getName(); + Object toolInput = toolRequest.getInput(); + + Tool tool = findTool(toolName, tools); + if (tool == null) { + Part errorPart = new Part(); + ToolResponse errorResponse = new ToolResponse(toolRequest.getRef(), toolName, + Collections.singletonMap("error", "Tool not found: " + toolName)); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + continue; + } + + try { + @SuppressWarnings("unchecked") + Tool typedTool = (Tool) tool; + + // Convert the input to the expected type if necessary + final Object convertedInput; + Class inputClass = typedTool.getInputClass(); + if (inputClass != null && toolInput != null && !inputClass.isInstance(toolInput)) { + convertedInput = JsonUtils.convert(toolInput, inputClass); + } else { + convertedInput = toolInput; + } + + Object result = SessionContext.runWithSession(session, () -> typedTool.run(ctx, convertedInput)); + + Part responsePart = new Part(); + ToolResponse toolResponse = new ToolResponse(toolRequest.getRef(), toolName, result); + responsePart.setToolResponse(toolResponse); + responseParts.add(responsePart); + } catch (Exception e) { + Part errorPart = new Part(); + ToolResponse errorResponse = new ToolResponse(toolRequest.getRef(), toolName, + Collections.singletonMap("error", "Tool execution failed: " + e.getMessage())); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + } + } + + return responseParts; + } + + /** + * Executes tools with interrupt handling and returns response parts. + * + *

+ * When a tool throws {@link ToolInterruptException}, the interrupt is captured + * and added to the pending interrupts list. The tool execution continues for + * other tools, and an interrupt response is returned after all tools have been + * processed. + * + *

+ * When a tool throws {@link AgentHandoffException}, the chat context is + * switched to the target agent (system prompt, tools, model), enabling + * multi-agent conversations. + */ + private List executeToolsWithInterruptHandling(ActionContext ctx, List toolRequests, + SendOptions sendOptions) { + List responseParts = new ArrayList<>(); + List> tools = resolveTools(sendOptions); + + for (ToolRequest toolRequest : toolRequests) { + String toolName = toolRequest.getName(); + Object toolInput = toolRequest.getInput(); + + Tool tool = findTool(toolName, tools); + if (tool == null) { + Part errorPart = new Part(); + ToolResponse errorResponse = new ToolResponse(toolRequest.getRef(), toolName, + Collections.singletonMap("error", "Tool not found: " + toolName)); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + continue; + } + + try { + @SuppressWarnings("unchecked") + Tool typedTool = (Tool) tool; + + // Convert the input to the expected type if necessary + final Object convertedInput; + Class inputClass = typedTool.getInputClass(); + if (inputClass != null && toolInput != null && !inputClass.isInstance(toolInput)) { + convertedInput = JsonUtils.convert(toolInput, inputClass); + } else { + convertedInput = toolInput; + } + + Object result = SessionContext.runWithSession(session, () -> typedTool.run(ctx, convertedInput)); + + Part responsePart = new Part(); + ToolResponse toolResponse = new ToolResponse(toolRequest.getRef(), toolName, result); + responsePart.setToolResponse(toolResponse); + responseParts.add(responsePart); + } catch (AgentHandoffException e) { + // Handle agent handoff - switch context to the target agent + handleAgentHandoff(e); + + // Add a response indicating the handoff + Part handoffPart = new Part(); + Map handoffOutput = new HashMap<>(); + handoffOutput.put("transferred", true); + handoffOutput.put("transferredTo", e.getTargetAgentName()); + handoffOutput.put("message", "Conversation transferred to " + e.getTargetAgentName()); + ToolResponse handoffResponse = new ToolResponse(toolRequest.getRef(), toolName, handoffOutput); + handoffPart.setToolResponse(handoffResponse); + responseParts.add(handoffPart); + } catch (ToolInterruptException e) { + // Capture the interrupt + InterruptRequest interruptRequest = new InterruptRequest(toolRequest, e.getMetadata()); + pendingInterrupts.add(interruptRequest); + + // Add a placeholder response indicating interruption + Part interruptPart = new Part(); + Map interruptOutput = new HashMap<>(); + interruptOutput.put("__interrupt", true); + interruptOutput.put("metadata", e.getMetadata()); + ToolResponse interruptResponse = new ToolResponse(toolRequest.getRef(), toolName, interruptOutput); + interruptPart.setToolResponse(interruptResponse); + responseParts.add(interruptPart); + } catch (Exception e) { + Part errorPart = new Part(); + ToolResponse errorResponse = new ToolResponse(toolRequest.getRef(), toolName, + Collections.singletonMap("error", "Tool execution failed: " + e.getMessage())); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + } + } + + return responseParts; + } + + /** + * Handles an agent handoff by switching the chat context. + */ + private void handleAgentHandoff(AgentHandoffException handoff) { + AgentConfig targetConfig = handoff.getTargetAgentConfig(); + currentAgentName = handoff.getTargetAgentName(); + + // Update system prompt + if (targetConfig.getSystem() != null) { + currentSystem = targetConfig.getSystem(); + } + + // Update model if specified + if (targetConfig.getModel() != null) { + currentModel = targetConfig.getModel(); + } + + // Update tools - include the agent's tools plus sub-agent tools + List> newTools = new ArrayList<>(); + if (targetConfig.getTools() != null) { + newTools.addAll(targetConfig.getTools()); + } + + // Add sub-agents as tools if agent registry is available + if (targetConfig.getAgents() != null && effectiveAgentRegistry != null) { + for (AgentConfig subAgentConfig : targetConfig.getAgents()) { + Agent subAgent = effectiveAgentRegistry.get(subAgentConfig.getName()); + if (subAgent != null) { + newTools.add(subAgent.asTool()); + } + } + } + + currentTools = newTools; + } + + /** + * Creates a response indicating the generation was interrupted. + */ + private ModelResponse createInterruptResponse() { + ModelResponse response = new ModelResponse(); + + // Create a message indicating interruption + Message interruptMessage = new Message(); + interruptMessage.setRole(Role.MODEL); + + Part textPart = new Part(); + textPart.setText("[Generation interrupted - awaiting user input]"); + interruptMessage.setContent(List.of(textPart)); + + // Add interrupt metadata + Map metadata = new HashMap<>(); + metadata.put("interrupted", true); + metadata.put("interruptCount", pendingInterrupts.size()); + List> interruptData = new ArrayList<>(); + for (InterruptRequest interrupt : pendingInterrupts) { + Map data = new HashMap<>(); + data.put("toolName", interrupt.getToolRequest().getName()); + data.put("toolRef", interrupt.getToolRequest().getRef()); + data.put("metadata", interrupt.getMetadata()); + interruptData.add(data); + } + metadata.put("interrupts", interruptData); + interruptMessage.setMetadata(metadata); + + // Create candidate + Candidate candidate = new Candidate(); + candidate.setMessage(interruptMessage); + candidate.setFinishReason(FinishReason.OTHER); + + response.setCandidates(List.of(candidate)); + + return response; + } + + /** + * Finds a tool by name. + */ + private Tool findTool(String toolName, List> tools) { + if (tools != null) { + for (Tool tool : tools) { + if (tool.getName().equals(toolName)) { + return tool; + } + } + } + + // Try registry + Action action = registry.lookupAction(ActionType.TOOL, toolName); + if (action instanceof Tool) { + return (Tool) action; + } + + return null; + } + + /** + * Persists the current history to the session store. + */ + private void persistHistory() { + session.updateMessages(threadName, history).join(); + } + + /** + * Converts GenerationConfig to a Map for the ModelRequest. + */ + private Map convertConfigToMap(GenerationConfig config) { + Map configMap = new HashMap<>(); + if (config.getTemperature() != null) { + configMap.put("temperature", config.getTemperature()); + } + if (config.getMaxOutputTokens() != null) { + configMap.put("maxOutputTokens", config.getMaxOutputTokens()); + } + if (config.getTopP() != null) { + configMap.put("topP", config.getTopP()); + } + if (config.getTopK() != null) { + configMap.put("topK", config.getTopK()); + } + if (config.getStopSequences() != null) { + configMap.put("stopSequences", config.getStopSequences()); + } + if (config.getPresencePenalty() != null) { + configMap.put("presencePenalty", config.getPresencePenalty()); + } + if (config.getFrequencyPenalty() != null) { + configMap.put("frequencyPenalty", config.getFrequencyPenalty()); + } + if (config.getSeed() != null) { + configMap.put("seed", config.getSeed()); + } + if (config.getCustom() != null) { + configMap.putAll(config.getCustom()); + } + return configMap; + } + + /** + * Options for individual send operations. + */ + public static class SendOptions { + private String model; + private List> tools; + private Integer maxTurns; + private ResumeOptions resumeOptions; + + /** + * Default constructor. + */ + public SendOptions() { + } + + /** + * Gets the model name. + * + * @return the model name + */ + public String getModel() { + return model; + } + + /** + * Sets the model name. + * + * @param model + * the model name + */ + public void setModel(String model) { + this.model = model; + } + + /** + * Gets the tools. + * + * @return the tools + */ + public List> getTools() { + return tools; + } + + /** + * Sets the tools. + * + * @param tools + * the tools + */ + public void setTools(List> tools) { + this.tools = tools; + } + + /** + * Gets the max turns. + * + * @return the max turns + */ + public Integer getMaxTurns() { + return maxTurns; + } + + /** + * Sets the max turns. + * + * @param maxTurns + * the max turns + */ + public void setMaxTurns(Integer maxTurns) { + this.maxTurns = maxTurns; + } + + /** + * Gets the resume options. + * + * @return the resume options + */ + public ResumeOptions getResumeOptions() { + return resumeOptions; + } + + /** + * Sets the resume options. + * + * @param resumeOptions + * the resume options + */ + public void setResumeOptions(ResumeOptions resumeOptions) { + this.resumeOptions = resumeOptions; + } + + /** + * Creates a builder for SendOptions. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for SendOptions. + */ + public static class Builder { + private String model; + private List> tools; + private Integer maxTurns; + private ResumeOptions resumeOptions; + + /** + * Sets the model name. + * + * @param model + * the model name + * @return this builder + */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Sets the tools. + * + * @param tools + * the tools + * @return this builder + */ + public Builder tools(List> tools) { + this.tools = tools; + return this; + } + + /** + * Sets the max turns. + * + * @param maxTurns + * the max turns + * @return this builder + */ + public Builder maxTurns(Integer maxTurns) { + this.maxTurns = maxTurns; + return this; + } + + /** + * Sets the resume options for resuming after an interrupt. + * + * @param resumeOptions + * the resume options + * @return this builder + */ + public Builder resumeOptions(ResumeOptions resumeOptions) { + this.resumeOptions = resumeOptions; + return this; + } + + /** + * Builds the SendOptions. + * + * @return the built SendOptions + */ + public SendOptions build() { + SendOptions options = new SendOptions(); + options.setModel(model); + options.setTools(tools); + options.setMaxTurns(maxTurns); + options.setResumeOptions(resumeOptions); + return options; + } + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/session/ChatOptions.java b/java/ai/src/main/java/com/google/genkit/ai/session/ChatOptions.java new file mode 100644 index 0000000000..67042c5cdf --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/session/ChatOptions.java @@ -0,0 +1,345 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import java.util.List; +import java.util.Map; + +import com.google.genkit.ai.Agent; +import com.google.genkit.ai.GenerationConfig; +import com.google.genkit.ai.OutputConfig; +import com.google.genkit.ai.Tool; + +/** + * ChatOptions provides configuration options for creating a Chat instance. + * + * @param + * the type of the session state + */ +public class ChatOptions { + + private String model; + private String system; + private List> tools; + private OutputConfig output; + private GenerationConfig config; + private Map context; + private Integer maxTurns; + private Map agentRegistry; + + /** + * Default constructor. + */ + public ChatOptions() { + } + + /** + * Gets the model name. + * + * @return the model name + */ + public String getModel() { + return model; + } + + /** + * Sets the model name. + * + * @param model + * the model name + */ + public void setModel(String model) { + this.model = model; + } + + /** + * Gets the system prompt. + * + * @return the system prompt + */ + public String getSystem() { + return system; + } + + /** + * Sets the system prompt. + * + * @param system + * the system prompt + */ + public void setSystem(String system) { + this.system = system; + } + + /** + * Gets the available tools. + * + * @return the tools + */ + public List> getTools() { + return tools; + } + + /** + * Sets the available tools. + * + * @param tools + * the tools + */ + public void setTools(List> tools) { + this.tools = tools; + } + + /** + * Gets the output configuration. + * + * @return the output configuration + */ + public OutputConfig getOutput() { + return output; + } + + /** + * Sets the output configuration. + * + * @param output + * the output configuration + */ + public void setOutput(OutputConfig output) { + this.output = output; + } + + /** + * Gets the generation configuration. + * + * @return the generation configuration + */ + public GenerationConfig getConfig() { + return config; + } + + /** + * Sets the generation configuration. + * + * @param config + * the generation configuration + */ + public void setConfig(GenerationConfig config) { + this.config = config; + } + + /** + * Gets the additional context. + * + * @return the context + */ + public Map getContext() { + return context; + } + + /** + * Sets the additional context. + * + * @param context + * the context + */ + public void setContext(Map context) { + this.context = context; + } + + /** + * Gets the maximum conversation turns. + * + * @return the max turns + */ + public Integer getMaxTurns() { + return maxTurns; + } + + /** + * Sets the maximum conversation turns. + * + * @param maxTurns + * the max turns + */ + public void setMaxTurns(Integer maxTurns) { + this.maxTurns = maxTurns; + } + + /** + * Gets the agent registry for multi-agent handoffs. + * + * @return the agent registry + */ + public Map getAgentRegistry() { + return agentRegistry; + } + + /** + * Sets the agent registry for multi-agent handoffs. + * + * @param agentRegistry + * the agent registry + */ + public void setAgentRegistry(Map agentRegistry) { + this.agentRegistry = agentRegistry; + } + + /** + * Creates a builder for ChatOptions. + * + * @param + * the state type + * @return a new builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Builder for ChatOptions. + * + * @param + * the state type + */ + public static class Builder { + private String model; + private String system; + private List> tools; + private OutputConfig output; + private GenerationConfig config; + private Map context; + private Integer maxTurns; + private Map agentRegistry; + + /** + * Sets the model name. + * + * @param model + * the model name + * @return this builder + */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Sets the system prompt. + * + * @param system + * the system prompt + * @return this builder + */ + public Builder system(String system) { + this.system = system; + return this; + } + + /** + * Sets the available tools. + * + * @param tools + * the tools + * @return this builder + */ + public Builder tools(List> tools) { + this.tools = tools; + return this; + } + + /** + * Sets the output configuration. + * + * @param output + * the output configuration + * @return this builder + */ + public Builder output(OutputConfig output) { + this.output = output; + return this; + } + + /** + * Sets the generation configuration. + * + * @param config + * the generation configuration + * @return this builder + */ + public Builder config(GenerationConfig config) { + this.config = config; + return this; + } + + /** + * Sets the additional context. + * + * @param context + * the context + * @return this builder + */ + public Builder context(Map context) { + this.context = context; + return this; + } + + /** + * Sets the maximum conversation turns. + * + * @param maxTurns + * the max turns + * @return this builder + */ + public Builder maxTurns(Integer maxTurns) { + this.maxTurns = maxTurns; + return this; + } + + /** + * Sets the agent registry for multi-agent handoffs. + * + * @param agentRegistry + * the agent registry + * @return this builder + */ + public Builder agentRegistry(Map agentRegistry) { + this.agentRegistry = agentRegistry; + return this; + } + + /** + * Builds the ChatOptions. + * + * @return the built ChatOptions + */ + public ChatOptions build() { + ChatOptions options = new ChatOptions<>(); + options.setModel(model); + options.setSystem(system); + options.setTools(tools); + options.setOutput(output); + options.setConfig(config); + options.setContext(context); + options.setMaxTurns(maxTurns); + options.setAgentRegistry(agentRegistry); + return options; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/session/InMemorySessionStore.java b/java/ai/src/main/java/com/google/genkit/ai/session/InMemorySessionStore.java new file mode 100644 index 0000000000..fb3bebe959 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/session/InMemorySessionStore.java @@ -0,0 +1,90 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; + +/** + * InMemorySessionStore is a simple in-memory implementation of SessionStore. + * + *

+ * This implementation is suitable for: + *

    + *
  • Development and testing
  • + *
  • Single-instance deployments
  • + *
  • Prototyping
  • + *
+ * + *

+ * Note: Sessions are lost when the application restarts. For production + * use cases requiring persistence, implement a database-backed SessionStore. + * + * @param + * the type of the custom session state + */ +public class InMemorySessionStore implements SessionStore { + + private final Map> data = new ConcurrentHashMap<>(); + + /** + * Creates a new InMemorySessionStore. + */ + public InMemorySessionStore() { + } + + @Override + public CompletableFuture> get(String sessionId) { + return CompletableFuture.completedFuture(data.get(sessionId)); + } + + @Override + public CompletableFuture save(String sessionId, SessionData sessionData) { + data.put(sessionId, sessionData); + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture delete(String sessionId) { + data.remove(sessionId); + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture exists(String sessionId) { + return CompletableFuture.completedFuture(data.containsKey(sessionId)); + } + + /** + * Returns the number of sessions currently stored. + * + * @return the session count + */ + public int size() { + return data.size(); + } + + /** + * Clears all sessions from the store. + */ + public void clear() { + data.clear(); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/session/Session.java b/java/ai/src/main/java/com/google/genkit/ai/session/Session.java new file mode 100644 index 0000000000..f24ed6f217 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/session/Session.java @@ -0,0 +1,342 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; + +import com.google.genkit.ai.Agent; +import com.google.genkit.ai.Message; +import com.google.genkit.core.Registry; + +/** + * Session represents a stateful chat session that persists conversation history + * and custom state across multiple interactions. + * + *

+ * Sessions provide: + *

    + *
  • Persistent conversation threads
  • + *
  • Custom state management
  • + *
  • Multiple named chat threads within a session
  • + *
  • Automatic history management
  • + *
+ * + *

+ * Example usage: + * + *

{@code
+ * // Create a session with initial state
+ * Session session = genkit
+ * 		.createSession(SessionOptions.builder().initialState(new MyState("John")).build());
+ *
+ * // Create a chat and interact
+ * Chat chat = session.chat();
+ * ModelResponse response = chat.send("Hello!");
+ *
+ * // Access session state
+ * MyState state = session.getState();
+ * }
+ * + * @param + * the type of the custom session state + */ +public class Session { + + /** Default thread name for chat conversations. */ + public static final String DEFAULT_THREAD = "main"; + + private final String id; + private SessionData sessionData; + private final SessionStore store; + private final Registry registry; + private final Supplier> chatFactory; + private final Map agentRegistry; + + /** + * Creates a new Session. + * + * @param registry + * the Genkit registry + * @param store + * the session store + * @param sessionData + * the initial session data + * @param chatFactory + * factory for creating Chat instances + * @param agentRegistry + * the agent registry for multi-agent handoffs (may be null) + */ + Session(Registry registry, SessionStore store, SessionData sessionData, Supplier> chatFactory, + Map agentRegistry) { + this.registry = registry; + this.store = store; + this.sessionData = sessionData; + this.id = sessionData.getId(); + this.chatFactory = chatFactory; + this.agentRegistry = agentRegistry; + } + + /** + * Gets the session ID. + * + * @return the unique session identifier + */ + public String getId() { + return id; + } + + /** + * Gets the current session state. + * + * @return the session state, or null if not set + */ + public S getState() { + return sessionData.getState(); + } + + /** + * Updates the session state and persists it. + * + * @param state + * the new state + * @return a CompletableFuture that completes when the state is saved + */ + public CompletableFuture updateState(S state) { + sessionData.setState(state); + return store.save(id, sessionData); + } + + /** + * Gets the message history for a thread. + * + * @param threadName + * the thread name + * @return the list of messages in the thread + */ + public List getMessages(String threadName) { + List messages = sessionData.getThread(threadName); + return messages != null ? new ArrayList<>(messages) : new ArrayList<>(); + } + + /** + * Gets the message history for the default thread. + * + * @return the list of messages + */ + public List getMessages() { + return getMessages(DEFAULT_THREAD); + } + + /** + * Updates the messages for a thread and persists them. + * + * @param threadName + * the thread name + * @param messages + * the messages to save + * @return a CompletableFuture that completes when saved + */ + public CompletableFuture updateMessages(String threadName, List messages) { + sessionData.setThread(threadName, messages); + return store.save(id, sessionData); + } + + /** + * Creates a new Chat instance for the default thread. + * + * @return a new Chat instance + */ + public Chat chat() { + return chat(DEFAULT_THREAD, ChatOptions.builder().build()); + } + + /** + * Creates a new Chat instance with options. + * + * @param options + * the chat options + * @return a new Chat instance + */ + public Chat chat(ChatOptions options) { + return chat(DEFAULT_THREAD, options); + } + + /** + * Creates a new Chat instance for a specific thread. + * + * @param threadName + * the thread name + * @return a new Chat instance + */ + public Chat chat(String threadName) { + return chat(threadName, ChatOptions.builder().build()); + } + + /** + * Creates a new Chat instance for a specific thread with options. + * + * @param threadName + * the thread name + * @param options + * the chat options + * @return a new Chat instance + */ + public Chat chat(String threadName, ChatOptions options) { + return new Chat<>(this, threadName, options, registry, agentRegistry); + } + + /** + * Gets the session store. + * + * @return the session store + */ + public SessionStore getStore() { + return store; + } + + /** + * Gets the registry. + * + * @return the registry + */ + public Registry getRegistry() { + return registry; + } + + /** + * Gets the agent registry for multi-agent handoffs. + * + * @return the agent registry, or null if not set + */ + public Map getAgentRegistry() { + return agentRegistry; + } + + /** + * Gets the session data. + * + * @return the session data + */ + public SessionData getSessionData() { + return sessionData; + } + + /** + * Serializes the session to JSON-compatible data. + * + * @return the session data + */ + public SessionData toJSON() { + return sessionData; + } + + /** + * Creates a new Session with a generated ID. + * + * @param + * the state type + * @param registry + * the Genkit registry + * @param options + * the session options + * @return a new Session + */ + public static Session create(Registry registry, SessionOptions options) { + return create(registry, options, null); + } + + /** + * Creates a new Session with a generated ID and agent registry. + * + * @param + * the state type + * @param registry + * the Genkit registry + * @param options + * the session options + * @param agentRegistry + * the agent registry for multi-agent handoffs (may be null) + * @return a new Session + */ + public static Session create(Registry registry, SessionOptions options, + Map agentRegistry) { + String sessionId = options.getSessionId() != null ? options.getSessionId() : UUID.randomUUID().toString(); + + SessionStore store = options.getStore() != null ? options.getStore() : new InMemorySessionStore<>(); + + SessionData data = SessionData.builder().id(sessionId).state(options.getInitialState()).build(); + + // Save initial session data + store.save(sessionId, data).join(); + + return new Session<>(registry, store, data, null, agentRegistry); + } + + /** + * Loads an existing session from a store. + * + * @param + * the state type + * @param registry + * the Genkit registry + * @param sessionId + * the session ID to load + * @param options + * the session options (must include store) + * @return a CompletableFuture containing the loaded session, or null if not + * found + */ + public static CompletableFuture> load(Registry registry, String sessionId, + SessionOptions options) { + return load(registry, sessionId, options, null); + } + + /** + * Loads an existing session from a store with agent registry. + * + * @param + * the state type + * @param registry + * the Genkit registry + * @param sessionId + * the session ID to load + * @param options + * the session options (must include store) + * @param agentRegistry + * the agent registry for multi-agent handoffs (may be null) + * @return a CompletableFuture containing the loaded session, or null if not + * found + */ + public static CompletableFuture> load(Registry registry, String sessionId, SessionOptions options, + Map agentRegistry) { + SessionStore store = options.getStore() != null ? options.getStore() : new InMemorySessionStore<>(); + + return store.get(sessionId).thenApply(data -> { + if (data == null) { + return null; + } + return new Session<>(registry, store, data, null, agentRegistry); + }); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/session/SessionContext.java b/java/ai/src/main/java/com/google/genkit/ai/session/SessionContext.java new file mode 100644 index 0000000000..b75b51857b --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/session/SessionContext.java @@ -0,0 +1,180 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import java.util.concurrent.Callable; + +import com.google.genkit.core.GenkitException; + +/** + * Provides access to the current session context. + * + *

+ * This class uses ThreadLocal to store the current session, making it + * accessible from within tool execution. This enables tools to access and + * modify session state during their execution. + * + *

+ * Example usage in a tool: + * + *

{@code
+ * Tool myTool = genkit.defineTool(
+ * 		ToolConfig.builder().name("myTool").description("A tool that accesses session state")
+ * 				.inputSchema(Input.class).outputSchema(Output.class).build(),
+ * 		(input, ctx) -> {
+ * 			// Access current session from within tool
+ * 			Session session = SessionContext.currentSession();
+ * 			MyState state = session.getState();
+ *
+ * 			// Update session state
+ * 			session.updateState(new MyState(state.getName(), state.getCount() + 1));
+ *
+ * 			return new Output("Updated");
+ * 		});
+ * }
+ */ +public final class SessionContext { + + private static final ThreadLocal> CURRENT_SESSION = new ThreadLocal<>(); + + private SessionContext() { + } + + /** + * Gets the current session. + * + * @param + * the session state type + * @return the current session + * @throws SessionException + * if not running within a session + */ + @SuppressWarnings("unchecked") + public static Session currentSession() { + Session session = CURRENT_SESSION.get(); + if (session == null) { + throw new SessionException("Not running within a session context"); + } + return (Session) session; + } + + /** + * Gets the current session if available. + * + * @param + * the session state type + * @return the current session, or null if not in a session context + */ + @SuppressWarnings("unchecked") + public static Session getCurrentSession() { + return (Session) CURRENT_SESSION.get(); + } + + /** + * Checks if currently running within a session context. + * + * @return true if in a session context + */ + public static boolean hasSession() { + return CURRENT_SESSION.get() != null; + } + + /** + * Runs a function within a session context. + * + * @param + * the session state type + * @param + * the return type + * @param session + * the session to use + * @param callable + * the function to run + * @return the result of the function + * @throws Exception + * if the function throws an exception + */ + public static T runWithSession(Session session, Callable callable) throws Exception { + Session previous = CURRENT_SESSION.get(); + try { + CURRENT_SESSION.set(session); + return callable.call(); + } finally { + if (previous != null) { + CURRENT_SESSION.set(previous); + } else { + CURRENT_SESSION.remove(); + } + } + } + + /** + * Runs a runnable within a session context. + * + * @param + * the session state type + * @param session + * the session to use + * @param runnable + * the runnable to execute + */ + public static void runWithSession(Session session, Runnable runnable) { + Session previous = CURRENT_SESSION.get(); + try { + CURRENT_SESSION.set(session); + runnable.run(); + } finally { + if (previous != null) { + CURRENT_SESSION.set(previous); + } else { + CURRENT_SESSION.remove(); + } + } + } + + /** + * Sets the current session. This is typically called internally by Chat. + * + * @param session + * the session to set + */ + public static void setSession(Session session) { + if (session != null) { + CURRENT_SESSION.set(session); + } else { + CURRENT_SESSION.remove(); + } + } + + /** Clears the current session. */ + public static void clearSession() { + CURRENT_SESSION.remove(); + } + + /** Exception thrown when session operations fail. */ + public static class SessionException extends GenkitException { + public SessionException(String message) { + super(message); + } + + public SessionException(String message, Throwable cause) { + super(message, cause); + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/session/SessionData.java b/java/ai/src/main/java/com/google/genkit/ai/session/SessionData.java new file mode 100644 index 0000000000..50a3638afc --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/session/SessionData.java @@ -0,0 +1,267 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.genkit.ai.Message; + +/** + * SessionData represents the persistent data structure for a session, including + * state and conversation threads. + * + * @param + * the type of the custom session state + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class SessionData { + + /** + * The unique identifier for this session. + */ + @JsonProperty("id") + private String id; + + /** + * Custom user-defined state associated with the session. + */ + @JsonProperty("state") + private S state; + + /** + * Named conversation threads. Each thread is identified by a string key + * (default is "main") and contains a list of messages. + */ + @JsonProperty("threads") + private Map> threads; + + /** + * Default constructor. + */ + public SessionData() { + this.threads = new HashMap<>(); + } + + /** + * Creates a new SessionData with the given ID. + * + * @param id + * the session ID + */ + public SessionData(String id) { + this.id = id; + this.threads = new HashMap<>(); + } + + /** + * Creates a new SessionData with the given ID and initial state. + * + * @param id + * the session ID + * @param state + * the initial state + */ + public SessionData(String id, S state) { + this.id = id; + this.state = state; + this.threads = new HashMap<>(); + } + + /** + * Gets the session ID. + * + * @return the session ID + */ + public String getId() { + return id; + } + + /** + * Sets the session ID. + * + * @param id + * the session ID + */ + public void setId(String id) { + this.id = id; + } + + /** + * Gets the session state. + * + * @return the session state + */ + public S getState() { + return state; + } + + /** + * Sets the session state. + * + * @param state + * the session state + */ + public void setState(S state) { + this.state = state; + } + + /** + * Gets all conversation threads. + * + * @return the threads map + */ + public Map> getThreads() { + return threads; + } + + /** + * Sets all conversation threads. + * + * @param threads + * the threads map + */ + public void setThreads(Map> threads) { + this.threads = threads; + } + + /** + * Gets a specific thread by name. + * + * @param threadName + * the thread name + * @return the list of messages in the thread, or null if not found + */ + public List getThread(String threadName) { + return threads.get(threadName); + } + + /** + * Gets or creates a thread by name. + * + * @param threadName + * the thread name + * @return the list of messages in the thread + */ + public List getOrCreateThread(String threadName) { + return threads.computeIfAbsent(threadName, k -> new ArrayList<>()); + } + + /** + * Sets messages for a specific thread. + * + * @param threadName + * the thread name + * @param messages + * the messages to set + */ + public void setThread(String threadName, List messages) { + threads.put(threadName, new ArrayList<>(messages)); + } + + /** + * Creates a builder for SessionData. + * + * @param + * the state type + * @return a new builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Builder for SessionData. + * + * @param + * the state type + */ + public static class Builder { + private String id; + private S state; + private Map> threads = new HashMap<>(); + + /** + * Sets the session ID. + * + * @param id + * the session ID + * @return this builder + */ + public Builder id(String id) { + this.id = id; + return this; + } + + /** + * Sets the session state. + * + * @param state + * the session state + * @return this builder + */ + public Builder state(S state) { + this.state = state; + return this; + } + + /** + * Sets the conversation threads. + * + * @param threads + * the threads map + * @return this builder + */ + public Builder threads(Map> threads) { + this.threads = new HashMap<>(threads); + return this; + } + + /** + * Adds a thread. + * + * @param threadName + * the thread name + * @param messages + * the messages + * @return this builder + */ + public Builder thread(String threadName, List messages) { + this.threads.put(threadName, new ArrayList<>(messages)); + return this; + } + + /** + * Builds the SessionData. + * + * @return the built SessionData + */ + public SessionData build() { + SessionData data = new SessionData<>(); + data.setId(id); + data.setState(state); + data.setThreads(threads); + return data; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/session/SessionOptions.java b/java/ai/src/main/java/com/google/genkit/ai/session/SessionOptions.java new file mode 100644 index 0000000000..a9b856575e --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/session/SessionOptions.java @@ -0,0 +1,168 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +/** + * SessionOptions provides configuration options for creating or loading + * sessions. + * + * @param + * the type of the custom session state + */ +public class SessionOptions { + + private SessionStore store; + private S initialState; + private String sessionId; + + /** + * Default constructor. + */ + public SessionOptions() { + } + + /** + * Gets the session store. + * + * @return the session store + */ + public SessionStore getStore() { + return store; + } + + /** + * Sets the session store. + * + * @param store + * the session store + */ + public void setStore(SessionStore store) { + this.store = store; + } + + /** + * Gets the initial state. + * + * @return the initial state + */ + public S getInitialState() { + return initialState; + } + + /** + * Sets the initial state. + * + * @param initialState + * the initial state + */ + public void setInitialState(S initialState) { + this.initialState = initialState; + } + + /** + * Gets the session ID. + * + * @return the session ID + */ + public String getSessionId() { + return sessionId; + } + + /** + * Sets the session ID. + * + * @param sessionId + * the session ID + */ + public void setSessionId(String sessionId) { + this.sessionId = sessionId; + } + + /** + * Creates a builder for SessionOptions. + * + * @param + * the state type + * @return a new builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Builder for SessionOptions. + * + * @param + * the state type + */ + public static class Builder { + private SessionStore store; + private S initialState; + private String sessionId; + + /** + * Sets the session store. + * + * @param store + * the session store + * @return this builder + */ + public Builder store(SessionStore store) { + this.store = store; + return this; + } + + /** + * Sets the initial state. + * + * @param initialState + * the initial state + * @return this builder + */ + public Builder initialState(S initialState) { + this.initialState = initialState; + return this; + } + + /** + * Sets the session ID. + * + * @param sessionId + * the session ID + * @return this builder + */ + public Builder sessionId(String sessionId) { + this.sessionId = sessionId; + return this; + } + + /** + * Builds the SessionOptions. + * + * @return the built SessionOptions + */ + public SessionOptions build() { + SessionOptions options = new SessionOptions<>(); + options.setStore(store); + options.setInitialState(initialState); + options.setSessionId(sessionId); + return options; + } + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/session/SessionStore.java b/java/ai/src/main/java/com/google/genkit/ai/session/SessionStore.java new file mode 100644 index 0000000000..356b5560bf --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/session/SessionStore.java @@ -0,0 +1,81 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import java.util.concurrent.CompletableFuture; + +/** + * SessionStore is an interface for persisting session data. + * + *

+ * Implementations can provide different storage backends such as: + *

    + *
  • In-memory storage (for development/testing)
  • + *
  • Database storage (for production)
  • + *
  • Redis or other distributed cache
  • + *
  • File-based storage
  • + *
+ * + * @param + * the type of the custom session state + */ +public interface SessionStore { + + /** + * Retrieves a session by its ID. + * + * @param sessionId + * the session ID + * @return a CompletableFuture containing the session data, or null if not found + */ + CompletableFuture> get(String sessionId); + + /** + * Saves session data. + * + * @param sessionId + * the session ID + * @param data + * the session data to save + * @return a CompletableFuture that completes when the save is done + */ + CompletableFuture save(String sessionId, SessionData data); + + /** + * Deletes a session by its ID. + * + * @param sessionId + * the session ID + * @return a CompletableFuture that completes when the deletion is done + */ + default CompletableFuture delete(String sessionId) { + return CompletableFuture.completedFuture(null); + } + + /** + * Checks if a session exists. + * + * @param sessionId + * the session ID + * @return a CompletableFuture containing true if the session exists + */ + default CompletableFuture exists(String sessionId) { + return get(sessionId).thenApply(data -> data != null); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/session/package-info.java b/java/ai/src/main/java/com/google/genkit/ai/session/package-info.java new file mode 100644 index 0000000000..ee0badcae1 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/session/package-info.java @@ -0,0 +1,94 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Provides session management for multi-turn agent conversations with + * persistence. + * + *

+ * The session package provides a stateful layer on top of Genkit's generation + * capabilities, enabling: + *

    + *
  • Persistent conversation history across multiple interactions
  • + *
  • Custom session state management
  • + *
  • Multiple named conversation threads within a session
  • + *
  • Pluggable storage backends via + * {@link com.google.genkit.ai.session.SessionStore}
  • + *
+ * + *

Key Components

+ *
    + *
  • {@link com.google.genkit.ai.session.Session} - The main entry point for + * session management
  • + *
  • {@link com.google.genkit.ai.session.Chat} - Manages conversations within + * a session thread
  • + *
  • {@link com.google.genkit.ai.session.SessionStore} - Interface for session + * persistence
  • + *
  • {@link com.google.genkit.ai.session.InMemorySessionStore} - Default + * in-memory implementation
  • + *
+ * + *

Example Usage

+ * + *
{@code
+ * // Create a session with custom state
+ * Session session = genkit
+ * 		.createSession(SessionOptions.builder().initialState(new MyState("John")).build());
+ *
+ * // Create a chat with system prompt
+ * Chat chat = session
+ * 		.chat(ChatOptions.builder().model("openai/gpt-4o").system("You are a helpful assistant.").build());
+ *
+ * // Multi-turn conversation (history is preserved automatically)
+ * chat.send("What is the capital of France?");
+ * chat.send("And what about Germany?");
+ *
+ * // Access session state
+ * MyState state = session.getState();
+ *
+ * // Load an existing session
+ * Session loadedSession = genkit.loadSession(sessionId, options).get();
+ * }
+ * + *

Custom Session Stores

+ *

+ * Implement {@link com.google.genkit.ai.session.SessionStore} to provide custom + * persistence backends (e.g., database, Redis, file system): + * + *

+ * {
+ * 	@code
+ * 	public class RedisSessionStore implements SessionStore {
+ * 		@Override
+ * 		public CompletableFuture> get(String sessionId) {
+ * 			// Load from Redis
+ * 		}
+ *
+ * 		@Override
+ * 		public CompletableFuture save(String sessionId, SessionData data) {
+ * 			// Save to Redis
+ * 		}
+ * 	}
+ * }
+ * 
+ * + * @see com.google.genkit.ai.session.Session + * @see com.google.genkit.ai.session.Chat + * @see com.google.genkit.ai.session.SessionStore + */ +package com.google.genkit.ai.session; diff --git a/java/ai/src/main/java/com/google/genkit/ai/telemetry/ActionTelemetry.java b/java/ai/src/main/java/com/google/genkit/ai/telemetry/ActionTelemetry.java new file mode 100644 index 0000000000..face4a5d31 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/telemetry/ActionTelemetry.java @@ -0,0 +1,132 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.telemetry; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.metrics.LongCounter; +import io.opentelemetry.api.metrics.LongHistogram; +import io.opentelemetry.api.metrics.Meter; + +/** + * ActionTelemetry provides metrics collection for general action execution. + * + *

+ * This class tracks: + *

    + *
  • Request counts per action type
  • + *
  • Latency histograms per action
  • + *
  • Failure counts
  • + *
+ */ +public class ActionTelemetry { + + private static final Logger logger = LoggerFactory.getLogger(ActionTelemetry.class); + private static final String METER_NAME = "genkit"; + private static final String SOURCE = "java"; + + // Metric names following JS/Go SDK conventions + private static final String METRIC_REQUESTS = "genkit/action/requests"; + private static final String METRIC_LATENCY = "genkit/action/latency"; + + private final LongCounter requestCounter; + private final LongHistogram latencyHistogram; + + private static ActionTelemetry instance; + + /** + * Gets the singleton instance of ActionTelemetry. + * + * @return the ActionTelemetry instance + */ + public static synchronized ActionTelemetry getInstance() { + if (instance == null) { + instance = new ActionTelemetry(); + } + return instance; + } + + private ActionTelemetry() { + Meter meter = GlobalOpenTelemetry.getMeter(METER_NAME); + + requestCounter = meter.counterBuilder(METRIC_REQUESTS).setDescription("Counts calls to genkit actions.") + .setUnit("1").build(); + + latencyHistogram = meter.histogramBuilder(METRIC_LATENCY) + .setDescription("Latencies when executing Genkit actions.").setUnit("ms").ofLongs().build(); + + logger.debug("ActionTelemetry initialized with OpenTelemetry metrics"); + } + + /** + * Records metrics for an action execution. + * + * @param actionName + * the action name + * @param actionType + * the action type (flow, model, tool, etc.) + * @param featureName + * the feature name (flow name or action name) + * @param path + * the span path + * @param latencyMs + * the latency in milliseconds + * @param error + * the error name if failed, null otherwise + */ + public void recordActionMetrics(String actionName, String actionType, String featureName, String path, + long latencyMs, String error) { + String status = error != null ? "failure" : "success"; + + Attributes baseAttrs = Attributes.builder().put("name", truncate(actionName, 1024)) + .put("type", actionType != null ? actionType : "unknown").put("featureName", truncate(featureName, 256)) + .put("path", truncate(path, 2048)).put("status", status).put("source", SOURCE).build(); + + // Record request count + Attributes requestAttrs = error != null + ? baseAttrs.toBuilder().put("error", truncate(error, 256)).build() + : baseAttrs; + requestCounter.add(1, requestAttrs); + + // Record latency + latencyHistogram.record(latencyMs, baseAttrs); + } + + /** + * Truncates a string to the specified maximum length. + * + * @param value + * the string to truncate + * @param maxLength + * the maximum length + * @return the truncated string + */ + private String truncate(String value, int maxLength) { + if (value == null) { + return ""; + } + if (value.length() <= maxLength) { + return value; + } + return value.substring(0, maxLength); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/telemetry/FeatureTelemetry.java b/java/ai/src/main/java/com/google/genkit/ai/telemetry/FeatureTelemetry.java new file mode 100644 index 0000000000..301e6f307e --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/telemetry/FeatureTelemetry.java @@ -0,0 +1,200 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.telemetry; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.metrics.LongCounter; +import io.opentelemetry.api.metrics.LongHistogram; +import io.opentelemetry.api.metrics.Meter; + +/** + * FeatureTelemetry provides metrics collection for top-level feature (flow) + * execution. + * + *

+ * This class tracks: + *

    + *
  • Feature request counts
  • + *
  • Feature latency histograms
  • + *
  • Path-level metrics for observability
  • + *
+ * + *

+ * Features in Genkit are the entry points to AI functionality, typically flows + * that users interact with directly. + */ +public class FeatureTelemetry { + + private static final Logger logger = LoggerFactory.getLogger(FeatureTelemetry.class); + private static final String METER_NAME = "genkit"; + private static final String SOURCE = "java"; + + // Feature-level metrics + private static final String METRIC_FEATURE_REQUESTS = "genkit/feature/requests"; + private static final String METRIC_FEATURE_LATENCY = "genkit/feature/latency"; + + // Path-level metrics + private static final String METRIC_PATH_REQUESTS = "genkit/feature/path/requests"; + private static final String METRIC_PATH_LATENCY = "genkit/feature/path/latency"; + + private final LongCounter featureRequestCounter; + private final LongHistogram featureLatencyHistogram; + private final LongCounter pathRequestCounter; + private final LongHistogram pathLatencyHistogram; + + private static FeatureTelemetry instance; + + /** + * Gets the singleton instance of FeatureTelemetry. + * + * @return the FeatureTelemetry instance + */ + public static synchronized FeatureTelemetry getInstance() { + if (instance == null) { + instance = new FeatureTelemetry(); + } + return instance; + } + + private FeatureTelemetry() { + Meter meter = GlobalOpenTelemetry.getMeter(METER_NAME); + + featureRequestCounter = meter.counterBuilder(METRIC_FEATURE_REQUESTS) + .setDescription("Counts calls to genkit features (flows).").setUnit("1").build(); + + featureLatencyHistogram = meter.histogramBuilder(METRIC_FEATURE_LATENCY) + .setDescription("Latencies when executing Genkit features.").setUnit("ms").ofLongs().build(); + + pathRequestCounter = meter.counterBuilder(METRIC_PATH_REQUESTS) + .setDescription("Tracks unique flow paths per flow.").setUnit("1").build(); + + pathLatencyHistogram = meter.histogramBuilder(METRIC_PATH_LATENCY).setDescription("Latencies per flow path.") + .setUnit("ms").ofLongs().build(); + + logger.debug("FeatureTelemetry initialized with OpenTelemetry metrics"); + } + + /** + * Records metrics for a feature (root flow) execution. + * + * @param featureName + * the feature name + * @param path + * the span path + * @param latencyMs + * the latency in milliseconds + * @param error + * the error name if failed, null otherwise + */ + public void recordFeatureMetrics(String featureName, String path, long latencyMs, String error) { + String status = error != null ? "failure" : "success"; + + Attributes attrs = Attributes.builder().put("featureName", truncate(featureName, 256)) + .put("path", truncate(path, 2048)).put("status", status).put("source", SOURCE).build(); + + featureRequestCounter.add(1, + error != null ? attrs.toBuilder().put("error", truncate(error, 256)).build() : attrs); + featureLatencyHistogram.record(latencyMs, attrs); + } + + /** + * Records metrics for a path within a flow. + * + * @param featureName + * the feature name + * @param path + * the full path including step types + * @param latencyMs + * the latency in milliseconds + * @param error + * the error name if failed, null otherwise + */ + public void recordPathMetrics(String featureName, String path, long latencyMs, String error) { + String status = error != null ? "failure" : "success"; + String simplePath = extractSimplePathFromQualified(path); + + Attributes attrs = Attributes.builder().put("featureName", truncate(featureName, 256)) + .put("path", truncate(simplePath, 2048)).put("status", status).put("source", SOURCE).build(); + + pathRequestCounter.add(1, error != null ? attrs.toBuilder().put("error", truncate(error, 256)).build() : attrs); + pathLatencyHistogram.record(latencyMs, attrs); + } + + /** + * Extracts a simple path name from a qualified path. For example: + * /{flow,t:flow}/{step,t:action} -> flow/step + * + * @param qualifiedPath + * the qualified path with type annotations + * @return the simple path + */ + private String extractSimplePathFromQualified(String qualifiedPath) { + if (qualifiedPath == null || qualifiedPath.isEmpty()) { + return ""; + } + + StringBuilder simplePath = new StringBuilder(); + String[] parts = qualifiedPath.split("/"); + + for (String part : parts) { + if (part.isEmpty()) + continue; + + // Extract name from {name,t:type} format + if (part.startsWith("{") && part.contains(",")) { + String name = part.substring(1, part.indexOf(',')); + if (simplePath.length() > 0) { + simplePath.append("/"); + } + simplePath.append(name); + } else if (part.startsWith("{") && part.endsWith("}")) { + String name = part.substring(1, part.length() - 1); + if (simplePath.length() > 0) { + simplePath.append("/"); + } + simplePath.append(name); + } + } + + return simplePath.toString(); + } + + /** + * Truncates a string to the specified maximum length. + * + * @param value + * the string to truncate + * @param maxLength + * the maximum length + * @return the truncated string + */ + private String truncate(String value, int maxLength) { + if (value == null) { + return ""; + } + if (value.length() <= maxLength) { + return value; + } + return value.substring(0, maxLength); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/telemetry/GenerateTelemetry.java b/java/ai/src/main/java/com/google/genkit/ai/telemetry/GenerateTelemetry.java new file mode 100644 index 0000000000..386ab1da7a --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/telemetry/GenerateTelemetry.java @@ -0,0 +1,219 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.telemetry; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Usage; + +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.metrics.LongCounter; +import io.opentelemetry.api.metrics.LongHistogram; +import io.opentelemetry.api.metrics.Meter; + +/** + * GenerateTelemetry provides metrics collection for model generate actions. + * + *

+ * This class tracks: + *

    + *
  • Request counts per model
  • + *
  • Latency histograms
  • + *
  • Input/output token counts
  • + *
  • Input/output character counts
  • + *
  • Input/output image counts
  • + *
+ * + *

+ * The metrics follow the same naming conventions as the JS and Go SDKs for + * consistency across the Genkit ecosystem. + */ +public class GenerateTelemetry { + + private static final Logger logger = LoggerFactory.getLogger(GenerateTelemetry.class); + private static final String METER_NAME = "genkit"; + private static final String SOURCE = "java"; + + // Metric names following JS/Go SDK conventions + private static final String METRIC_REQUESTS = "genkit/ai/generate/requests"; + private static final String METRIC_LATENCY = "genkit/ai/generate/latency"; + private static final String METRIC_INPUT_TOKENS = "genkit/ai/generate/input/tokens"; + private static final String METRIC_OUTPUT_TOKENS = "genkit/ai/generate/output/tokens"; + private static final String METRIC_INPUT_CHARS = "genkit/ai/generate/input/characters"; + private static final String METRIC_OUTPUT_CHARS = "genkit/ai/generate/output/characters"; + private static final String METRIC_INPUT_IMAGES = "genkit/ai/generate/input/images"; + private static final String METRIC_OUTPUT_IMAGES = "genkit/ai/generate/output/images"; + private static final String METRIC_THINKING_TOKENS = "genkit/ai/generate/thinking/tokens"; + + private final LongCounter requestCounter; + private final LongHistogram latencyHistogram; + private final LongCounter inputTokensCounter; + private final LongCounter outputTokensCounter; + private final LongCounter inputCharsCounter; + private final LongCounter outputCharsCounter; + private final LongCounter inputImagesCounter; + private final LongCounter outputImagesCounter; + private final LongCounter thinkingTokensCounter; + + private static GenerateTelemetry instance; + + /** + * Gets the singleton instance of GenerateTelemetry. + * + * @return the GenerateTelemetry instance + */ + public static synchronized GenerateTelemetry getInstance() { + if (instance == null) { + instance = new GenerateTelemetry(); + } + return instance; + } + + private GenerateTelemetry() { + Meter meter = GlobalOpenTelemetry.getMeter(METER_NAME); + + requestCounter = meter.counterBuilder(METRIC_REQUESTS) + .setDescription("Counts calls to genkit generate actions.").setUnit("1").build(); + + latencyHistogram = meter.histogramBuilder(METRIC_LATENCY) + .setDescription("Latencies when interacting with a Genkit model.").setUnit("ms").ofLongs().build(); + + inputTokensCounter = meter.counterBuilder(METRIC_INPUT_TOKENS) + .setDescription("Counts input tokens to a Genkit model.").setUnit("1").build(); + + outputTokensCounter = meter.counterBuilder(METRIC_OUTPUT_TOKENS) + .setDescription("Counts output tokens from a Genkit model.").setUnit("1").build(); + + inputCharsCounter = meter.counterBuilder(METRIC_INPUT_CHARS) + .setDescription("Counts input characters to any Genkit model.").setUnit("1").build(); + + outputCharsCounter = meter.counterBuilder(METRIC_OUTPUT_CHARS) + .setDescription("Counts output characters from a Genkit model.").setUnit("1").build(); + + inputImagesCounter = meter.counterBuilder(METRIC_INPUT_IMAGES) + .setDescription("Counts input images to a Genkit model.").setUnit("1").build(); + + outputImagesCounter = meter.counterBuilder(METRIC_OUTPUT_IMAGES) + .setDescription("Count output images from a Genkit model.").setUnit("1").build(); + + thinkingTokensCounter = meter.counterBuilder(METRIC_THINKING_TOKENS) + .setDescription("Counts thinking tokens from a Genkit model.").setUnit("1").build(); + + logger.debug("GenerateTelemetry initialized with OpenTelemetry metrics"); + } + + /** + * Records metrics for a generate action. + * + * @param modelName + * the model name + * @param featureName + * the feature name (flow name or "generate") + * @param path + * the span path + * @param response + * the model response (may be null) + * @param latencyMs + * the latency in milliseconds + * @param error + * the error name if failed, null otherwise + */ + public void recordGenerateMetrics(String modelName, String featureName, String path, ModelResponse response, + long latencyMs, String error) { + String status = error != null ? "failure" : "success"; + + Attributes baseAttrs = Attributes.builder().put("modelName", truncate(modelName, 1024)) + .put("featureName", truncate(featureName, 256)).put("path", truncate(path, 2048)).put("status", status) + .put("source", SOURCE).build(); + + // Record request count + Attributes requestAttrs = error != null + ? baseAttrs.toBuilder().put("error", truncate(error, 256)).build() + : baseAttrs; + requestCounter.add(1, requestAttrs); + + // Record latency + latencyHistogram.record(latencyMs, baseAttrs); + + // Record usage metrics if available + if (response != null && response.getUsage() != null) { + recordUsageMetrics(response.getUsage(), baseAttrs); + } + } + + /** + * Records usage metrics from a model response. + * + * @param usage + * the usage statistics + * @param attrs + * the base attributes + */ + private void recordUsageMetrics(Usage usage, Attributes attrs) { + if (usage.getInputTokens() != null && usage.getInputTokens() > 0) { + inputTokensCounter.add(usage.getInputTokens(), attrs); + } + + if (usage.getOutputTokens() != null && usage.getOutputTokens() > 0) { + outputTokensCounter.add(usage.getOutputTokens(), attrs); + } + + if (usage.getInputCharacters() != null && usage.getInputCharacters() > 0) { + inputCharsCounter.add(usage.getInputCharacters(), attrs); + } + + if (usage.getOutputCharacters() != null && usage.getOutputCharacters() > 0) { + outputCharsCounter.add(usage.getOutputCharacters(), attrs); + } + + if (usage.getInputImages() != null && usage.getInputImages() > 0) { + inputImagesCounter.add(usage.getInputImages(), attrs); + } + + if (usage.getOutputImages() != null && usage.getOutputImages() > 0) { + outputImagesCounter.add(usage.getOutputImages(), attrs); + } + + if (usage.getThoughtsTokens() != null && usage.getThoughtsTokens() > 0) { + thinkingTokensCounter.add(usage.getThoughtsTokens(), attrs); + } + } + + /** + * Truncates a string to the specified maximum length. + * + * @param value + * the string to truncate + * @param maxLength + * the maximum length + * @return the truncated string + */ + private String truncate(String value, int maxLength) { + if (value == null) { + return ""; + } + if (value.length() <= maxLength) { + return value; + } + return value.substring(0, maxLength); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/telemetry/ModelTelemetryHelper.java b/java/ai/src/main/java/com/google/genkit/ai/telemetry/ModelTelemetryHelper.java new file mode 100644 index 0000000000..1e4fa7d15e --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/telemetry/ModelTelemetryHelper.java @@ -0,0 +1,179 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.telemetry; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.ai.Message; +import com.google.genkit.ai.ModelRequest; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.Usage; +import com.google.genkit.core.GenkitException; + +/** + * ModelTelemetryHelper provides utilities for recording model telemetry. + * + *

+ * This helper should be used when invoking models to automatically record + * metrics like latency, token counts, and error rates. + */ +public class ModelTelemetryHelper { + + private static final Logger logger = LoggerFactory.getLogger(ModelTelemetryHelper.class); + + /** + * Executes a model call with automatic telemetry recording. + * + * @param modelName + * the model name + * @param featureName + * the feature/flow name + * @param path + * the span path + * @param request + * the model request + * @param modelFn + * the function that executes the model + * @return the model response + * @throws GenkitException + * if model execution fails + */ + public static ModelResponse runWithTelemetry(String modelName, String featureName, String path, + ModelRequest request, ModelExecutor modelFn) throws GenkitException { + long startTime = System.currentTimeMillis(); + String error = null; + ModelResponse response = null; + + try { + response = modelFn.execute(request); + + // Calculate usage statistics if not provided by the model + if (response != null && response.getUsage() == null) { + Usage calculatedUsage = calculateBasicUsage(request, response); + response.setUsage(calculatedUsage); + } + + // Set latency if not already set + if (response != null && response.getLatencyMs() == null) { + response.setLatencyMs(System.currentTimeMillis() - startTime); + } + + return response; + } catch (GenkitException e) { + error = e.getClass().getSimpleName(); + throw e; + } catch (Exception e) { + error = e.getClass().getSimpleName(); + throw new GenkitException("Model execution failed: " + e.getMessage(), e); + } finally { + long latencyMs = System.currentTimeMillis() - startTime; + + // Record telemetry metrics + try { + GenerateTelemetry.getInstance().recordGenerateMetrics(modelName, + featureName != null ? featureName : "generate", path != null ? path : "", response, latencyMs, + error); + } catch (Exception e) { + logger.warn("Failed to record model telemetry: {}", e.getMessage()); + } + } + } + + /** + * Calculates basic usage statistics from request and response. + * + * @param request + * the model request + * @param response + * the model response + * @return calculated usage statistics + */ + public static Usage calculateBasicUsage(ModelRequest request, ModelResponse response) { + Usage usage = new Usage(); + + // Calculate input statistics + int inputChars = 0; + int inputImages = 0; + + if (request != null && request.getMessages() != null) { + for (Message message : request.getMessages()) { + if (message.getContent() != null) { + for (Part part : message.getContent()) { + if (part.getText() != null) { + inputChars += part.getText().length(); + } + if (part.getMedia() != null) { + String contentType = part.getMedia().getContentType(); + if (contentType != null && contentType.startsWith("image/")) { + inputImages++; + } + } + } + } + } + } + + // Calculate output statistics + int outputChars = 0; + int outputImages = 0; + + if (response != null && response.getMessage() != null) { + Message outputMessage = response.getMessage(); + if (outputMessage.getContent() != null) { + for (Part part : outputMessage.getContent()) { + if (part.getText() != null) { + outputChars += part.getText().length(); + } + if (part.getMedia() != null) { + String contentType = part.getMedia().getContentType(); + if (contentType != null && contentType.startsWith("image/")) { + outputImages++; + } + } + } + } + } + + usage.setInputCharacters(inputChars); + usage.setOutputCharacters(outputChars); + usage.setInputImages(inputImages > 0 ? inputImages : null); + usage.setOutputImages(outputImages > 0 ? outputImages : null); + + return usage; + } + + /** + * Functional interface for model execution. + */ + @FunctionalInterface + public interface ModelExecutor { + /** + * Executes the model with the given request. + * + * @param request + * the model request + * @return the model response + * @throws GenkitException + * if execution fails + */ + ModelResponse execute(ModelRequest request) throws GenkitException; + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/telemetry/ToolTelemetry.java b/java/ai/src/main/java/com/google/genkit/ai/telemetry/ToolTelemetry.java new file mode 100644 index 0000000000..582c89f112 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/telemetry/ToolTelemetry.java @@ -0,0 +1,129 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.telemetry; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.metrics.LongCounter; +import io.opentelemetry.api.metrics.LongHistogram; +import io.opentelemetry.api.metrics.Meter; + +/** + * ToolTelemetry provides metrics collection for tool execution. + * + *

+ * This class tracks: + *

    + *
  • Tool invocation counts
  • + *
  • Tool latency histograms
  • + *
  • Tool error rates
  • + *
+ */ +public class ToolTelemetry { + + private static final Logger logger = LoggerFactory.getLogger(ToolTelemetry.class); + private static final String METER_NAME = "genkit"; + private static final String SOURCE = "java"; + + // Metric names following conventions + private static final String METRIC_REQUESTS = "genkit/tool/requests"; + private static final String METRIC_LATENCY = "genkit/tool/latency"; + + private final LongCounter requestCounter; + private final LongHistogram latencyHistogram; + + private static ToolTelemetry instance; + + /** + * Gets the singleton instance of ToolTelemetry. + * + * @return the ToolTelemetry instance + */ + public static synchronized ToolTelemetry getInstance() { + if (instance == null) { + instance = new ToolTelemetry(); + } + return instance; + } + + private ToolTelemetry() { + Meter meter = GlobalOpenTelemetry.getMeter(METER_NAME); + + requestCounter = meter.counterBuilder(METRIC_REQUESTS).setDescription("Counts calls to genkit tools.") + .setUnit("1").build(); + + latencyHistogram = meter.histogramBuilder(METRIC_LATENCY) + .setDescription("Latencies when executing Genkit tools.").setUnit("ms").ofLongs().build(); + + logger.debug("ToolTelemetry initialized with OpenTelemetry metrics"); + } + + /** + * Records metrics for a tool execution. + * + * @param toolName + * the tool name + * @param featureName + * the feature/flow name + * @param path + * the span path + * @param latencyMs + * the latency in milliseconds + * @param error + * the error name if failed, null otherwise + */ + public void recordToolMetrics(String toolName, String featureName, String path, long latencyMs, String error) { + String status = error != null ? "failure" : "success"; + + Attributes baseAttrs = Attributes.builder().put("toolName", truncate(toolName, 1024)) + .put("featureName", truncate(featureName, 256)).put("path", truncate(path, 2048)).put("status", status) + .put("source", SOURCE).build(); + + // Record request count + Attributes requestAttrs = error != null + ? baseAttrs.toBuilder().put("error", truncate(error, 256)).build() + : baseAttrs; + requestCounter.add(1, requestAttrs); + + // Record latency + latencyHistogram.record(latencyMs, baseAttrs); + } + + /** + * Truncates a string to the specified maximum length. + * + * @param value + * the string to truncate + * @param maxLength + * the maximum length + * @return the truncated string + */ + private String truncate(String value, int maxLength) { + if (value == null) { + return ""; + } + if (value.length() <= maxLength) { + return value; + } + return value.substring(0, maxLength); + } +} diff --git a/java/ai/src/main/java/com/google/genkit/ai/telemetry/package-info.java b/java/ai/src/main/java/com/google/genkit/ai/telemetry/package-info.java new file mode 100644 index 0000000000..37d477d968 --- /dev/null +++ b/java/ai/src/main/java/com/google/genkit/ai/telemetry/package-info.java @@ -0,0 +1,53 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Telemetry module for Genkit Java SDK. + * + *

+ * This package provides observability utilities for tracking: + *

    + *
  • Model generation metrics (token counts, latency, etc.)
  • + *
  • Tool execution metrics
  • + *
  • Feature/flow-level metrics
  • + *
  • Action-level metrics
  • + *
+ * + *

+ * The telemetry classes integrate with OpenTelemetry for metrics export, + * allowing Genkit applications to be monitored using standard observability + * tools like Google Cloud Operations, Prometheus, etc. + * + *

+ * Key classes: + *

    + *
  • {@link com.google.genkit.ai.telemetry.GenerateTelemetry} - Tracks model + * generation metrics
  • + *
  • {@link com.google.genkit.ai.telemetry.ToolTelemetry} - Tracks tool + * execution metrics
  • + *
  • {@link com.google.genkit.ai.telemetry.FeatureTelemetry} - Tracks + * feature/flow metrics
  • + *
  • {@link com.google.genkit.ai.telemetry.ActionTelemetry} - Tracks general + * action metrics
  • + *
  • {@link com.google.genkit.ai.telemetry.ModelTelemetryHelper} - Helper for + * recording model telemetry
  • + *
+ * + * @see OpenTelemetry + */ +package com.google.genkit.ai.telemetry; diff --git a/java/ai/src/test/java/com/google/genkit/ai/AgentConfigTest.java b/java/ai/src/test/java/com/google/genkit/ai/AgentConfigTest.java new file mode 100644 index 0000000000..461652e01d --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/AgentConfigTest.java @@ -0,0 +1,157 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for AgentConfig. + */ +class AgentConfigTest { + + /** + * Helper to create a simple test tool. + */ + private Tool createTestTool(String name) { + Map schema = new HashMap<>(); + schema.put("type", "string"); + return new Tool<>(name, "Test tool " + name, schema, schema, String.class, (ctx, input) -> "result"); + } + + @Test + void testBuilderWithAllFields() { + Tool tool1 = createTestTool("tool1"); + Tool tool2 = createTestTool("tool2"); + + AgentConfig subAgent = AgentConfig.builder().name("subAgent").description("Sub agent").build(); + + GenerationConfig genConfig = GenerationConfig.builder().temperature(0.7).build(); + + OutputConfig outputConfig = new OutputConfig(); + outputConfig.setFormat(OutputFormat.JSON); + + AgentConfig config = AgentConfig.builder().name("mainAgent").description("Main agent description") + .system("You are a helpful assistant.").model("openai/gpt-4o").tools(List.of(tool1, tool2)) + .agents(List.of(subAgent)).config(genConfig).output(outputConfig).build(); + + assertEquals("mainAgent", config.getName()); + assertEquals("Main agent description", config.getDescription()); + assertEquals("You are a helpful assistant.", config.getSystem()); + assertEquals("openai/gpt-4o", config.getModel()); + assertEquals(2, config.getTools().size()); + assertEquals(1, config.getAgents().size()); + assertEquals("subAgent", config.getAgents().get(0).getName()); + assertEquals(genConfig, config.getConfig()); + assertEquals(outputConfig, config.getOutput()); + } + + @Test + void testBuilderWithMinimalFields() { + AgentConfig config = AgentConfig.builder().name("simpleAgent").description("A simple agent").build(); + + assertEquals("simpleAgent", config.getName()); + assertEquals("A simple agent", config.getDescription()); + assertNull(config.getSystem()); + assertNull(config.getModel()); + assertNull(config.getTools()); + assertNull(config.getAgents()); + assertNull(config.getConfig()); + assertNull(config.getOutput()); + } + + @Test + void testDefaultConstructor() { + AgentConfig config = new AgentConfig(); + + assertNull(config.getName()); + assertNull(config.getDescription()); + assertNull(config.getSystem()); + assertNull(config.getModel()); + assertNull(config.getTools()); + assertNull(config.getAgents()); + assertNull(config.getConfig()); + assertNull(config.getOutput()); + } + + @Test + void testSetters() { + AgentConfig config = new AgentConfig(); + Tool tool = createTestTool("testTool"); + AgentConfig subAgent = AgentConfig.builder().name("sub").build(); + GenerationConfig genConfig = GenerationConfig.builder().build(); + OutputConfig outputConfig = new OutputConfig(); + + config.setName("agent"); + config.setDescription("desc"); + config.setSystem("system"); + config.setModel("model"); + config.setTools(List.of(tool)); + config.setAgents(List.of(subAgent)); + config.setConfig(genConfig); + config.setOutput(outputConfig); + + assertEquals("agent", config.getName()); + assertEquals("desc", config.getDescription()); + assertEquals("system", config.getSystem()); + assertEquals("model", config.getModel()); + assertEquals(1, config.getTools().size()); + assertEquals(1, config.getAgents().size()); + assertEquals(genConfig, config.getConfig()); + assertEquals(outputConfig, config.getOutput()); + } + + @Test + void testNestedAgents() { + AgentConfig level3 = AgentConfig.builder().name("level3").description("Level 3 agent").build(); + + AgentConfig level2 = AgentConfig.builder().name("level2").description("Level 2 agent").agents(List.of(level3)) + .build(); + + AgentConfig level1 = AgentConfig.builder().name("level1").description("Level 1 agent").agents(List.of(level2)) + .build(); + + assertEquals("level1", level1.getName()); + assertEquals(1, level1.getAgents().size()); + assertEquals("level2", level1.getAgents().get(0).getName()); + assertEquals(1, level1.getAgents().get(0).getAgents().size()); + assertEquals("level3", level1.getAgents().get(0).getAgents().get(0).getName()); + } + + @Test + void testMultipleToolsAndAgents() { + Tool tool1 = createTestTool("tool1"); + Tool tool2 = createTestTool("tool2"); + Tool tool3 = createTestTool("tool3"); + + AgentConfig sub1 = AgentConfig.builder().name("sub1").build(); + AgentConfig sub2 = AgentConfig.builder().name("sub2").build(); + + AgentConfig config = AgentConfig.builder().name("main").tools(List.of(tool1, tool2, tool3)) + .agents(List.of(sub1, sub2)).build(); + + assertEquals(3, config.getTools().size()); + assertEquals(2, config.getAgents().size()); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/AgentTest.java b/java/ai/src/test/java/com/google/genkit/ai/AgentTest.java new file mode 100644 index 0000000000..2f46d03eb7 --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/AgentTest.java @@ -0,0 +1,179 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for Agent. + */ +class AgentTest { + + /** + * Helper to create a simple test tool. + */ + private Tool createTestTool(String name) { + Map schema = new HashMap<>(); + schema.put("type", "string"); + return new Tool<>(name, "Test tool " + name, schema, schema, String.class, (ctx, input) -> "result"); + } + + @Test + void testAgentCreation() { + AgentConfig config = AgentConfig.builder().name("testAgent").description("Test agent description") + .system("You are a test agent.").model("test-model").build(); + + Agent agent = new Agent(config); + + assertEquals("testAgent", agent.getName()); + assertEquals("Test agent description", agent.getDescription()); + assertEquals("You are a test agent.", agent.getSystem()); + assertEquals("test-model", agent.getModel()); + assertEquals(config, agent.getConfig()); + } + + @Test + void testAsTool() { + AgentConfig config = AgentConfig.builder().name("delegateAgent").description("Handles delegated tasks").build(); + + Agent agent = new Agent(config); + Tool, Agent.AgentTransferResult> tool = agent.asTool(); + + assertNotNull(tool); + assertEquals("delegateAgent", tool.getDefinition().getName()); + assertEquals("Handles delegated tasks", tool.getDefinition().getDescription()); + assertNotNull(tool.getInputSchema()); + assertNotNull(tool.getOutputSchema()); + } + + @Test + void testGetToolDefinition() { + AgentConfig config = AgentConfig.builder().name("myAgent").description("My agent").build(); + + Agent agent = new Agent(config); + ToolDefinition def = agent.getToolDefinition(); + + assertEquals("myAgent", def.getName()); + assertEquals("My agent", def.getDescription()); + assertNotNull(def.getInputSchema()); + } + + @Test + void testGetTools() { + Tool tool1 = createTestTool("tool1"); + Tool tool2 = createTestTool("tool2"); + + AgentConfig config = AgentConfig.builder().name("agent").tools(List.of(tool1, tool2)).build(); + + Agent agent = new Agent(config); + + assertEquals(2, agent.getTools().size()); + assertTrue(agent.getTools().contains(tool1)); + assertTrue(agent.getTools().contains(tool2)); + } + + @Test + void testGetAgents() { + AgentConfig sub1 = AgentConfig.builder().name("sub1").build(); + AgentConfig sub2 = AgentConfig.builder().name("sub2").build(); + + AgentConfig config = AgentConfig.builder().name("parent").agents(List.of(sub1, sub2)).build(); + + Agent agent = new Agent(config); + + assertEquals(2, agent.getAgents().size()); + } + + @Test + void testGetAllToolsWithNoSubAgents() { + Tool tool1 = createTestTool("tool1"); + Tool tool2 = createTestTool("tool2"); + + AgentConfig config = AgentConfig.builder().name("agent").tools(List.of(tool1, tool2)).build(); + + Agent agent = new Agent(config); + Map registry = new HashMap<>(); + + List> allTools = agent.getAllTools(registry); + + assertEquals(2, allTools.size()); + } + + @Test + void testGetAllToolsWithSubAgents() { + Tool parentTool = createTestTool("parentTool"); + + AgentConfig subConfig = AgentConfig.builder().name("subAgent").description("Sub agent").build(); + + AgentConfig config = AgentConfig.builder().name("parent").tools(List.of(parentTool)).agents(List.of(subConfig)) + .build(); + + Agent parent = new Agent(config); + Agent subAgent = new Agent(subConfig); + + Map registry = new HashMap<>(); + registry.put("subAgent", subAgent); + + List> allTools = parent.getAllTools(registry); + + // Should have parent tool + sub-agent as tool + assertEquals(2, allTools.size()); + } + + @Test + void testAgentTransferResult() { + Agent.AgentTransferResult result = new Agent.AgentTransferResult("targetAgent"); + + assertEquals("targetAgent", result.getTransferredTo()); + assertTrue(result.isTransferred()); + assertTrue(result.toString().contains("targetAgent")); + } + + @Test + void testToString() { + AgentConfig config = AgentConfig.builder().name("myAgent").build(); + + Agent agent = new Agent(config); + String str = agent.toString(); + + assertTrue(str.contains("myAgent")); + assertTrue(str.contains("Agent")); + } + + @Test + void testNullToolsAndAgents() { + AgentConfig config = AgentConfig.builder().name("minimal").build(); + + Agent agent = new Agent(config); + + assertNull(agent.getTools()); + assertNull(agent.getAgents()); + + Map registry = new HashMap<>(); + List> allTools = agent.getAllTools(registry); + + assertTrue(allTools.isEmpty()); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/DocumentTest.java b/java/ai/src/test/java/com/google/genkit/ai/DocumentTest.java new file mode 100644 index 0000000000..b236e4630a --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/DocumentTest.java @@ -0,0 +1,190 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for Document. + */ +class DocumentTest { + + @Test + void testDefaultConstructor() { + Document doc = new Document(); + + assertNotNull(doc.getContent()); + assertTrue(doc.getContent().isEmpty()); + assertNotNull(doc.getMetadata()); + assertTrue(doc.getMetadata().isEmpty()); + } + + @Test + void testConstructorWithText() { + Document doc = new Document("Hello, world!"); + + assertEquals(1, doc.getContent().size()); + assertEquals("Hello, world!", doc.text()); + assertNotNull(doc.getMetadata()); + } + + @Test + void testConstructorWithParts() { + List parts = Arrays.asList(Part.text("Part 1"), Part.text("Part 2")); + Document doc = new Document(parts); + + assertEquals(2, doc.getContent().size()); + assertEquals("Part 1Part 2", doc.text()); + } + + @Test + void testConstructorWithNullParts() { + Document doc = new Document((List) null); + + assertNotNull(doc.getContent()); + assertTrue(doc.getContent().isEmpty()); + } + + @Test + void testFromText() { + Document doc = Document.fromText("Test content"); + + assertEquals(1, doc.getContent().size()); + assertEquals("Test content", doc.text()); + } + + @Test + void testFromTextWithMetadata() { + Map metadata = new HashMap<>(); + metadata.put("source", "test"); + metadata.put("page", 1); + + Document doc = Document.fromText("Test content", metadata); + + assertEquals("Test content", doc.text()); + assertEquals("test", doc.getMetadata().get("source")); + assertEquals(1, doc.getMetadata().get("page")); + } + + @Test + void testFromTextWithNullMetadata() { + Document doc = Document.fromText("Test content", null); + + assertEquals("Test content", doc.text()); + assertNotNull(doc.getMetadata()); + assertTrue(doc.getMetadata().isEmpty()); + } + + @Test + void testGetText() { + Document doc = new Document(); + doc.setContent(Arrays.asList(Part.text("Hello, "), Part.text("world!"))); + + assertEquals("Hello, world!", doc.text()); + } + + @Test + void testGetTextWithEmptyContent() { + Document doc = new Document(); + + assertEquals("", doc.text()); + } + + @Test + void testGetTextWithNullContent() { + Document doc = new Document(); + doc.setContent(null); + + assertEquals("", doc.text()); + } + + @Test + void testGetTextSkipsNonTextParts() { + Document doc = new Document(); + doc.setContent(Arrays.asList(Part.text("Text"), Part.media("image/png", "http://example.com/img.png"), + Part.text(" content"))); + + assertEquals("Text content", doc.text()); + } + + @Test + void testSetContent() { + Document doc = new Document(); + List content = Collections.singletonList(Part.text("New content")); + doc.setContent(content); + + assertEquals(1, doc.getContent().size()); + assertEquals("New content", doc.text()); + } + + @Test + void testSetMetadata() { + Document doc = new Document("Test"); + Map metadata = Map.of("key", "value"); + doc.setMetadata(metadata); + + assertEquals(metadata, doc.getMetadata()); + } + + @Test + void testAddPart() { + Document doc = new Document("Initial"); + doc.getContent().add(Part.text(" Added")); + + assertEquals("Initial Added", doc.text()); + } + + @Test + void testMetadataOperations() { + Document doc = new Document("Test"); + + doc.getMetadata().put("author", "John"); + doc.getMetadata().put("date", "2025-01-01"); + + assertEquals("John", doc.getMetadata().get("author")); + assertEquals("2025-01-01", doc.getMetadata().get("date")); + } + + @Test + void testDocumentWithMedia() { + Document doc = new Document(); + doc.setContent( + Arrays.asList(Part.text("Description: "), Part.media("image/png", "http://example.com/image.png"))); + + assertEquals(2, doc.getContent().size()); + assertEquals("Description: ", doc.text()); + assertNotNull(doc.getContent().get(1).getMedia()); + } + + @Test + void testEmptyTextDocument() { + Document doc = new Document(""); + + assertEquals(1, doc.getContent().size()); + assertEquals("", doc.text()); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/InterruptConfigTest.java b/java/ai/src/test/java/com/google/genkit/ai/InterruptConfigTest.java new file mode 100644 index 0000000000..4d3e896622 --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/InterruptConfigTest.java @@ -0,0 +1,124 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for InterruptConfig. + */ +class InterruptConfigTest { + + @Test + void testBuilderWithAllFields() { + Map inputSchema = Map.of("type", "object", "properties", Map.of()); + Map outputSchema = Map.of("type", "string"); + + InterruptConfig config = InterruptConfig.builder().name("confirm") + .description("Asks for confirmation").inputType(TestInput.class).outputType(TestOutput.class) + .inputSchema(inputSchema).outputSchema(outputSchema) + .requestMetadata(input -> Map.of("action", input.action)).build(); + + assertEquals("confirm", config.getName()); + assertEquals("Asks for confirmation", config.getDescription()); + assertEquals(TestInput.class, config.getInputType()); + assertEquals(TestOutput.class, config.getOutputType()); + assertEquals(inputSchema, config.getInputSchema()); + assertEquals(outputSchema, config.getOutputSchema()); + assertNotNull(config.getRequestMetadata()); + } + + @Test + void testBuilderWithMinimalFields() { + InterruptConfig config = InterruptConfig.builder().name("simple") + .description("Simple interrupt").inputType(TestInput.class).outputType(TestOutput.class).build(); + + assertEquals("simple", config.getName()); + assertEquals("Simple interrupt", config.getDescription()); + assertEquals(TestInput.class, config.getInputType()); + assertEquals(TestOutput.class, config.getOutputType()); + assertNull(config.getInputSchema()); + assertNull(config.getOutputSchema()); + assertNull(config.getRequestMetadata()); + } + + @Test + void testRequestMetadataFunction() { + InterruptConfig config = InterruptConfig.builder().name("confirm") + .description("Confirm action").inputType(TestInput.class).outputType(TestOutput.class) + .requestMetadata(input -> { + Map metadata = new HashMap<>(); + metadata.put("action", input.action); + metadata.put("amount", input.amount); + return metadata; + }).build(); + + TestInput input = new TestInput(); + input.action = "purchase"; + input.amount = 100; + + Map metadata = config.getRequestMetadata().apply(input); + + assertEquals("purchase", metadata.get("action")); + assertEquals(100, metadata.get("amount")); + } + + @Test + void testDefaultConstructor() { + InterruptConfig config = new InterruptConfig<>(); + + assertNull(config.getName()); + assertNull(config.getDescription()); + assertNull(config.getInputType()); + assertNull(config.getOutputType()); + assertNull(config.getInputSchema()); + assertNull(config.getOutputSchema()); + assertNull(config.getRequestMetadata()); + } + + @Test + void testSetters() { + InterruptConfig config = new InterruptConfig<>(); + + config.setName("test"); + config.setDescription("Test description"); + config.setInputType(TestInput.class); + config.setOutputType(TestOutput.class); + + assertEquals("test", config.getName()); + assertEquals("Test description", config.getDescription()); + assertEquals(TestInput.class, config.getInputType()); + assertEquals(TestOutput.class, config.getOutputType()); + } + + // Test helper classes + static class TestInput { + String action; + int amount; + } + + static class TestOutput { + boolean confirmed; + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/InterruptRequestTest.java b/java/ai/src/test/java/com/google/genkit/ai/InterruptRequestTest.java new file mode 100644 index 0000000000..24ab378ef3 --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/InterruptRequestTest.java @@ -0,0 +1,157 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for InterruptRequest. + */ +class InterruptRequestTest { + + @Test + void testConstructor() { + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName("confirm"); + toolRequest.setRef("ref-123"); + toolRequest.setInput(Map.of("action", "purchase")); + + Map metadata = new HashMap<>(); + metadata.put("amount", 100); + + InterruptRequest request = new InterruptRequest(toolRequest, metadata); + + assertEquals(toolRequest, request.getToolRequest()); + assertNotNull(request.getMetadata()); + assertEquals(100, request.getMetadata().get("amount")); + assertTrue(request.isInterrupt()); + } + + @Test + void testConstructorWithNullMetadata() { + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName("confirm"); + + InterruptRequest request = new InterruptRequest(toolRequest, null); + + assertNotNull(request.getMetadata()); + assertTrue(request.isInterrupt()); + } + + @Test + void testRespond() { + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName("confirm"); + toolRequest.setRef("ref-123"); + + InterruptRequest request = new InterruptRequest(toolRequest, new HashMap<>()); + + // Respond with confirmation + Map response = Map.of("confirmed", true); + ToolResponse toolResponse = request.respond(response); + + assertEquals("ref-123", toolResponse.getRef()); + assertEquals("confirm", toolResponse.getName()); + assertEquals(response, toolResponse.getOutput()); + } + + @Test + void testRestart() { + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName("confirm"); + toolRequest.setRef("ref-123"); + toolRequest.setInput(Map.of("action", "purchase")); + + InterruptRequest request = new InterruptRequest(toolRequest, new HashMap<>()); + + // Restart returns the original tool request + ToolRequest restartRequest = request.restart(); + + assertEquals("confirm", restartRequest.getName()); + assertEquals("ref-123", restartRequest.getRef()); + assertEquals(Map.of("action", "purchase"), restartRequest.getInput()); + } + + @Test + void testRestartWithModifiedInput() { + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName("confirm"); + toolRequest.setRef("ref-123"); + toolRequest.setInput(Map.of("action", "purchase", "amount", 100)); + + InterruptRequest request = new InterruptRequest(toolRequest, new HashMap<>()); + + // Restart with modified input + Map newMetadata = Map.of("reason", "retry"); + Map newInput = Map.of("action", "purchase", "amount", 50); + + ToolRequest restartRequest = request.restart(newMetadata, newInput); + + assertEquals("confirm", restartRequest.getName()); + // Should have same ref + assertEquals("ref-123", restartRequest.getRef()); + // But new input + assertEquals(newInput, restartRequest.getInput()); + } + + @Test + void testRestartWithNullNewInput() { + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName("confirm"); + toolRequest.setRef("ref-123"); + Object originalInput = Map.of("action", "purchase"); + toolRequest.setInput(originalInput); + + InterruptRequest request = new InterruptRequest(toolRequest, new HashMap<>()); + + // Restart with null new input should keep original + ToolRequest restartRequest = request.restart(null, null); + + assertEquals(originalInput, restartRequest.getInput()); + } + + @Test + void testMetadataContainsInterruptFlag() { + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName("test"); + + Map metadata = new HashMap<>(); + metadata.put("custom", "value"); + + InterruptRequest request = new InterruptRequest(toolRequest, metadata); + + assertTrue((Boolean) request.getMetadata().get("interrupt")); + assertEquals("value", request.getMetadata().get("custom")); + } + + @Test + void testGetToolName() { + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName("myInterrupt"); + + InterruptRequest request = new InterruptRequest(toolRequest, new HashMap<>()); + + assertEquals("myInterrupt", request.getToolRequest().getName()); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/JsonSerializationTest.java b/java/ai/src/test/java/com/google/genkit/ai/JsonSerializationTest.java new file mode 100644 index 0000000000..d979c3554b --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/JsonSerializationTest.java @@ -0,0 +1,225 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Arrays; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Unit tests for JSON serialization and deserialization of AI types. + */ +class JsonSerializationTest { + + private final ObjectMapper objectMapper; + + JsonSerializationTest() { + objectMapper = new ObjectMapper(); + objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + } + + @Test + void testMessageSerialization() throws Exception { + Message message = Message.user("Hello, world!"); + + String json = objectMapper.writeValueAsString(message); + + assertNotNull(json); + assertTrue(json.contains("\"role\":\"user\"")); + assertTrue(json.contains("\"content\"")); + } + + @Test + void testMessageDeserialization() throws Exception { + String json = "{\"role\":\"user\",\"content\":[{\"text\":\"Hello\"}]}"; + + Message message = objectMapper.readValue(json, Message.class); + + assertEquals(Role.USER, message.getRole()); + assertEquals("Hello", message.getText()); + } + + @Test + void testMessageRoundTrip() throws Exception { + Message original = Message.system("You are a helpful assistant."); + + String json = objectMapper.writeValueAsString(original); + Message deserialized = objectMapper.readValue(json, Message.class); + + assertEquals(original.getRole(), deserialized.getRole()); + assertEquals(original.getText(), deserialized.getText()); + } + + @Test + void testPartTextSerialization() throws Exception { + Part part = Part.text("Some text content"); + + String json = objectMapper.writeValueAsString(part); + + assertTrue(json.contains("\"text\":\"Some text content\"")); + } + + @Test + void testPartMediaSerialization() throws Exception { + Part part = Part.media("image/png", "http://example.com/img.png"); + + String json = objectMapper.writeValueAsString(part); + + assertTrue(json.contains("\"media\"")); + assertTrue(json.contains("\"contentType\":\"image/png\"")); + assertTrue(json.contains("\"url\":\"http://example.com/img.png\"")); + } + + @Test + void testPartDeserialization() throws Exception { + String json = "{\"text\":\"Hello\"}"; + + Part part = objectMapper.readValue(json, Part.class); + + assertEquals("Hello", part.getText()); + } + + @Test + void testDocumentSerialization() throws Exception { + Document doc = Document.fromText("Document content", Map.of("source", "test")); + + String json = objectMapper.writeValueAsString(doc); + + assertTrue(json.contains("\"content\"")); + assertTrue(json.contains("\"metadata\"")); + assertTrue(json.contains("\"source\":\"test\"")); + } + + @Test + void testDocumentDeserialization() throws Exception { + String json = "{\"content\":[{\"text\":\"Doc text\"}],\"metadata\":{\"key\":\"value\"}}"; + + Document doc = objectMapper.readValue(json, Document.class); + + assertEquals("Doc text", doc.text()); + assertEquals("value", doc.getMetadata().get("key")); + } + + @Test + void testRoleSerialization() throws Exception { + String userJson = objectMapper.writeValueAsString(Role.USER); + String modelJson = objectMapper.writeValueAsString(Role.MODEL); + String systemJson = objectMapper.writeValueAsString(Role.SYSTEM); + String toolJson = objectMapper.writeValueAsString(Role.TOOL); + + assertEquals("\"user\"", userJson); + assertEquals("\"model\"", modelJson); + assertEquals("\"system\"", systemJson); + assertEquals("\"tool\"", toolJson); + } + + @Test + void testRoleDeserialization() throws Exception { + assertEquals(Role.USER, objectMapper.readValue("\"user\"", Role.class)); + assertEquals(Role.MODEL, objectMapper.readValue("\"model\"", Role.class)); + assertEquals(Role.SYSTEM, objectMapper.readValue("\"system\"", Role.class)); + assertEquals(Role.TOOL, objectMapper.readValue("\"tool\"", Role.class)); + } + + @Test + void testRoleAssistantDeserialization() throws Exception { + // "assistant" should map to MODEL for compatibility + assertEquals(Role.MODEL, objectMapper.readValue("\"assistant\"", Role.class)); + } + + @Test + void testMediaSerialization() throws Exception { + Media media = new Media("video/mp4", "http://example.com/video.mp4"); + + String json = objectMapper.writeValueAsString(media); + + assertTrue(json.contains("\"contentType\":\"video/mp4\"")); + assertTrue(json.contains("\"url\":\"http://example.com/video.mp4\"")); + } + + @Test + void testMediaDeserialization() throws Exception { + String json = "{\"contentType\":\"audio/mp3\",\"url\":\"http://example.com/audio.mp3\"}"; + + Media media = objectMapper.readValue(json, Media.class); + + assertEquals("audio/mp3", media.getContentType()); + assertEquals("http://example.com/audio.mp3", media.getUrl()); + } + + @Test + void testToolRequestSerialization() throws Exception { + ToolRequest request = new ToolRequest(); + request.setName("calculator"); + request.setRef("calc-001"); + request.setInput(Map.of("a", 5, "b", 3)); + + String json = objectMapper.writeValueAsString(request); + + assertTrue(json.contains("\"name\":\"calculator\"")); + assertTrue(json.contains("\"ref\":\"calc-001\"")); + assertTrue(json.contains("\"input\"")); + } + + @Test + void testToolResponseSerialization() throws Exception { + ToolResponse response = new ToolResponse(); + response.setName("calculator"); + response.setRef("calc-001"); + response.setOutput(Map.of("result", 8)); + + String json = objectMapper.writeValueAsString(response); + + assertTrue(json.contains("\"name\":\"calculator\"")); + assertTrue(json.contains("\"output\"")); + assertTrue(json.contains("\"result\":8")); + } + + @Test + void testComplexMessageSerialization() throws Exception { + Message message = new Message(Role.USER, Arrays.asList(Part.text("Look at this image: "), + Part.media("image/png", "http://example.com/img.png"))); + message.setMetadata(Map.of("timestamp", "2025-01-01T12:00:00Z")); + + String json = objectMapper.writeValueAsString(message); + Message deserialized = objectMapper.readValue(json, Message.class); + + assertEquals(Role.USER, deserialized.getRole()); + assertEquals(2, deserialized.getContent().size()); + assertNotNull(deserialized.getMetadata()); + } + + @Test + void testNullValuesExcluded() throws Exception { + Part part = new Part(); + part.setText("Only text"); + + String json = objectMapper.writeValueAsString(part); + + assertTrue(json.contains("\"text\":\"Only text\"")); + // Null values should be excluded with @JsonInclude(NON_NULL) + assertFalse(json.contains("\"media\":null")); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/MediaTest.java b/java/ai/src/test/java/com/google/genkit/ai/MediaTest.java new file mode 100644 index 0000000000..0b81beb7db --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/MediaTest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for Media. + */ +class MediaTest { + + @Test + void testDefaultConstructor() { + Media media = new Media(); + + assertNull(media.getContentType()); + assertNull(media.getUrl()); + } + + @Test + void testConstructorWithParameters() { + Media media = new Media("image/png", "http://example.com/image.png"); + + assertEquals("image/png", media.getContentType()); + assertEquals("http://example.com/image.png", media.getUrl()); + } + + @Test + void testSetContentType() { + Media media = new Media(); + media.setContentType("video/mp4"); + + assertEquals("video/mp4", media.getContentType()); + } + + @Test + void testSetUrl() { + Media media = new Media(); + media.setUrl("http://example.com/video.mp4"); + + assertEquals("http://example.com/video.mp4", media.getUrl()); + } + + @Test + void testCommonMediaTypes() { + Media png = new Media("image/png", "http://example.com/img.png"); + Media jpeg = new Media("image/jpeg", "http://example.com/img.jpg"); + Media gif = new Media("image/gif", "http://example.com/img.gif"); + Media webp = new Media("image/webp", "http://example.com/img.webp"); + Media pdf = new Media("application/pdf", "http://example.com/doc.pdf"); + Media mp3 = new Media("audio/mpeg", "http://example.com/audio.mp3"); + Media mp4 = new Media("video/mp4", "http://example.com/video.mp4"); + + assertEquals("image/png", png.getContentType()); + assertEquals("image/jpeg", jpeg.getContentType()); + assertEquals("image/gif", gif.getContentType()); + assertEquals("image/webp", webp.getContentType()); + assertEquals("application/pdf", pdf.getContentType()); + assertEquals("audio/mpeg", mp3.getContentType()); + assertEquals("video/mp4", mp4.getContentType()); + } + + @Test + void testDataUrl() { + String dataUrl = ""; + Media media = new Media("image/png", dataUrl); + + assertEquals("image/png", media.getContentType()); + assertTrue(media.getUrl().startsWith("data:")); + } + + @Test + void testHttpsUrl() { + Media media = new Media("image/png", "https://secure.example.com/image.png"); + + assertTrue(media.getUrl().startsWith("https://")); + } + + @Test + void testRelativeUrl() { + Media media = new Media("image/png", "/images/photo.png"); + + assertEquals("/images/photo.png", media.getUrl()); + } + + @Test + void testNullValues() { + Media media = new Media(null, null); + + assertNull(media.getContentType()); + assertNull(media.getUrl()); + } + + @Test + void testEmptyValues() { + Media media = new Media("", ""); + + assertEquals("", media.getContentType()); + assertEquals("", media.getUrl()); + } + + @Test + void testMutableProperties() { + Media media = new Media("image/png", "http://old.url.com/img.png"); + + media.setContentType("image/jpeg"); + media.setUrl("http://new.url.com/img.jpg"); + + assertEquals("image/jpeg", media.getContentType()); + assertEquals("http://new.url.com/img.jpg", media.getUrl()); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/MessageTest.java b/java/ai/src/test/java/com/google/genkit/ai/MessageTest.java new file mode 100644 index 0000000000..81cc503960 --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/MessageTest.java @@ -0,0 +1,167 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for Message. + */ +class MessageTest { + + @Test + void testDefaultConstructor() { + Message message = new Message(); + + assertNull(message.getRole()); + assertNotNull(message.getContent()); + assertTrue(message.getContent().isEmpty()); + } + + @Test + void testConstructorWithRoleAndContent() { + List content = Collections.singletonList(Part.text("Hello")); + Message message = new Message(Role.USER, content); + + assertEquals(Role.USER, message.getRole()); + assertEquals(1, message.getContent().size()); + assertEquals("Hello", message.getContent().get(0).getText()); + } + + @Test + void testUserMessage() { + Message message = Message.user("Hello, world!"); + + assertEquals(Role.USER, message.getRole()); + assertEquals(1, message.getContent().size()); + assertEquals("Hello, world!", message.getText()); + } + + @Test + void testSystemMessage() { + Message message = Message.system("You are a helpful assistant."); + + assertEquals(Role.SYSTEM, message.getRole()); + assertEquals(1, message.getContent().size()); + assertEquals("You are a helpful assistant.", message.getText()); + } + + @Test + void testModelMessage() { + Message message = Message.model("I can help you with that."); + + assertEquals(Role.MODEL, message.getRole()); + assertEquals(1, message.getContent().size()); + assertEquals("I can help you with that.", message.getText()); + } + + @Test + void testToolMessage() { + List content = Arrays.asList(Part.text("Tool response 1"), Part.text("Tool response 2")); + Message message = Message.tool(content); + + assertEquals(Role.TOOL, message.getRole()); + assertEquals(2, message.getContent().size()); + } + + @Test + void testGetText() { + List content = Arrays.asList(Part.text("Hello, "), Part.text("world!")); + Message message = new Message(Role.USER, content); + + assertEquals("Hello, world!", message.getText()); + } + + @Test + void testGetTextWithNullContent() { + Message message = new Message(); + message.setContent(null); + + assertEquals("", message.getText()); + } + + @Test + void testGetTextWithEmptyContent() { + Message message = new Message(Role.USER, Collections.emptyList()); + + assertEquals("", message.getText()); + } + + @Test + void testSetRole() { + Message message = new Message(); + message.setRole(Role.MODEL); + + assertEquals(Role.MODEL, message.getRole()); + } + + @Test + void testSetContent() { + Message message = new Message(); + List content = Collections.singletonList(Part.text("New content")); + message.setContent(content); + + assertEquals(1, message.getContent().size()); + assertEquals("New content", message.getContent().get(0).getText()); + } + + @Test + void testSetMetadata() { + Message message = new Message(); + Map metadata = new HashMap<>(); + metadata.put("key", "value"); + message.setMetadata(metadata); + + assertEquals(metadata, message.getMetadata()); + } + + @Test + void testConstructorCopiesContentList() { + List original = Arrays.asList(Part.text("Hello")); + Message message = new Message(Role.USER, original); + + // Modifying original list should not affect message + assertNotSame(original, message.getContent()); + } + + @Test + void testNullContentInConstructor() { + Message message = new Message(Role.USER, null); + + assertNotNull(message.getContent()); + assertTrue(message.getContent().isEmpty()); + } + + @Test + void testGetTextSkipsNonTextParts() { + List content = Arrays.asList(Part.text("Hello"), Part.media("image/png", "http://example.com/image.png"), + Part.text(" World")); + Message message = new Message(Role.USER, content); + + assertEquals("Hello World", message.getText()); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/PartTest.java b/java/ai/src/test/java/com/google/genkit/ai/PartTest.java new file mode 100644 index 0000000000..e4c824a91a --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/PartTest.java @@ -0,0 +1,199 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for Part. + */ +class PartTest { + + @Test + void testDefaultConstructor() { + Part part = new Part(); + + assertNull(part.getText()); + assertNull(part.getMedia()); + assertNull(part.getToolRequest()); + assertNull(part.getToolResponse()); + assertNull(part.getData()); + assertNull(part.getMetadata()); + } + + @Test + void testTextPart() { + Part part = Part.text("Hello, world!"); + + assertEquals("Hello, world!", part.getText()); + assertNull(part.getMedia()); + assertNull(part.getToolRequest()); + assertNull(part.getToolResponse()); + } + + @Test + void testTextPartWithEmptyString() { + Part part = Part.text(""); + + assertEquals("", part.getText()); + } + + @Test + void testTextPartWithNull() { + Part part = Part.text(null); + + assertNull(part.getText()); + } + + @Test + void testMediaPart() { + Part part = Part.media("image/png", "http://example.com/image.png"); + + assertNull(part.getText()); + assertNotNull(part.getMedia()); + assertEquals("image/png", part.getMedia().getContentType()); + assertEquals("http://example.com/image.png", part.getMedia().getUrl()); + } + + @Test + void testMediaPartWithDifferentTypes() { + Part jpegPart = Part.media("image/jpeg", "http://example.com/photo.jpg"); + Part pdfPart = Part.media("application/pdf", "http://example.com/doc.pdf"); + Part audioPart = Part.media("audio/mp3", "http://example.com/sound.mp3"); + + assertEquals("image/jpeg", jpegPart.getMedia().getContentType()); + assertEquals("application/pdf", pdfPart.getMedia().getContentType()); + assertEquals("audio/mp3", audioPart.getMedia().getContentType()); + } + + @Test + void testToolRequestPart() { + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName("calculator"); + toolRequest.setRef("ref-123"); + + Part part = Part.toolRequest(toolRequest); + + assertNull(part.getText()); + assertNull(part.getMedia()); + assertNotNull(part.getToolRequest()); + assertEquals("calculator", part.getToolRequest().getName()); + assertEquals("ref-123", part.getToolRequest().getRef()); + } + + @Test + void testToolResponsePart() { + ToolResponse toolResponse = new ToolResponse(); + toolResponse.setName("calculator"); + toolResponse.setRef("ref-123"); + toolResponse.setOutput(Map.of("result", 42)); + + Part part = Part.toolResponse(toolResponse); + + assertNull(part.getText()); + assertNull(part.getMedia()); + assertNull(part.getToolRequest()); + assertNotNull(part.getToolResponse()); + assertEquals("calculator", part.getToolResponse().getName()); + } + + @Test + void testDataPart() { + Map data = Map.of("key", "value", "number", 42); + Part part = Part.data(data); + + assertNull(part.getText()); + assertNotNull(part.getData()); + assertEquals(data, part.getData()); + } + + @Test + void testSetText() { + Part part = new Part(); + part.setText("New text"); + + assertEquals("New text", part.getText()); + } + + @Test + void testSetMedia() { + Part part = new Part(); + Media media = new Media("video/mp4", "http://example.com/video.mp4"); + part.setMedia(media); + + assertNotNull(part.getMedia()); + assertEquals("video/mp4", part.getMedia().getContentType()); + } + + @Test + void testSetToolRequest() { + Part part = new Part(); + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName("search"); + part.setToolRequest(toolRequest); + + assertNotNull(part.getToolRequest()); + assertEquals("search", part.getToolRequest().getName()); + } + + @Test + void testSetToolResponse() { + Part part = new Part(); + ToolResponse toolResponse = new ToolResponse(); + toolResponse.setName("search"); + part.setToolResponse(toolResponse); + + assertNotNull(part.getToolResponse()); + assertEquals("search", part.getToolResponse().getName()); + } + + @Test + void testSetData() { + Part part = new Part(); + Object data = Map.of("field", "value"); + part.setData(data); + + assertEquals(data, part.getData()); + } + + @Test + void testSetMetadata() { + Part part = new Part(); + Map metadata = Map.of("timestamp", "2025-01-01"); + part.setMetadata(metadata); + + assertEquals(metadata, part.getMetadata()); + } + + @Test + void testPartWithMultipleTypes() { + // A part should be able to have multiple types set, though typically only one + // is used + Part part = new Part(); + part.setText("text content"); + part.setMedia(new Media("image/png", "http://example.com/img.png")); + + assertEquals("text content", part.getText()); + assertNotNull(part.getMedia()); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/ResumeOptionsTest.java b/java/ai/src/test/java/com/google/genkit/ai/ResumeOptionsTest.java new file mode 100644 index 0000000000..94bed270f3 --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/ResumeOptionsTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for ResumeOptions. + */ +class ResumeOptionsTest { + + @Test + void testBuilderWithRespond() { + ToolResponse response1 = new ToolResponse("ref-1", "tool1", Map.of("result", "ok")); + ToolResponse response2 = new ToolResponse("ref-2", "tool2", Map.of("result", "done")); + + ResumeOptions options = ResumeOptions.builder().respond(List.of(response1, response2)).build(); + + assertNotNull(options.getRespond()); + assertEquals(2, options.getRespond().size()); + assertEquals("ref-1", options.getRespond().get(0).getRef()); + assertEquals("ref-2", options.getRespond().get(1).getRef()); + assertNull(options.getRestart()); + } + + @Test + void testBuilderWithRestart() { + ToolRequest request1 = new ToolRequest(); + request1.setName("tool1"); + request1.setRef("ref-1"); + + ToolRequest request2 = new ToolRequest(); + request2.setName("tool2"); + request2.setRef("ref-2"); + + ResumeOptions options = ResumeOptions.builder().restart(List.of(request1, request2)).build(); + + assertNull(options.getRespond()); + assertNotNull(options.getRestart()); + assertEquals(2, options.getRestart().size()); + assertEquals("tool1", options.getRestart().get(0).getName()); + assertEquals("tool2", options.getRestart().get(1).getName()); + } + + @Test + void testBuilderWithBothRespondAndRestart() { + ToolResponse response = new ToolResponse("ref-1", "tool1", "result"); + ToolRequest request = new ToolRequest(); + request.setName("tool2"); + request.setRef("ref-2"); + + ResumeOptions options = ResumeOptions.builder().respond(List.of(response)).restart(List.of(request)).build(); + + assertNotNull(options.getRespond()); + assertEquals(1, options.getRespond().size()); + assertNotNull(options.getRestart()); + assertEquals(1, options.getRestart().size()); + } + + @Test + void testDefaultConstructor() { + ResumeOptions options = new ResumeOptions(); + + assertNull(options.getRespond()); + assertNull(options.getRestart()); + } + + @Test + void testSetters() { + ResumeOptions options = new ResumeOptions(); + + ToolResponse response = new ToolResponse("ref", "tool", "output"); + ToolRequest request = new ToolRequest(); + request.setName("tool"); + + options.setRespond(List.of(response)); + options.setRestart(List.of(request)); + + assertEquals(1, options.getRespond().size()); + assertEquals(1, options.getRestart().size()); + } + + @Test + void testEmptyBuilder() { + ResumeOptions options = ResumeOptions.builder().build(); + + assertNull(options.getRespond()); + assertNull(options.getRestart()); + } + + @Test + void testRespondWithEmptyList() { + ResumeOptions options = ResumeOptions.builder().respond(new ArrayList<>()).build(); + + assertNotNull(options.getRespond()); + assertTrue(options.getRespond().isEmpty()); + } + + @Test + void testRestartWithEmptyList() { + ResumeOptions options = ResumeOptions.builder().restart(new ArrayList<>()).build(); + + assertNotNull(options.getRestart()); + assertTrue(options.getRestart().isEmpty()); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/RoleTest.java b/java/ai/src/test/java/com/google/genkit/ai/RoleTest.java new file mode 100644 index 0000000000..f7aad14157 --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/RoleTest.java @@ -0,0 +1,141 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Unit tests for Role. + */ +class RoleTest { + + @Test + void testUserRole() { + assertEquals("user", Role.USER.getValue()); + assertEquals("user", Role.USER.toString()); + } + + @Test + void testModelRole() { + assertEquals("model", Role.MODEL.getValue()); + assertEquals("model", Role.MODEL.toString()); + } + + @Test + void testSystemRole() { + assertEquals("system", Role.SYSTEM.getValue()); + assertEquals("system", Role.SYSTEM.toString()); + } + + @Test + void testToolRole() { + assertEquals("tool", Role.TOOL.getValue()); + assertEquals("tool", Role.TOOL.toString()); + } + + @Test + void testFromValueUser() { + assertEquals(Role.USER, Role.fromValue("user")); + } + + @Test + void testFromValueModel() { + assertEquals(Role.MODEL, Role.fromValue("model")); + } + + @Test + void testFromValueSystem() { + assertEquals(Role.SYSTEM, Role.fromValue("system")); + } + + @Test + void testFromValueTool() { + assertEquals(Role.TOOL, Role.fromValue("tool")); + } + + @Test + void testFromValueCaseInsensitive() { + assertEquals(Role.USER, Role.fromValue("USER")); + assertEquals(Role.USER, Role.fromValue("User")); + assertEquals(Role.MODEL, Role.fromValue("MODEL")); + assertEquals(Role.MODEL, Role.fromValue("Model")); + assertEquals(Role.SYSTEM, Role.fromValue("SYSTEM")); + assertEquals(Role.SYSTEM, Role.fromValue("System")); + assertEquals(Role.TOOL, Role.fromValue("TOOL")); + assertEquals(Role.TOOL, Role.fromValue("Tool")); + } + + @Test + void testFromValueAssistantMapsToModel() { + // For compatibility with other APIs that use "assistant" role + assertEquals(Role.MODEL, Role.fromValue("assistant")); + assertEquals(Role.MODEL, Role.fromValue("ASSISTANT")); + assertEquals(Role.MODEL, Role.fromValue("Assistant")); + } + + @Test + void testFromValueUnknown() { + assertThrows(IllegalArgumentException.class, () -> Role.fromValue("unknown")); + } + + @Test + void testFromValueNull() { + assertThrows(IllegalArgumentException.class, () -> Role.fromValue(null)); + } + + @Test + void testFromValueEmpty() { + assertThrows(IllegalArgumentException.class, () -> Role.fromValue("")); + } + + @ParameterizedTest + @ValueSource(strings = {"bot", "ai", "human", "admin", "moderator"}) + void testFromValueInvalidValues(String value) { + assertThrows(IllegalArgumentException.class, () -> Role.fromValue(value)); + } + + @Test + void testEnumValues() { + Role[] roles = Role.values(); + assertEquals(4, roles.length); + } + + @Test + void testEnumValueOf() { + assertEquals(Role.USER, Role.valueOf("USER")); + assertEquals(Role.MODEL, Role.valueOf("MODEL")); + assertEquals(Role.SYSTEM, Role.valueOf("SYSTEM")); + assertEquals(Role.TOOL, Role.valueOf("TOOL")); + } + + @Test + void testAllRolesHaveUniqueValues() { + Role[] roles = Role.values(); + for (int i = 0; i < roles.length; i++) { + for (int j = i + 1; j < roles.length; j++) { + assertNotEquals(roles[i].getValue(), roles[j].getValue(), + String.format("Roles %s and %s have same value", roles[i].name(), roles[j].name())); + } + } + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/ToolInterruptExceptionTest.java b/java/ai/src/test/java/com/google/genkit/ai/ToolInterruptExceptionTest.java new file mode 100644 index 0000000000..0b03951f4e --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/ToolInterruptExceptionTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for ToolInterruptException. + */ +class ToolInterruptExceptionTest { + + @Test + void testConstructorWithMetadata() { + Map metadata = new HashMap<>(); + metadata.put("key1", "value1"); + metadata.put("key2", 42); + + ToolInterruptException exception = new ToolInterruptException(metadata); + + assertNotNull(exception.getMetadata()); + assertEquals(2, exception.getMetadata().size()); + assertEquals("value1", exception.getMetadata().get("key1")); + assertEquals(42, exception.getMetadata().get("key2")); + assertEquals("Tool execution interrupted", exception.getMessage()); + } + + @Test + void testConstructorWithNullMetadata() { + ToolInterruptException exception = new ToolInterruptException(null); + + assertNotNull(exception.getMetadata()); + assertTrue(exception.getMetadata().isEmpty()); + } + + @Test + void testConstructorWithEmptyMetadata() { + ToolInterruptException exception = new ToolInterruptException(new HashMap<>()); + + assertNotNull(exception.getMetadata()); + assertTrue(exception.getMetadata().isEmpty()); + } + + @Test + void testMetadataIsImmutableCopy() { + Map metadata = new HashMap<>(); + metadata.put("key", "original"); + + ToolInterruptException exception = new ToolInterruptException(metadata); + + // Modify original + metadata.put("key", "modified"); + metadata.put("newKey", "newValue"); + + // Exception's metadata should not change + assertEquals("original", exception.getMetadata().get("key")); + assertFalse(exception.getMetadata().containsKey("newKey")); + } + + @Test + void testCanCatchAsException() { + Map metadata = Map.of("action", "confirm"); + + try { + throw new ToolInterruptException(metadata); + } catch (ToolInterruptException e) { + assertEquals("confirm", e.getMetadata().get("action")); + } + } + + @Test + void testIsRuntimeException() { + ToolInterruptException exception = new ToolInterruptException(new HashMap<>()); + assertTrue(exception instanceof RuntimeException); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/ToolTest.java b/java/ai/src/test/java/com/google/genkit/ai/ToolTest.java new file mode 100644 index 0000000000..86123790c8 --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/ToolTest.java @@ -0,0 +1,181 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for ToolRequest and ToolResponse. + */ +class ToolTest { + + @Test + void testToolRequestDefaultConstructor() { + ToolRequest request = new ToolRequest(); + + assertNull(request.getName()); + assertNull(request.getRef()); + assertNull(request.getInput()); + } + + @Test + void testToolRequestSetName() { + ToolRequest request = new ToolRequest(); + request.setName("calculator"); + + assertEquals("calculator", request.getName()); + } + + @Test + void testToolRequestSetRef() { + ToolRequest request = new ToolRequest(); + request.setRef("ref-abc-123"); + + assertEquals("ref-abc-123", request.getRef()); + } + + @Test + void testToolRequestSetInput() { + ToolRequest request = new ToolRequest(); + Map input = Map.of("a", 5, "b", 3, "operation", "add"); + request.setInput(input); + + assertEquals(input, request.getInput()); + @SuppressWarnings("unchecked") + Map inputMap = (Map) request.getInput(); + assertEquals(5, inputMap.get("a")); + assertEquals(3, inputMap.get("b")); + assertEquals("add", inputMap.get("operation")); + } + + @Test + void testToolRequestCompleteSetup() { + ToolRequest request = new ToolRequest(); + request.setName("search"); + request.setRef("search-001"); + request.setInput(Map.of("query", "Genkit documentation")); + + assertEquals("search", request.getName()); + assertEquals("search-001", request.getRef()); + @SuppressWarnings("unchecked") + Map inputMap = (Map) request.getInput(); + assertEquals("Genkit documentation", inputMap.get("query")); + } + + @Test + void testToolResponseDefaultConstructor() { + ToolResponse response = new ToolResponse(); + + assertNull(response.getName()); + assertNull(response.getRef()); + assertNull(response.getOutput()); + } + + @Test + void testToolResponseSetName() { + ToolResponse response = new ToolResponse(); + response.setName("calculator"); + + assertEquals("calculator", response.getName()); + } + + @Test + void testToolResponseSetRef() { + ToolResponse response = new ToolResponse(); + response.setRef("ref-abc-123"); + + assertEquals("ref-abc-123", response.getRef()); + } + + @Test + void testToolResponseSetOutput() { + ToolResponse response = new ToolResponse(); + Map output = Map.of("result", 8, "success", true); + response.setOutput(output); + + assertEquals(output, response.getOutput()); + @SuppressWarnings("unchecked") + Map outputMap = (Map) response.getOutput(); + assertEquals(8, outputMap.get("result")); + assertEquals(true, outputMap.get("success")); + } + + @Test + void testToolResponseCompleteSetup() { + ToolResponse response = new ToolResponse(); + response.setName("calculator"); + response.setRef("calc-001"); + response.setOutput(Map.of("result", 42)); + + assertEquals("calculator", response.getName()); + assertEquals("calc-001", response.getRef()); + @SuppressWarnings("unchecked") + Map outputMap = (Map) response.getOutput(); + assertEquals(42, outputMap.get("result")); + } + + @Test + void testToolRequestAndResponseMatching() { + String toolName = "weather"; + String ref = "weather-req-001"; + + ToolRequest request = new ToolRequest(); + request.setName(toolName); + request.setRef(ref); + request.setInput(Map.of("location", "San Francisco")); + + ToolResponse response = new ToolResponse(); + response.setName(toolName); + response.setRef(ref); + response.setOutput(Map.of("temperature", 72, "condition", "sunny")); + + // Request and response should have matching name and ref + assertEquals(request.getName(), response.getName()); + assertEquals(request.getRef(), response.getRef()); + } + + @Test + void testToolRequestWithComplexInput() { + ToolRequest request = new ToolRequest(); + request.setName("database_query"); + request.setInput(Map.of("table", "users", "fields", new String[]{"id", "name", "email"}, "limit", 10, "where", + Map.of("active", true))); + + assertEquals("database_query", request.getName()); + @SuppressWarnings("unchecked") + Map inputMap = (Map) request.getInput(); + assertEquals("users", inputMap.get("table")); + assertEquals(10, inputMap.get("limit")); + } + + @Test + void testToolResponseWithNullOutput() { + ToolResponse response = new ToolResponse(); + response.setName("void_operation"); + response.setRef("void-001"); + response.setOutput(null); + + assertEquals("void_operation", response.getName()); + assertNull(response.getOutput()); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/session/ChatOptionsTest.java b/java/ai/src/test/java/com/google/genkit/ai/session/ChatOptionsTest.java new file mode 100644 index 0000000000..eb3f08224e --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/session/ChatOptionsTest.java @@ -0,0 +1,272 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.google.genkit.ai.GenerationConfig; +import com.google.genkit.ai.OutputConfig; + +/** Unit tests for ChatOptions. */ +class ChatOptionsTest { + + @Test + void testDefaultConstructor() { + ChatOptions options = new ChatOptions<>(); + + assertNull(options.getModel()); + assertNull(options.getSystem()); + assertNull(options.getConfig()); + assertNull(options.getTools()); + assertNull(options.getOutput()); + assertNull(options.getContext()); + assertNull(options.getMaxTurns()); + } + + @Test + void testSetAndGetModel() { + ChatOptions options = new ChatOptions<>(); + options.setModel("gpt-4"); + + assertEquals("gpt-4", options.getModel()); + } + + @Test + void testSetAndGetSystem() { + ChatOptions options = new ChatOptions<>(); + options.setSystem("You are a helpful assistant."); + + assertEquals("You are a helpful assistant.", options.getSystem()); + } + + @Test + void testSetAndGetConfig() { + ChatOptions options = new ChatOptions<>(); + GenerationConfig config = GenerationConfig.builder().temperature(0.7).build(); + options.setConfig(config); + + assertSame(config, options.getConfig()); + } + + @Test + void testSetAndGetOutput() { + ChatOptions options = new ChatOptions<>(); + OutputConfig output = new OutputConfig(); + options.setOutput(output); + + assertSame(output, options.getOutput()); + } + + @Test + void testSetAndGetContext() { + ChatOptions options = new ChatOptions<>(); + Map context = new HashMap<>(); + context.put("key1", "value1"); + context.put("key2", 42); + options.setContext(context); + + assertNotNull(options.getContext()); + assertEquals("value1", options.getContext().get("key1")); + assertEquals(42, options.getContext().get("key2")); + } + + @Test + void testSetAndGetMaxTurns() { + ChatOptions options = new ChatOptions<>(); + options.setMaxTurns(10); + + assertEquals(10, options.getMaxTurns()); + } + + @Test + void testBuilderEmpty() { + ChatOptions options = ChatOptions.builder().build(); + + assertNull(options.getModel()); + assertNull(options.getSystem()); + assertNull(options.getConfig()); + assertNull(options.getTools()); + assertNull(options.getOutput()); + assertNull(options.getContext()); + assertNull(options.getMaxTurns()); + } + + @Test + void testBuilderWithModel() { + ChatOptions options = ChatOptions.builder().model("claude-3").build(); + + assertEquals("claude-3", options.getModel()); + } + + @Test + void testBuilderWithSystem() { + ChatOptions options = ChatOptions.builder().system("You are a coding assistant.").build(); + + assertEquals("You are a coding assistant.", options.getSystem()); + } + + @Test + void testBuilderWithConfig() { + GenerationConfig config = GenerationConfig.builder().temperature(0.5).maxOutputTokens(100).build(); + + ChatOptions options = ChatOptions.builder().config(config).build(); + + assertSame(config, options.getConfig()); + } + + @Test + void testBuilderWithOutput() { + OutputConfig output = new OutputConfig(); + + ChatOptions options = ChatOptions.builder().output(output).build(); + + assertSame(output, options.getOutput()); + } + + @Test + void testBuilderWithContext() { + Map context = new HashMap<>(); + context.put("userId", "user123"); + + ChatOptions options = ChatOptions.builder().context(context).build(); + + assertNotNull(options.getContext()); + assertEquals("user123", options.getContext().get("userId")); + } + + @Test + void testBuilderWithMaxTurns() { + ChatOptions options = ChatOptions.builder().maxTurns(5).build(); + + assertEquals(5, options.getMaxTurns()); + } + + @Test + void testBuilderWithAllOptions() { + GenerationConfig config = GenerationConfig.builder().temperature(0.7).build(); + OutputConfig output = new OutputConfig(); + Map context = new HashMap<>(); + context.put("key", "value"); + + ChatOptions options = ChatOptions.builder().model("gemini-pro") + .system("You are an expert programmer.").config(config).output(output).context(context).maxTurns(20) + .build(); + + assertEquals("gemini-pro", options.getModel()); + assertEquals("You are an expert programmer.", options.getSystem()); + assertSame(config, options.getConfig()); + assertSame(output, options.getOutput()); + assertEquals("value", options.getContext().get("key")); + assertEquals(20, options.getMaxTurns()); + } + + @Test + void testBuilderChaining() { + ChatOptions.Builder builder = ChatOptions.builder(); + + // Test that builder methods return the builder for chaining + assertSame(builder, builder.model("model")); + assertSame(builder, builder.system("system")); + assertSame(builder, builder.config(GenerationConfig.builder().build())); + assertSame(builder, builder.output(new OutputConfig())); + assertSame(builder, builder.context(new HashMap<>())); + assertSame(builder, builder.maxTurns(10)); + } + + @Test + void testLongSystemPrompt() { + String longPrompt = "You are a helpful assistant. " + "You should always be polite and professional. " + + "Never provide harmful or misleading information. " + "If you don't know something, say so. " + + "Always cite your sources when possible."; + + ChatOptions options = ChatOptions.builder().system(longPrompt).build(); + + assertEquals(longPrompt, options.getSystem()); + } + + @Test + void testMultipleBuilds() { + ChatOptions.Builder builder = ChatOptions.builder().model("model1").maxTurns(5); + + ChatOptions options1 = builder.build(); + assertEquals("model1", options1.getModel()); + assertEquals(5, options1.getMaxTurns()); + + // Modify and build again + builder.model("model2").maxTurns(10); + ChatOptions options2 = builder.build(); + + assertEquals("model2", options2.getModel()); + assertEquals(10, options2.getMaxTurns()); + } + + @Test + void testWithComplexState() { + // Test with a custom state type + ChatOptions options = ChatOptions.builder().model("test-model") + .system("Test system prompt").maxTurns(15).build(); + + assertEquals("test-model", options.getModel()); + assertEquals("Test system prompt", options.getSystem()); + assertEquals(15, options.getMaxTurns()); + } + + @Test + void testEmptyContext() { + ChatOptions options = ChatOptions.builder().context(new HashMap<>()).build(); + + assertNotNull(options.getContext()); + assertTrue(options.getContext().isEmpty()); + } + + @Test + void testZeroMaxTurns() { + ChatOptions options = ChatOptions.builder().maxTurns(0).build(); + + assertEquals(0, options.getMaxTurns()); + } + + /** Simple test state class. */ + static class TestState { + private final String name; + private final int value; + + TestState(String name, int value) { + this.name = name; + this.value = value; + } + + String getName() { + return name; + } + + int getValue() { + return value; + } + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/session/InMemorySessionStoreTest.java b/java/ai/src/test/java/com/google/genkit/ai/session/InMemorySessionStoreTest.java new file mode 100644 index 0000000000..3056464382 --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/session/InMemorySessionStoreTest.java @@ -0,0 +1,211 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.google.genkit.ai.Message; + +/** Unit tests for InMemorySessionStore. */ +class InMemorySessionStoreTest { + + private InMemorySessionStore store; + + @BeforeEach + void setUp() { + store = new InMemorySessionStore<>(); + } + + @Test + void testSaveAndGet() throws ExecutionException, InterruptedException { + SessionData data = new SessionData<>("session-1", "test-state"); + data.setThread("main", List.of(Message.user("Hello"))); + + store.save("session-1", data).get(); + + SessionData retrieved = store.get("session-1").get(); + + assertNotNull(retrieved); + assertEquals("session-1", retrieved.getId()); + assertEquals("test-state", retrieved.getState()); + assertEquals(1, retrieved.getThread("main").size()); + } + + @Test + void testGetNonexistentSession() throws ExecutionException, InterruptedException { + SessionData result = store.get("nonexistent").get(); + + assertNull(result); + } + + @Test + void testDelete() throws ExecutionException, InterruptedException { + SessionData data = new SessionData<>("to-delete", "state"); + store.save("to-delete", data).get(); + + assertTrue(store.exists("to-delete").get()); + + store.delete("to-delete").get(); + + assertFalse(store.exists("to-delete").get()); + assertNull(store.get("to-delete").get()); + } + + @Test + void testDeleteNonexistentSession() throws ExecutionException, InterruptedException { + // Should not throw + assertDoesNotThrow(() -> store.delete("nonexistent").get()); + } + + @Test + void testExists() throws ExecutionException, InterruptedException { + assertFalse(store.exists("new-session").get()); + + store.save("new-session", new SessionData<>("new-session")).get(); + + assertTrue(store.exists("new-session").get()); + } + + @Test + void testSize() throws ExecutionException, InterruptedException { + assertEquals(0, store.size()); + + store.save("session-1", new SessionData<>("session-1")).get(); + assertEquals(1, store.size()); + + store.save("session-2", new SessionData<>("session-2")).get(); + assertEquals(2, store.size()); + + store.delete("session-1").get(); + assertEquals(1, store.size()); + } + + @Test + void testClear() throws ExecutionException, InterruptedException { + store.save("session-1", new SessionData<>("session-1")).get(); + store.save("session-2", new SessionData<>("session-2")).get(); + store.save("session-3", new SessionData<>("session-3")).get(); + + assertEquals(3, store.size()); + + store.clear(); + + assertEquals(0, store.size()); + assertNull(store.get("session-1").get()); + } + + @Test + void testOverwriteSession() throws ExecutionException, InterruptedException { + SessionData original = new SessionData<>("session", "original-state"); + store.save("session", original).get(); + + SessionData updated = new SessionData<>("session", "updated-state"); + store.save("session", updated).get(); + + SessionData retrieved = store.get("session").get(); + assertEquals("updated-state", retrieved.getState()); + assertEquals(1, store.size()); + } + + @Test + void testMultipleSessions() throws ExecutionException, InterruptedException { + for (int i = 0; i < 10; i++) { + SessionData data = new SessionData<>("session-" + i, "state-" + i); + store.save("session-" + i, data).get(); + } + + assertEquals(10, store.size()); + + for (int i = 0; i < 10; i++) { + SessionData retrieved = store.get("session-" + i).get(); + assertNotNull(retrieved); + assertEquals("state-" + i, retrieved.getState()); + } + } + + @Test + void testWithComplexState() throws ExecutionException, InterruptedException { + // Create a store with complex state type + InMemorySessionStore> complexStore = new InMemorySessionStore<>(); + + List state = new ArrayList<>(); + state.add(1); + state.add(2); + state.add(3); + + SessionData> data = new SessionData<>("complex", state); + complexStore.save("complex", data).get(); + + SessionData> retrieved = complexStore.get("complex").get(); + assertNotNull(retrieved); + assertEquals(3, retrieved.getState().size()); + assertEquals(List.of(1, 2, 3), retrieved.getState()); + } + + @Test + void testAsyncOperations() throws ExecutionException, InterruptedException { + // Test that operations return proper CompletableFutures + CompletableFuture saveFuture = store.save("async-session", new SessionData<>("async-session")); + assertNotNull(saveFuture); + saveFuture.get(); // Should complete without exception + + CompletableFuture> getFuture = store.get("async-session"); + assertNotNull(getFuture); + SessionData result = getFuture.get(); + assertNotNull(result); + + CompletableFuture existsFuture = store.exists("async-session"); + assertNotNull(existsFuture); + assertTrue(existsFuture.get()); + + CompletableFuture deleteFuture = store.delete("async-session"); + assertNotNull(deleteFuture); + deleteFuture.get(); // Should complete without exception + } + + @Test + void testSessionDataWithThreads() throws ExecutionException, InterruptedException { + SessionData data = new SessionData<>("threaded-session", "state"); + + List mainThread = new ArrayList<>(); + mainThread.add(Message.user("Hello")); + mainThread.add(Message.model("Hi there!")); + data.setThread("main", mainThread); + + List sideThread = new ArrayList<>(); + sideThread.add(Message.user("Different conversation")); + data.setThread("side", sideThread); + + store.save("threaded-session", data).get(); + + SessionData retrieved = store.get("threaded-session").get(); + assertNotNull(retrieved); + assertEquals(2, retrieved.getThread("main").size()); + assertEquals(1, retrieved.getThread("side").size()); + assertEquals("Hello", retrieved.getThread("main").get(0).getText()); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/session/SessionContextTest.java b/java/ai/src/test/java/com/google/genkit/ai/session/SessionContextTest.java new file mode 100644 index 0000000000..0666e2d3a0 --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/session/SessionContextTest.java @@ -0,0 +1,243 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.google.genkit.core.Registry; + +/** + * Unit tests for SessionContext. + */ +class SessionContextTest { + + private Registry mockRegistry; + private InMemorySessionStore store; + + @BeforeEach + void setUp() { + // Ensure clean state before each test + SessionContext.clearSession(); + mockRegistry = mock(Registry.class); + store = new InMemorySessionStore<>(); + } + + @AfterEach + void tearDown() { + // Clean up after each test + SessionContext.clearSession(); + } + + /** + * Helper to create a test session. + */ + private Session createTestSession(String id) { + SessionData sessionData = new SessionData<>(id, "test-state"); + return new Session(mockRegistry, store, sessionData, () -> null, // We don't need actual chat for + // context + // tests + null // No agent registry needed for these tests + ); + } + + @Test + void testCurrentSessionThrowsWhenNotSet() { + assertThrows(SessionContext.SessionException.class, () -> SessionContext.currentSession()); + } + + @Test + void testGetCurrentSessionReturnsNullWhenNotSet() { + assertNull(SessionContext.getCurrentSession()); + } + + @Test + void testHasSessionReturnsFalseWhenNotSet() { + assertFalse(SessionContext.hasSession()); + } + + @Test + void testSetAndGetSession() { + Session session = createTestSession("test-id"); + + SessionContext.setSession(session); + + assertTrue(SessionContext.hasSession()); + assertSame(session, SessionContext.currentSession()); + assertSame(session, SessionContext.getCurrentSession()); + + SessionContext.clearSession(); + } + + @Test + void testRunWithSession() throws Exception { + Session session = createTestSession("test-session-id"); + + AtomicReference> capturedSession = new AtomicReference<>(); + + String result = SessionContext.runWithSession(session, () -> { + capturedSession.set(SessionContext.currentSession()); + return "test-result"; + }); + + assertEquals("test-result", result); + assertSame(session, capturedSession.get()); + // Session should be cleared after runWithSession + assertFalse(SessionContext.hasSession()); + } + + @Test + void testRunWithSessionRestoresPreviousSession() throws Exception { + Session outerSession = createTestSession("outer"); + Session innerSession = createTestSession("inner"); + + SessionContext.setSession(outerSession); + + String result = SessionContext.runWithSession(innerSession, () -> { + assertSame(innerSession, SessionContext.currentSession()); + return "done"; + }); + + // Outer session should be restored + assertSame(outerSession, SessionContext.currentSession()); + + SessionContext.clearSession(); + } + + @Test + void testRunWithSessionHandlesException() { + Session session = createTestSession("error-session"); + + assertThrows(RuntimeException.class, () -> { + SessionContext.runWithSession(session, () -> { + throw new RuntimeException("Test exception"); + }); + }); + + // Session should be cleared even after exception + assertFalse(SessionContext.hasSession()); + } + + @Test + void testThreadIsolation() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(2); + AtomicReference thread1SessionId = new AtomicReference<>(); + AtomicReference thread2SessionId = new AtomicReference<>(); + + Session session1 = createTestSession("session-1"); + Session session2 = createTestSession("session-2"); + + ExecutorService executor = Executors.newFixedThreadPool(2); + + executor.submit(() -> { + SessionContext.setSession(session1); + try { + Thread.sleep(50); // Allow time for overlap + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + thread1SessionId.set(SessionContext.currentSession().getId()); + latch.countDown(); + }); + + executor.submit(() -> { + SessionContext.setSession(session2); + try { + Thread.sleep(50); // Allow time for overlap + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + thread2SessionId.set(SessionContext.currentSession().getId()); + latch.countDown(); + }); + + assertTrue(latch.await(1, TimeUnit.SECONDS)); + executor.shutdown(); + + // Each thread should have its own session + assertEquals("session-1", thread1SessionId.get()); + assertEquals("session-2", thread2SessionId.get()); + } + + @Test + void testClearSession() { + Session session = createTestSession("clear-test"); + + SessionContext.setSession(session); + assertTrue(SessionContext.hasSession()); + + SessionContext.clearSession(); + assertFalse(SessionContext.hasSession()); + } + + @Test + void testNestedRunWithSession() throws Exception { + Session session1 = createTestSession("session-1"); + Session session2 = createTestSession("session-2"); + Session session3 = createTestSession("session-3"); + + AtomicReference level1 = new AtomicReference<>(); + AtomicReference level2 = new AtomicReference<>(); + AtomicReference level3 = new AtomicReference<>(); + AtomicReference afterLevel2 = new AtomicReference<>(); + + SessionContext.runWithSession(session1, () -> { + level1.set(SessionContext.currentSession().getId()); + + SessionContext.runWithSession(session2, () -> { + level2.set(SessionContext.currentSession().getId()); + + SessionContext.runWithSession(session3, () -> { + level3.set(SessionContext.currentSession().getId()); + return null; + }); + + afterLevel2.set(SessionContext.currentSession().getId()); + return null; + }); + + return null; + }); + + assertEquals("session-1", level1.get()); + assertEquals("session-2", level2.get()); + assertEquals("session-3", level3.get()); + assertEquals("session-2", afterLevel2.get()); + } + + @Test + void testRunWithSessionWithNullSession() throws Exception { + String result = SessionContext.runWithSession(null, () -> { + assertFalse(SessionContext.hasSession()); + return "result"; + }); + + assertEquals("result", result); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/session/SessionDataTest.java b/java/ai/src/test/java/com/google/genkit/ai/session/SessionDataTest.java new file mode 100644 index 0000000000..e911d4542d --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/session/SessionDataTest.java @@ -0,0 +1,216 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.google.genkit.ai.Message; + +/** Unit tests for SessionData. */ +class SessionDataTest { + + @Test + void testDefaultConstructor() { + SessionData data = new SessionData<>(); + + assertNull(data.getId()); + assertNull(data.getState()); + assertNotNull(data.getThreads()); + assertTrue(data.getThreads().isEmpty()); + } + + @Test + void testConstructorWithId() { + SessionData data = new SessionData<>("test-session-123"); + + assertEquals("test-session-123", data.getId()); + assertNull(data.getState()); + assertNotNull(data.getThreads()); + assertTrue(data.getThreads().isEmpty()); + } + + @Test + void testConstructorWithIdAndState() { + SessionData data = new SessionData<>("test-session", "initial-state"); + + assertEquals("test-session", data.getId()); + assertEquals("initial-state", data.getState()); + assertNotNull(data.getThreads()); + assertTrue(data.getThreads().isEmpty()); + } + + @Test + void testSetAndGetId() { + SessionData data = new SessionData<>(); + data.setId("my-session-id"); + + assertEquals("my-session-id", data.getId()); + } + + @Test + void testSetAndGetState() { + SessionData data = new SessionData<>(); + data.setState(42); + + assertEquals(42, data.getState()); + } + + @Test + void testSetAndGetThreads() { + SessionData data = new SessionData<>(); + + Map> threads = new HashMap<>(); + List messages = new ArrayList<>(); + messages.add(Message.user("Hello")); + threads.put("main", messages); + + data.setThreads(threads); + + assertEquals(1, data.getThreads().size()); + assertTrue(data.getThreads().containsKey("main")); + assertEquals(1, data.getThreads().get("main").size()); + } + + @Test + void testGetThread() { + SessionData data = new SessionData<>(); + + List messages = new ArrayList<>(); + messages.add(Message.user("Test message")); + data.setThread("test-thread", messages); + + List retrieved = data.getThread("test-thread"); + + assertNotNull(retrieved); + assertEquals(1, retrieved.size()); + assertEquals("Test message", retrieved.get(0).getText()); + } + + @Test + void testGetThreadReturnsNullForNonexistent() { + SessionData data = new SessionData<>(); + + assertNull(data.getThread("nonexistent")); + } + + @Test + void testGetOrCreateThread() { + SessionData data = new SessionData<>(); + + // First call should create the thread + List thread1 = data.getOrCreateThread("new-thread"); + assertNotNull(thread1); + assertTrue(thread1.isEmpty()); + + // Add a message + thread1.add(Message.user("Hello")); + + // Second call should return the same thread + List thread2 = data.getOrCreateThread("new-thread"); + assertEquals(1, thread2.size()); + assertEquals("Hello", thread2.get(0).getText()); + } + + @Test + void testSetThread() { + SessionData data = new SessionData<>(); + + List messages = new ArrayList<>(); + messages.add(Message.user("First")); + messages.add(Message.model("Second")); + + data.setThread("conversation", messages); + + List retrieved = data.getThread("conversation"); + assertEquals(2, retrieved.size()); + assertEquals("First", retrieved.get(0).getText()); + assertEquals("Second", retrieved.get(1).getText()); + } + + @Test + void testSetThreadCreatesDefensiveCopy() { + SessionData data = new SessionData<>(); + + List messages = new ArrayList<>(); + messages.add(Message.user("Original")); + data.setThread("thread", messages); + + // Modify original list + messages.add(Message.model("Added after")); + + // The stored thread should not be affected + List retrieved = data.getThread("thread"); + assertEquals(1, retrieved.size()); + } + + @Test + void testBuilder() { + SessionData data = SessionData.builder().id("builder-session").state("builder-state").build(); + + assertEquals("builder-session", data.getId()); + assertEquals("builder-state", data.getState()); + assertNotNull(data.getThreads()); + } + + @Test + void testBuilderWithThreads() { + Map> threads = new HashMap<>(); + List mainThread = new ArrayList<>(); + mainThread.add(Message.user("Hello")); + threads.put("main", mainThread); + + SessionData data = SessionData.builder().id("session").threads(threads).build(); + + assertEquals(1, data.getThreads().size()); + assertTrue(data.getThreads().containsKey("main")); + } + + @Test + void testBuilderAddThread() { + List messages = new ArrayList<>(); + messages.add(Message.system("System prompt")); + + SessionData data = SessionData.builder().id("session").thread("custom", messages).build(); + + assertNotNull(data.getThread("custom")); + assertEquals(1, data.getThread("custom").size()); + } + + @Test + void testWithComplexState() { + // Test with a complex state object + Map complexState = new HashMap<>(); + complexState.put("userName", "Alice"); + complexState.put("preferences", Map.of("theme", "dark", "language", "en")); + complexState.put("messageCount", 5); + + SessionData> data = new SessionData<>("complex-session", complexState); + + assertEquals("complex-session", data.getId()); + assertEquals("Alice", data.getState().get("userName")); + assertEquals(5, data.getState().get("messageCount")); + } +} diff --git a/java/ai/src/test/java/com/google/genkit/ai/session/SessionOptionsTest.java b/java/ai/src/test/java/com/google/genkit/ai/session/SessionOptionsTest.java new file mode 100644 index 0000000000..2f458fd084 --- /dev/null +++ b/java/ai/src/test/java/com/google/genkit/ai/session/SessionOptionsTest.java @@ -0,0 +1,147 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.session; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** Unit tests for SessionOptions. */ +class SessionOptionsTest { + + @Test + void testDefaultConstructor() { + SessionOptions options = new SessionOptions<>(); + + assertNull(options.getStore()); + assertNull(options.getInitialState()); + assertNull(options.getSessionId()); + } + + @Test + void testSetAndGetStore() { + SessionOptions options = new SessionOptions<>(); + InMemorySessionStore store = new InMemorySessionStore<>(); + + options.setStore(store); + + assertSame(store, options.getStore()); + } + + @Test + void testSetAndGetInitialState() { + SessionOptions options = new SessionOptions<>(); + options.setInitialState(42); + + assertEquals(42, options.getInitialState()); + } + + @Test + void testSetAndGetSessionId() { + SessionOptions options = new SessionOptions<>(); + options.setSessionId("custom-session-id"); + + assertEquals("custom-session-id", options.getSessionId()); + } + + @Test + void testBuilderEmpty() { + SessionOptions options = SessionOptions.builder().build(); + + assertNull(options.getStore()); + assertNull(options.getInitialState()); + assertNull(options.getSessionId()); + } + + @Test + void testBuilderWithStore() { + InMemorySessionStore store = new InMemorySessionStore<>(); + + SessionOptions options = SessionOptions.builder().store(store).build(); + + assertSame(store, options.getStore()); + } + + @Test + void testBuilderWithInitialState() { + SessionOptions options = SessionOptions.builder().initialState("initial-value").build(); + + assertEquals("initial-value", options.getInitialState()); + } + + @Test + void testBuilderWithSessionId() { + SessionOptions options = SessionOptions.builder().sessionId("my-session-123").build(); + + assertEquals("my-session-123", options.getSessionId()); + } + + @Test + void testBuilderWithAllOptions() { + InMemorySessionStore store = new InMemorySessionStore<>(); + + SessionOptions options = SessionOptions.builder().store(store).initialState("test-state") + .sessionId("session-456").build(); + + assertSame(store, options.getStore()); + assertEquals("test-state", options.getInitialState()); + assertEquals("session-456", options.getSessionId()); + } + + @Test + void testBuilderChaining() { + SessionOptions.Builder builder = SessionOptions.builder(); + + // Test that builder methods return the builder for chaining + assertSame(builder, builder.store(new InMemorySessionStore<>())); + assertSame(builder, builder.initialState("state")); + assertSame(builder, builder.sessionId("id")); + } + + @Test + void testWithComplexState() { + // Test with a custom state class + TestState state = new TestState("Alice", 25); + + SessionOptions options = SessionOptions.builder().initialState(state) + .sessionId("complex-state-session").build(); + + assertEquals("Alice", options.getInitialState().getName()); + assertEquals(25, options.getInitialState().getAge()); + } + + /** Simple test state class. */ + static class TestState { + private final String name; + private final int age; + + TestState(String name, int age) { + this.name = name; + this.age = age; + } + + String getName() { + return name; + } + + int getAge() { + return age; + } + } +} diff --git a/java/core/pom.xml b/java/core/pom.xml new file mode 100644 index 0000000000..9972259a7a --- /dev/null +++ b/java/core/pom.xml @@ -0,0 +1,99 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + + + genkit-core + jar + + Genkit Core + Core functionality for Genkit including actions, flows, tracing, and registry + + + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + + + org.slf4j + slf4j-api + + + + + io.opentelemetry + opentelemetry-api + + + io.opentelemetry + opentelemetry-sdk + + + io.opentelemetry + opentelemetry-sdk-trace + + + + + com.github.victools + jsonschema-generator + + + + + org.junit.jupiter + junit-jupiter + test + + + org.mockito + mockito-core + test + + + org.mockito + mockito-junit-jupiter + test + + + ch.qos.logback + logback-classic + test + + + diff --git a/java/core/src/main/java/com/google/genkit/core/Action.java b/java/core/src/main/java/com/google/genkit/core/Action.java new file mode 100644 index 0000000000..5a75cfb45d --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/Action.java @@ -0,0 +1,151 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import java.util.Map; +import java.util.function.Consumer; + +import com.fasterxml.jackson.databind.JsonNode; + +/** + * Action is the interface that all Genkit primitives (e.g., flows, models, + * tools) have in common. An Action represents a named, observable operation + * that can be executed and traced. + * + *

+ * Actions are the fundamental building blocks of Genkit applications. They + * provide: + *

    + *
  • Named operations that can be discovered and invoked
  • + *
  • Input/output schema validation
  • + *
  • Automatic tracing and observability
  • + *
  • Registry integration for reflection API support
  • + *
+ * + * @param + * The input type for the action + * @param + * The output type for the action + * @param + * The streaming chunk type (use Void for non-streaming actions) + */ +public interface Action extends Registerable { + + /** + * Returns the name of the action. + * + * @return the action name + */ + String getName(); + + /** + * Returns the type of the action. + * + * @return the action type + */ + ActionType getType(); + + /** + * Returns the descriptor of the action containing metadata, schemas, etc. + * + * @return the action descriptor + */ + ActionDesc getDesc(); + + /** + * Runs the action with the given input. + * + * @param ctx + * the action context + * @param input + * the input to the action + * @return the output of the action + * @throws GenkitException + * if the action fails + */ + O run(ActionContext ctx, I input) throws GenkitException; + + /** + * Runs the action with the given input and streaming callback. + * + * @param ctx + * the action context + * @param input + * the input to the action + * @param streamCallback + * callback for receiving streaming chunks, may be null + * @return the output of the action + * @throws GenkitException + * if the action fails + */ + O run(ActionContext ctx, I input, Consumer streamCallback) throws GenkitException; + + /** + * Runs the action with JSON input and returns JSON output. + * + * @param ctx + * the action context + * @param input + * the JSON input + * @param streamCallback + * callback for receiving streaming JSON chunks, may be null + * @return the JSON output + * @throws GenkitException + * if the action fails + */ + JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) throws GenkitException; + + /** + * Runs the action with JSON input and returns the result with telemetry + * information. + * + * @param ctx + * the action context + * @param input + * the JSON input + * @param streamCallback + * callback for receiving streaming JSON chunks, may be null + * @return the action result including telemetry data + * @throws GenkitException + * if the action fails + */ + ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException; + + /** + * Returns the JSON schema for the action's input type. + * + * @return the input schema as a map, or null if not defined + */ + Map getInputSchema(); + + /** + * Returns the JSON schema for the action's output type. + * + * @return the output schema as a map, or null if not defined + */ + Map getOutputSchema(); + + /** + * Returns additional metadata for the action. + * + * @return the metadata map + */ + Map getMetadata(); +} diff --git a/java/core/src/main/java/com/google/genkit/core/ActionContext.java b/java/core/src/main/java/com/google/genkit/core/ActionContext.java new file mode 100644 index 0000000000..8edb858f06 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/ActionContext.java @@ -0,0 +1,274 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import com.google.genkit.core.tracing.SpanContext; + +/** + * ActionContext provides context for action execution including tracing and + * flow information. It is passed to all action executions and carries + * request-scoped state. + */ +public class ActionContext { + + private final SpanContext spanContext; + private final String flowName; + private final String spanPath; + private final Registry registry; + private final String sessionId; + private final String threadName; + + /** + * Creates a new ActionContext. + * + * @param spanContext + * the tracing span context, may be null + * @param flowName + * the name of the enclosing flow, may be null + * @param spanPath + * the current span path for tracing + * @param registry + * the Genkit registry + * @param sessionId + * the session ID for multi-turn conversations + * @param threadName + * the thread name for grouping related requests + */ + public ActionContext(SpanContext spanContext, String flowName, String spanPath, Registry registry, String sessionId, + String threadName) { + this.spanContext = spanContext; + this.flowName = flowName; + this.spanPath = spanPath; + this.registry = registry; + this.sessionId = sessionId; + this.threadName = threadName; + } + + /** + * Creates a new ActionContext. + * + * @param spanContext + * the tracing span context, may be null + * @param flowName + * the name of the enclosing flow, may be null + * @param spanPath + * the current span path for tracing + * @param registry + * the Genkit registry + */ + public ActionContext(SpanContext spanContext, String flowName, String spanPath, Registry registry) { + this(spanContext, flowName, spanPath, registry, null, null); + } + + /** + * Creates a new ActionContext. + * + * @param spanContext + * the tracing span context, may be null + * @param flowName + * the name of the enclosing flow, may be null + * @param registry + * the Genkit registry + */ + public ActionContext(SpanContext spanContext, String flowName, Registry registry) { + this(spanContext, flowName, null, registry); + } + + /** + * Creates a new ActionContext with default values. + * + * @param registry + * the Genkit registry + */ + public ActionContext(Registry registry) { + this(null, null, null, registry); + } + + /** + * Returns the tracing span context. + * + * @return the span context, or null if tracing is not active + */ + public SpanContext getSpanContext() { + return spanContext; + } + + /** + * Returns the name of the enclosing flow, if any. + * + * @return the flow name, or null if not in a flow context + */ + public String getFlowName() { + return flowName; + } + + /** + * Returns the current span path for tracing. + * + * @return the span path, or null if not in a traced context + */ + public String getSpanPath() { + return spanPath; + } + + /** + * Returns the Genkit registry. + * + * @return the registry + */ + public Registry getRegistry() { + return registry; + } + + /** + * Returns the session ID for multi-turn conversations. + * + * @return the session ID, or null if not set + */ + public String getSessionId() { + return sessionId; + } + + /** + * Returns the thread name for grouping related requests. + * + * @return the thread name, or null if not set + */ + public String getThreadName() { + return threadName; + } + + /** + * Creates a new ActionContext with a different flow name. + * + * @param flowName + * the new flow name + * @return a new ActionContext with the updated flow name + */ + public ActionContext withFlowName(String flowName) { + return new ActionContext(this.spanContext, flowName, this.spanPath, this.registry, this.sessionId, + this.threadName); + } + + /** + * Creates a new ActionContext with a different span context. + * + * @param spanContext + * the new span context + * @return a new ActionContext with the updated span context + */ + public ActionContext withSpanContext(SpanContext spanContext) { + return new ActionContext(spanContext, this.flowName, this.spanPath, this.registry, this.sessionId, + this.threadName); + } + + /** + * Creates a new ActionContext with a different span path. + * + * @param spanPath + * the new span path + * @return a new ActionContext with the updated span path + */ + public ActionContext withSpanPath(String spanPath) { + return new ActionContext(this.spanContext, this.flowName, spanPath, this.registry, this.sessionId, + this.threadName); + } + + /** + * Creates a new ActionContext with a session ID. + * + * @param sessionId + * the session ID + * @return a new ActionContext with the session ID + */ + public ActionContext withSessionId(String sessionId) { + return new ActionContext(this.spanContext, this.flowName, this.spanPath, this.registry, sessionId, + this.threadName); + } + + /** + * Creates a new ActionContext with a thread name. + * + * @param threadName + * the thread name + * @return a new ActionContext with the thread name + */ + public ActionContext withThreadName(String threadName) { + return new ActionContext(this.spanContext, this.flowName, this.spanPath, this.registry, this.sessionId, + threadName); + } + + /** + * Creates a builder for ActionContext. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for ActionContext. + */ + public static class Builder { + private SpanContext spanContext; + private String flowName; + private String spanPath; + private Registry registry; + private String sessionId; + private String threadName; + + public Builder spanContext(SpanContext spanContext) { + this.spanContext = spanContext; + return this; + } + + public Builder flowName(String flowName) { + this.flowName = flowName; + return this; + } + + public Builder spanPath(String spanPath) { + this.spanPath = spanPath; + return this; + } + + public Builder registry(Registry registry) { + this.registry = registry; + return this; + } + + public Builder sessionId(String sessionId) { + this.sessionId = sessionId; + return this; + } + + public Builder threadName(String threadName) { + this.threadName = threadName; + return this; + } + + public ActionContext build() { + if (registry == null) { + throw new IllegalStateException("registry is required"); + } + return new ActionContext(spanContext, flowName, spanPath, registry, sessionId, threadName); + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/ActionDef.java b/java/core/src/main/java/com/google/genkit/core/ActionDef.java new file mode 100644 index 0000000000..a275280608 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/ActionDef.java @@ -0,0 +1,367 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genkit.core.tracing.SpanContext; +import com.google.genkit.core.tracing.SpanMetadata; +import com.google.genkit.core.tracing.Tracer; + +/** + * ActionDef is the default implementation of an Action. It provides a named, + * observable operation that can be executed and traced. + * + * @param + * The input type for the action + * @param + * The output type for the action + * @param + * The streaming chunk type (use Void for non-streaming actions) + */ +public class ActionDef implements Action { + + private static final Logger logger = LoggerFactory.getLogger(ActionDef.class); + private static final ObjectMapper objectMapper = JsonUtils.getObjectMapper(); + + private final ActionDesc desc; + private final StreamingFunction fn; + private final Class inputClass; + private final Class outputClass; + private Registry registry; + + /** + * Function interface for streaming actions. + * + * @param + * input type + * @param + * output type + * @param + * stream chunk type + */ + @FunctionalInterface + public interface StreamingFunction { + O apply(ActionContext ctx, I input, Consumer streamCallback) throws GenkitException; + } + + /** + * Function interface for non-streaming actions. + * + * @param + * input type + * @param + * output type + */ + @FunctionalInterface + public interface ActionFunction { + O apply(ActionContext ctx, I input) throws GenkitException; + } + + /** + * Creates a new ActionDef. + * + * @param name + * the action name + * @param type + * the action type + * @param metadata + * additional metadata + * @param inputSchema + * the input JSON schema + * @param inputClass + * the input class + * @param outputClass + * the output class + * @param fn + * the action function + */ + public ActionDef(String name, ActionType type, Map metadata, Map inputSchema, + Class inputClass, Class outputClass, StreamingFunction fn) { + if (name == null || name.isEmpty()) { + throw new IllegalArgumentException("Action name is required"); + } + if (type == null) { + throw new IllegalArgumentException("Action type is required"); + } + if (fn == null) { + throw new IllegalArgumentException("Action function is required"); + } + + String description = null; + if (metadata != null && metadata.get("description") instanceof String) { + description = (String) metadata.get("description"); + } + + // Generate schemas if not provided + Map actualInputSchema = inputSchema; + if (actualInputSchema == null && inputClass != null && inputClass != Void.class) { + actualInputSchema = SchemaUtils.inferSchema(inputClass); + } + + Map outputSchema = null; + if (outputClass != null && outputClass != Void.class) { + outputSchema = SchemaUtils.inferSchema(outputClass); + } + + this.desc = ActionDesc.builder().type(type).name(name).description(description).inputSchema(actualInputSchema) + .outputSchema(outputSchema).metadata(metadata != null ? metadata : new HashMap<>()).build(); + + this.fn = fn; + this.inputClass = inputClass; + this.outputClass = outputClass; + } + + /** + * Creates a non-streaming action. + * + * @param name + * the action name + * @param type + * the action type + * @param metadata + * additional metadata + * @param inputSchema + * the input JSON schema + * @param inputClass + * the input class + * @param outputClass + * the output class + * @param fn + * the action function + * @param + * input type + * @param + * output type + * @return a new ActionDef + */ + public static ActionDef create(String name, ActionType type, Map metadata, + Map inputSchema, Class inputClass, Class outputClass, ActionFunction fn) { + return new ActionDef<>(name, type, metadata, inputSchema, inputClass, outputClass, + (ctx, input, cb) -> fn.apply(ctx, input)); + } + + /** + * Creates a streaming action. + * + * @param name + * the action name + * @param type + * the action type + * @param metadata + * additional metadata + * @param inputSchema + * the input JSON schema + * @param inputClass + * the input class + * @param outputClass + * the output class + * @param fn + * the streaming function + * @param + * input type + * @param + * output type + * @param + * stream chunk type + * @return a new ActionDef + */ + public static ActionDef createStreaming(String name, ActionType type, + Map metadata, Map inputSchema, Class inputClass, Class outputClass, + StreamingFunction fn) { + return new ActionDef<>(name, type, metadata, inputSchema, inputClass, outputClass, fn); + } + + @Override + public String getName() { + return desc.getName(); + } + + @Override + public ActionType getType() { + return desc.getType(); + } + + @Override + public ActionDesc getDesc() { + return desc; + } + + @Override + public O run(ActionContext ctx, I input) throws GenkitException { + return run(ctx, input, null); + } + + @Override + public O run(ActionContext ctx, I input, Consumer streamCallback) throws GenkitException { + logger.debug("Action.run: name={}, input={}", getName(), input); + + // Determine the subtype based on action type for proper telemetry + // categorization + String subtype = getSubtypeForTelemetry(desc.getType()); + + SpanMetadata spanMetadata = SpanMetadata.builder().name(desc.getName()).type(desc.getType().getValue()) + .subtype(subtype).build(); + + String flowName = ctx.getFlowName(); + if (flowName != null) { + spanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); + } + + return Tracer.runInNewSpan(ctx, spanMetadata, input, (spanCtx, in) -> { + try { + O result = fn.apply(ctx.withSpanContext(spanCtx), in, streamCallback); + logger.debug("Action.run complete: name={}, result={}", getName(), result); + return result; + } catch (Exception e) { + logger.error("Action.run failed: name={}, error={}", getName(), e.getMessage(), e); + if (e instanceof GenkitException) { + throw (GenkitException) e; + } + throw new GenkitException("Action execution failed: " + e.getMessage(), e); + } + }); + } + + @Override + @SuppressWarnings("unchecked") + public JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException { + try { + I typedInput = null; + if (inputClass != null && inputClass != Void.class && input != null) { + typedInput = objectMapper.treeToValue(input, inputClass); + } + + Consumer typedCallback = null; + if (streamCallback != null) { + typedCallback = chunk -> { + try { + JsonNode jsonChunk = objectMapper.valueToTree(chunk); + streamCallback.accept(jsonChunk); + } catch (Exception e) { + throw new RuntimeException("Failed to serialize stream chunk", e); + } + }; + } + + O result = run(ctx, typedInput, typedCallback); + + if (result == null) { + return null; + } + return objectMapper.valueToTree(result); + } catch (Exception e) { + if (e instanceof GenkitException) { + throw (GenkitException) e; + } + throw new GenkitException("JSON action execution failed: " + e.getMessage(), e); + } + } + + @Override + public ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, + Consumer streamCallback) throws GenkitException { + SpanContext spanContext = ctx.getSpanContext(); + String traceId = spanContext != null ? spanContext.getTraceId() : null; + String spanId = spanContext != null ? spanContext.getSpanId() : null; + + JsonNode result = runJson(ctx, input, streamCallback); + + // Get updated span info after execution + SpanContext currentSpan = ctx.getSpanContext(); + if (currentSpan != null) { + traceId = currentSpan.getTraceId(); + spanId = currentSpan.getSpanId(); + } + + return new ActionRunResult<>(result, traceId, spanId); + } + + @Override + public Map getInputSchema() { + return desc.getInputSchema(); + } + + @Override + public Map getOutputSchema() { + return desc.getOutputSchema(); + } + + @Override + public Map getMetadata() { + return desc.getMetadata(); + } + + @Override + public void register(Registry registry) { + this.registry = registry; + registry.registerAction(desc.getKey(), this); + } + + /** + * Returns the registry this action is registered with. + * + * @return the registry, or null if not registered + */ + public Registry getRegistry() { + return registry; + } + + /** + * Returns the subtype for telemetry based on the action type. This matches the + * JS/Go SDK format for proper trace categorization. + * + * @param type + * the action type + * @return the subtype string for telemetry + */ + private static String getSubtypeForTelemetry(ActionType type) { + if (type == null) { + return null; + } + switch (type) { + case MODEL : + return "model"; + case TOOL : + return "tool"; + case FLOW : + return "flow"; + case EMBEDDER : + return "embedder"; + case RETRIEVER : + return "retriever"; + case INDEXER : + return "indexer"; + case EVALUATOR : + return "evaluator"; + case PROMPT : + return "prompt"; + default : + return type.getValue(); + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/ActionDesc.java b/java/core/src/main/java/com/google/genkit/core/ActionDesc.java new file mode 100644 index 0000000000..74e4498bbb --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/ActionDesc.java @@ -0,0 +1,205 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * ActionDesc is a descriptor of an action containing its metadata and schemas. + * This is used for reflection and discovery of actions. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ActionDesc { + + @JsonProperty("type") + private ActionType type; + + @JsonProperty("key") + private String key; + + @JsonProperty("name") + private String name; + + @JsonProperty("description") + private String description; + + @JsonProperty("inputSchema") + private Map inputSchema; + + @JsonProperty("outputSchema") + private Map outputSchema; + + @JsonProperty("metadata") + private Map metadata; + + /** + * Default constructor for Jackson deserialization. + */ + public ActionDesc() { + } + + /** + * Creates a new ActionDesc with the specified parameters. + * + * @param type + * the action type + * @param name + * the action name + * @param description + * optional description + * @param inputSchema + * optional input JSON schema + * @param outputSchema + * optional output JSON schema + * @param metadata + * optional metadata + */ + public ActionDesc(ActionType type, String name, String description, Map inputSchema, + Map outputSchema, Map metadata) { + this.type = type; + this.key = type.keyFromName(name); + this.name = name; + this.description = description; + this.inputSchema = inputSchema; + this.outputSchema = outputSchema; + this.metadata = metadata; + } + + /** + * Creates a builder for ActionDesc. + * + * @return a new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + // Getters and setters + + public ActionType getType() { + return type; + } + + public void setType(ActionType type) { + this.type = type; + } + + public String getKey() { + return key; + } + + public void setKey(String key) { + this.key = key; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public Map getInputSchema() { + return inputSchema; + } + + public void setInputSchema(Map inputSchema) { + this.inputSchema = inputSchema; + } + + public Map getOutputSchema() { + return outputSchema; + } + + public void setOutputSchema(Map outputSchema) { + this.outputSchema = outputSchema; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } + + /** + * Builder for ActionDesc. + */ + public static class Builder { + private ActionType type; + private String name; + private String description; + private Map inputSchema; + private Map outputSchema; + private Map metadata; + + public Builder type(ActionType type) { + this.type = type; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder inputSchema(Map inputSchema) { + this.inputSchema = inputSchema; + return this; + } + + public Builder outputSchema(Map outputSchema) { + this.outputSchema = outputSchema; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public ActionDesc build() { + if (type == null) { + throw new IllegalStateException("type is required"); + } + if (name == null || name.isEmpty()) { + throw new IllegalStateException("name is required"); + } + return new ActionDesc(type, name, description, inputSchema, outputSchema, metadata); + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/ActionRunResult.java b/java/core/src/main/java/com/google/genkit/core/ActionRunResult.java new file mode 100644 index 0000000000..017eb2816b --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/ActionRunResult.java @@ -0,0 +1,118 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +/** + * ActionRunResult contains the result of an action execution along with + * telemetry information. + * + * @param + * the type of the result + */ +public class ActionRunResult { + + private final T result; + private final String traceId; + private final String spanId; + + /** + * Creates a new ActionRunResult. + * + * @param result + * the action result + * @param traceId + * the trace ID for this execution + * @param spanId + * the span ID for this execution + */ + public ActionRunResult(T result, String traceId, String spanId) { + this.result = result; + this.traceId = traceId; + this.spanId = spanId; + } + + /** + * Returns the action result. + * + * @return the result + */ + public T getResult() { + return result; + } + + /** + * Returns the trace ID for this execution. + * + * @return the trace ID + */ + public String getTraceId() { + return traceId; + } + + /** + * Returns the span ID for this execution. + * + * @return the span ID + */ + public String getSpanId() { + return spanId; + } + + /** + * Creates a builder for ActionRunResult. + * + * @param + * the result type + * @return a new builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Builder for ActionRunResult. + * + * @param + * the result type + */ + public static class Builder { + private T result; + private String traceId; + private String spanId; + + public Builder result(T result) { + this.result = result; + return this; + } + + public Builder traceId(String traceId) { + this.traceId = traceId; + return this; + } + + public Builder spanId(String spanId) { + this.spanId = spanId; + return this; + } + + public ActionRunResult build() { + return new ActionRunResult<>(result, traceId, spanId); + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/ActionType.java b/java/core/src/main/java/com/google/genkit/core/ActionType.java new file mode 100644 index 0000000000..d922c4b0c2 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/ActionType.java @@ -0,0 +1,158 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +/** + * ActionType represents the kind of an action. Each type corresponds to a + * different Genkit primitive or capability. + */ +public enum ActionType { + /** + * A retriever action that fetches documents from a vector store or other + * source. + */ + RETRIEVER("retriever"), + + /** + * An indexer action that indexes documents into a vector store. + */ + INDEXER("indexer"), + + /** + * An embedder action that converts content to vector embeddings. + */ + EMBEDDER("embedder"), + + /** + * An evaluator action that assesses the quality of generated content. + */ + EVALUATOR("evaluator"), + + /** + * A flow action representing a user-defined workflow. + */ + FLOW("flow"), + + /** + * A model action for AI model inference. + */ + MODEL("model"), + + /** + * A background model action for long-running inference operations. + */ + BACKGROUND_MODEL("background-model"), + + /** + * An executable prompt action that can generate content directly. Uses the key + * format "/executable-prompt/{name}" to match Go SDK. This is the primary + * prompt type used by the Genkit Developer UI. + */ + EXECUTABLE_PROMPT("executable-prompt"), + + /** + * A prompt action that renders templates to generate model requests. Uses the + * key format "/prompt/{name}" to match the JS SDK. + */ + PROMPT("prompt"), + + /** + * A resource action for managing external resources. + */ + RESOURCE("resource"), + + /** + * A tool action that can be called by AI models. + */ + TOOL("tool"), + + /** + * A tool action using the v2 multipart format. + */ + TOOL_V2("tool.v2"), + + /** + * A utility action for internal operations. + */ + UTIL("util"), + + /** + * A custom action type for user-defined action types. + */ + CUSTOM("custom"), + + /** + * An action for checking operation status. + */ + CHECK_OPERATION("check-operation"), + + /** + * An action for cancelling operations. + */ + CANCEL_OPERATION("cancel-operation"); + + private final String value; + + ActionType(String value) { + this.value = value; + } + + /** + * Returns the string value of the action type. + * + * @return the action type string value + */ + public String getValue() { + return value; + } + + /** + * Creates an ActionType from a string value. + * + * @param value + * the string value + * @return the corresponding ActionType + * @throws IllegalArgumentException + * if the value doesn't match any ActionType + */ + public static ActionType fromValue(String value) { + for (ActionType type : values()) { + if (type.value.equals(value)) { + return type; + } + } + throw new IllegalArgumentException("Unknown action type: " + value); + } + + /** + * Creates the registry key for an action of this type with the given name. + * + * @param name + * the action name + * @return the registry key + */ + public String keyFromName(String name) { + return "/" + value + "/" + name; + } + + @Override + public String toString() { + return value; + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/DefaultRegistry.java b/java/core/src/main/java/com/google/genkit/core/DefaultRegistry.java new file mode 100644 index 0000000000..82be3116df --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/DefaultRegistry.java @@ -0,0 +1,288 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DefaultRegistry is the default implementation of the Registry interface. It + * provides thread-safe storage and lookup of Genkit primitives. + */ +public class DefaultRegistry implements Registry { + + private static final Logger logger = LoggerFactory.getLogger(DefaultRegistry.class); + + private final Registry parent; + private final Map> actions = new ConcurrentHashMap<>(); + private final Map plugins = new ConcurrentHashMap<>(); + private final Map values = new ConcurrentHashMap<>(); + private final Map> schemas = new ConcurrentHashMap<>(); + private final Map partials = new ConcurrentHashMap<>(); + private final Map helpers = new ConcurrentHashMap<>(); + + /** + * Creates a new root registry. + */ + public DefaultRegistry() { + this(null); + } + + /** + * Creates a new child registry with the given parent. + * + * @param parent + * the parent registry, or null for a root registry + */ + public DefaultRegistry(Registry parent) { + this.parent = parent; + } + + @Override + public Registry newChild() { + return new DefaultRegistry(this); + } + + @Override + public boolean isChild() { + return parent != null; + } + + @Override + public void registerPlugin(String name, Plugin plugin) { + if (plugins.containsKey(name)) { + throw new IllegalStateException("Plugin already registered: " + name); + } + plugins.put(name, plugin); + logger.debug("Registered plugin: {}", name); + } + + @Override + public void registerAction(String key, Action action) { + if (actions.containsKey(key)) { + throw new IllegalStateException("Action already registered: " + key); + } + actions.put(key, action); + logger.debug("Registered action: {}", key); + } + + @Override + public void registerValue(String name, Object value) { + if (values.containsKey(name)) { + throw new IllegalStateException("Value already registered: " + name); + } + values.put(name, value); + logger.debug("Registered value: {}", name); + } + + @Override + public void registerSchema(String name, Map schema) { + if (schemas.containsKey(name)) { + throw new IllegalStateException("Schema already registered: " + name); + } + schemas.put(name, schema); + logger.debug("Registered schema: {}", name); + } + + @Override + public Plugin lookupPlugin(String name) { + Plugin plugin = plugins.get(name); + if (plugin == null && parent != null) { + plugin = parent.lookupPlugin(name); + } + return plugin; + } + + @Override + public Action lookupAction(String key) { + Action action = actions.get(key); + if (action == null && parent != null) { + action = parent.lookupAction(key); + } + return action; + } + + @Override + public Object lookupValue(String name) { + Object value = values.get(name); + if (value == null && parent != null) { + value = parent.lookupValue(name); + } + return value; + } + + @Override + public Map lookupSchema(String name) { + Map schema = schemas.get(name); + if (schema == null && parent != null) { + schema = parent.lookupSchema(name); + } + return schema; + } + + @Override + public Action resolveAction(String key) { + // First try direct lookup + Action action = lookupAction(key); + if (action != null) { + return action; + } + + // Try dynamic resolution through plugins + for (Plugin plugin : listPlugins()) { + if (plugin instanceof DynamicPlugin) { + DynamicPlugin dynamicPlugin = (DynamicPlugin) plugin; + // Parse the key to get type and name + String[] parts = key.split("/"); + if (parts.length >= 3) { + ActionType type = ActionType.fromValue(parts[1]); + String name = parts[2]; + action = dynamicPlugin.resolveAction(type, name); + if (action != null) { + // Register for future lookups + registerAction(key, action); + return action; + } + } + } + } + + return null; + } + + @Override + public List> listActions() { + Map> allActions = new LinkedHashMap<>(); + + // First add parent actions + if (parent != null) { + for (Action action : parent.listActions()) { + allActions.put(action.getDesc().getKey(), action); + } + } + + // Then add/override with local actions + allActions.putAll(actions); + + // Also include dynamic actions from plugins + for (Plugin plugin : listPlugins()) { + if (plugin instanceof DynamicPlugin) { + DynamicPlugin dynamicPlugin = (DynamicPlugin) plugin; + for (ActionDesc desc : dynamicPlugin.listActions()) { + if (!allActions.containsKey(desc.getKey())) { + Action action = dynamicPlugin.resolveAction(desc.getType(), desc.getName()); + if (action != null) { + allActions.put(desc.getKey(), action); + } + } + } + } + } + + return new ArrayList<>(allActions.values()); + } + + @Override + public List> listActions(ActionType type) { + String prefix = "/" + type.toString().toLowerCase() + "/"; + List> result = new ArrayList<>(); + + for (Action action : listActions()) { + String key = action.getDesc() != null ? action.getDesc().getKey() : null; + if (key != null && key.contains(prefix)) { + result.add(action); + } else { + // Check by iterating through the local actions map + for (Map.Entry> entry : actions.entrySet()) { + if (entry.getKey().contains(prefix) && entry.getValue() == action) { + result.add(action); + break; + } + } + } + } + + return result; + } + + @Override + public List listPlugins() { + Map allPlugins = new LinkedHashMap<>(); + + // First add parent plugins + if (parent != null) { + for (Plugin plugin : parent.listPlugins()) { + allPlugins.put(plugin.getName(), plugin); + } + } + + // Then add/override with local plugins + allPlugins.putAll(plugins); + + return new ArrayList<>(allPlugins.values()); + } + + @Override + public Map listValues() { + Map allValues = new LinkedHashMap<>(); + + // First add parent values + if (parent != null) { + allValues.putAll(parent.listValues()); + } + + // Then add/override with local values + allValues.putAll(values); + + return allValues; + } + + @Override + public void registerPartial(String name, String source) { + partials.put(name, source); + logger.debug("Registered partial: {}", name); + } + + @Override + public void registerHelper(String name, Object helper) { + helpers.put(name, helper); + logger.debug("Registered helper: {}", name); + } + + @Override + public String lookupPartial(String name) { + String partial = partials.get(name); + if (partial == null && parent != null) { + partial = parent.lookupPartial(name); + } + return partial; + } + + @Override + public Object lookupHelper(String name) { + Object helper = helpers.get(name); + if (helper == null && parent != null) { + helper = parent.lookupHelper(name); + } + return helper; + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/DynamicPlugin.java b/java/core/src/main/java/com/google/genkit/core/DynamicPlugin.java new file mode 100644 index 0000000000..5f05a72a3f --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/DynamicPlugin.java @@ -0,0 +1,47 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import java.util.List; + +/** + * DynamicPlugin is a Plugin that can dynamically resolve actions. This is + * useful for plugins that provide a large number of actions or actions that are + * determined at runtime. + */ +public interface DynamicPlugin extends Plugin { + + /** + * Returns a list of action descriptors that the plugin is capable of resolving. + * + * @return list of action descriptors + */ + List listActions(); + + /** + * Resolves an action by type and name. + * + * @param type + * the action type + * @param name + * the action name + * @return the resolved action, or null if not resolvable + */ + Action resolveAction(ActionType type, String name); +} diff --git a/java/core/src/main/java/com/google/genkit/core/Flow.java b/java/core/src/main/java/com/google/genkit/core/Flow.java new file mode 100644 index 0000000000..00882471e4 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/Flow.java @@ -0,0 +1,391 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.genkit.core.middleware.Middleware; +import com.google.genkit.core.middleware.MiddlewareChain; +import com.google.genkit.core.tracing.SpanMetadata; +import com.google.genkit.core.tracing.Tracer; + +/** + * A Flow is a user-defined Action. It represents a function from input I to + * output O. The Stream parameter S is for flows that support streaming their + * results incrementally. + * + *

+ * Flows are the primary way to organize AI application logic in Genkit. They + * provide: + *

    + *
  • Observability through automatic tracing
  • + *
  • Integration with Genkit developer tools
  • + *
  • Easy deployment as API endpoints
  • + *
  • Built-in streaming support
  • + *
+ * + * @param + * The input type for the flow + * @param + * The output type for the flow + * @param + * The streaming chunk type (use Void for non-streaming flows) + */ +public class Flow implements Action { + + private final ActionDef actionDef; + private final MiddlewareChain middlewareChain; + + /** + * Creates a new Flow wrapping an ActionDef. + * + * @param actionDef + * the underlying action definition + */ + private Flow(ActionDef actionDef) { + this(actionDef, new MiddlewareChain<>()); + } + + /** + * Creates a new Flow wrapping an ActionDef with middleware. + * + * @param actionDef + * the underlying action definition + * @param middlewareChain + * the middleware chain to use + */ + private Flow(ActionDef actionDef, MiddlewareChain middlewareChain) { + this.actionDef = actionDef; + this.middlewareChain = middlewareChain; + } + + /** + * Defines a new non-streaming flow and registers it. + * + * @param registry + * the registry to register with + * @param name + * the flow name + * @param inputClass + * the input type class + * @param outputClass + * the output type class + * @param fn + * the flow function + * @param + * input type + * @param + * output type + * @return the created flow + */ + public static Flow define(Registry registry, String name, Class inputClass, + Class outputClass, BiFunction fn) { + return define(registry, name, inputClass, outputClass, fn, null); + } + + /** + * Defines a new non-streaming flow with middleware and registers it. + * + * @param registry + * the registry to register with + * @param name + * the flow name + * @param inputClass + * the input type class + * @param outputClass + * the output type class + * @param fn + * the flow function + * @param middleware + * the middleware to apply + * @param + * input type + * @param + * output type + * @return the created flow + */ + public static Flow define(Registry registry, String name, Class inputClass, + Class outputClass, BiFunction fn, List> middleware) { + MiddlewareChain chain = new MiddlewareChain<>(); + if (middleware != null) { + chain.useAll(middleware); + } + + ActionDef actionDef = ActionDef.create(name, ActionType.FLOW, null, null, inputClass, outputClass, + (ctx, input) -> { + ActionContext flowCtx = ctx.withFlowName(name); + if (chain.isEmpty()) { + return fn.apply(flowCtx, input); + } + return chain.execute(input, flowCtx, (c, i) -> fn.apply(c, i)); + }); + + Flow flow = new Flow<>(actionDef, chain); + flow.register(registry); + return flow; + } + + /** + * Defines a new streaming flow and registers it. + * + * @param registry + * the registry to register with + * @param name + * the flow name + * @param inputClass + * the input type class + * @param outputClass + * the output type class + * @param fn + * the streaming flow function + * @param + * input type + * @param + * output type + * @param + * stream chunk type + * @return the created flow + */ + public static Flow defineStreaming(Registry registry, String name, Class inputClass, + Class outputClass, ActionDef.StreamingFunction fn) { + ActionDef actionDef = ActionDef.createStreaming(name, ActionType.FLOW, null, null, inputClass, + outputClass, (ctx, input, cb) -> { + ActionContext flowCtx = ctx.withFlowName(name); + return fn.apply(flowCtx, input, cb); + }); + + Flow flow = new Flow<>(actionDef); + flow.register(registry); + return flow; + } + + /** + * Runs a named step within the current flow. Each call to run results in a new + * step with its own trace span. + * + * @param ctx + * the action context (must be a flow context) + * @param name + * the step name + * @param fn + * the step function + * @param + * the step output type + * @return the step result + * @throws GenkitException + * if not called from within a flow + */ + public static T run(ActionContext ctx, String name, Function fn) throws GenkitException { + if (ctx.getFlowName() == null) { + throw new GenkitException("Flow.run(\"" + name + "\"): must be called from within a flow"); + } + + SpanMetadata spanMetadata = SpanMetadata.builder().name(name).type("flowStep").subtype("flowStep").build(); + + return Tracer.runInNewSpan(ctx, spanMetadata, null, (spanCtx, input) -> { + try { + return fn.apply(null); + } catch (Exception e) { + if (e instanceof GenkitException) { + throw (GenkitException) e; + } + throw new GenkitException("Flow step failed: " + e.getMessage(), e); + } + }); + } + + @Override + public String getName() { + return actionDef.getName(); + } + + @Override + public ActionType getType() { + return actionDef.getType(); + } + + @Override + public ActionDesc getDesc() { + return actionDef.getDesc(); + } + + @Override + public O run(ActionContext ctx, I input) throws GenkitException { + return actionDef.run(ctx, input); + } + + @Override + public O run(ActionContext ctx, I input, Consumer streamCallback) throws GenkitException { + return actionDef.run(ctx, input, streamCallback); + } + + @Override + public JsonNode runJson(ActionContext ctx, JsonNode input, Consumer streamCallback) + throws GenkitException { + return actionDef.runJson(ctx, input, streamCallback); + } + + @Override + public ActionRunResult runJsonWithTelemetry(ActionContext ctx, JsonNode input, + Consumer streamCallback) throws GenkitException { + return actionDef.runJsonWithTelemetry(ctx, input, streamCallback); + } + + @Override + public Map getInputSchema() { + return actionDef.getInputSchema(); + } + + @Override + public Map getOutputSchema() { + return actionDef.getOutputSchema(); + } + + @Override + public Map getMetadata() { + return actionDef.getMetadata(); + } + + @Override + public void register(Registry registry) { + actionDef.register(registry); + } + + /** + * Returns the middleware chain for this flow. + * + * @return the middleware chain + */ + public MiddlewareChain getMiddlewareChain() { + return middlewareChain; + } + + /** + * Creates a copy of this flow with additional middleware. + * + * @param middleware + * the middleware to add + * @return a new flow with the middleware added + */ + public Flow withMiddleware(Middleware middleware) { + MiddlewareChain newChain = middlewareChain.copy(); + newChain.use(middleware); + return new Flow<>(actionDef, newChain); + } + + /** + * Creates a copy of this flow with additional middleware. + * + * @param middlewareList + * the middleware to add + * @return a new flow with the middleware added + */ + public Flow withMiddleware(List> middlewareList) { + MiddlewareChain newChain = middlewareChain.copy(); + newChain.useAll(middlewareList); + return new Flow<>(actionDef, newChain); + } + + /** + * Streams the flow output with the given input. Returns a consumer that can be + * used with a yield-style iteration pattern. + * + * @param ctx + * the action context + * @param input + * the flow input + * @param consumer + * the consumer for streaming values + */ + public void stream(ActionContext ctx, I input, Consumer> consumer) { + Consumer streamCallback = chunk -> { + consumer.accept(new StreamingFlowValue<>(false, null, chunk)); + }; + + try { + O output = run(ctx, input, streamCallback); + consumer.accept(new StreamingFlowValue<>(true, output, null)); + } catch (GenkitException e) { + throw e; + } + } + + /** + * StreamingFlowValue represents either a streamed chunk or the final output of + * a flow. + * + * @param + * the output type + * @param + * the stream chunk type + */ + public static class StreamingFlowValue { + private final boolean done; + private final O output; + private final S stream; + + /** + * Creates a new StreamingFlowValue. + * + * @param done + * true if this is the final output + * @param output + * the final output (valid if done is true) + * @param stream + * the stream chunk (valid if done is false) + */ + public StreamingFlowValue(boolean done, O output, S stream) { + this.done = done; + this.output = output; + this.stream = stream; + } + + /** + * Returns true if this is the final output. + * + * @return true if done + */ + public boolean isDone() { + return done; + } + + /** + * Returns the final output. Valid only if isDone() returns true. + * + * @return the output + */ + public O getOutput() { + return output; + } + + /** + * Returns the stream chunk. Valid only if isDone() returns false. + * + * @return the stream chunk + */ + public S getStream() { + return stream; + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/GenkitException.java b/java/core/src/main/java/com/google/genkit/core/GenkitException.java new file mode 100644 index 0000000000..a0a864c6ef --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/GenkitException.java @@ -0,0 +1,152 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +/** + * GenkitException is the base exception for all Genkit errors. It provides + * structured error information including error codes and details. + */ +public class GenkitException extends RuntimeException { + + private final String errorCode; + private final Object details; + private final String traceId; + + /** + * Creates a new GenkitException. + * + * @param message + * the error message + */ + public GenkitException(String message) { + this(message, null, null, null, null); + } + + /** + * Creates a new GenkitException with a cause. + * + * @param message + * the error message + * @param cause + * the underlying cause + */ + public GenkitException(String message, Throwable cause) { + this(message, cause, null, null, null); + } + + /** + * Creates a new GenkitException with full details. + * + * @param message + * the error message + * @param cause + * the underlying cause + * @param errorCode + * the error code + * @param details + * additional error details + * @param traceId + * the trace ID for debugging + */ + public GenkitException(String message, Throwable cause, String errorCode, Object details, String traceId) { + super(message, cause); + this.errorCode = errorCode; + this.details = details; + this.traceId = traceId; + } + + /** + * Returns the error code. + * + * @return the error code, or null if not set + */ + public String getErrorCode() { + return errorCode; + } + + /** + * Returns additional error details. + * + * @return the error details, or null if not set + */ + public Object getDetails() { + return details; + } + + /** + * Returns the trace ID for this error. + * + * @return the trace ID, or null if not set + */ + public String getTraceId() { + return traceId; + } + + /** + * Creates a builder for GenkitException. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for GenkitException. + */ + public static class Builder { + private String message; + private Throwable cause; + private String errorCode; + private Object details; + private String traceId; + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder cause(Throwable cause) { + this.cause = cause; + return this; + } + + public Builder errorCode(String errorCode) { + this.errorCode = errorCode; + return this; + } + + public Builder details(Object details) { + this.details = details; + return this; + } + + public Builder traceId(String traceId) { + this.traceId = traceId; + return this; + } + + public GenkitException build() { + if (message == null || message.isEmpty()) { + throw new IllegalStateException("message is required"); + } + return new GenkitException(message, cause, errorCode, details, traceId); + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/JsonUtils.java b/java/core/src/main/java/com/google/genkit/core/JsonUtils.java new file mode 100644 index 0000000000..d213a2d906 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/JsonUtils.java @@ -0,0 +1,184 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; + +/** + * JsonUtils provides JSON serialization and deserialization utilities for + * Genkit. + */ +public final class JsonUtils { + + private static final ObjectMapper objectMapper; + + static { + objectMapper = new ObjectMapper(); + objectMapper.registerModule(new JavaTimeModule()); + objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + objectMapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false); + objectMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + } + + private JsonUtils() { + // Utility class + } + + /** + * Returns the shared ObjectMapper instance. + * + * @return the ObjectMapper + */ + public static ObjectMapper getObjectMapper() { + return objectMapper; + } + + /** + * Converts an object to JSON string. + * + * @param value + * the object to convert + * @return the JSON string + * @throws GenkitException + * if serialization fails + */ + public static String toJson(Object value) throws GenkitException { + try { + return objectMapper.writeValueAsString(value); + } catch (JsonProcessingException e) { + throw new GenkitException("Failed to serialize to JSON: " + e.getMessage(), e); + } + } + + /** + * Converts an object to a JsonNode. + * + * @param value + * the object to convert + * @return the JsonNode + */ + public static JsonNode toJsonNode(Object value) { + return objectMapper.valueToTree(value); + } + + /** + * Parses a JSON string to the specified type. + * + * @param json + * the JSON string + * @param clazz + * the target class + * @param + * the target type + * @return the parsed object + * @throws GenkitException + * if parsing fails + */ + public static T fromJson(String json, Class clazz) throws GenkitException { + try { + return objectMapper.readValue(json, clazz); + } catch (JsonProcessingException e) { + throw new GenkitException("Failed to parse JSON: " + e.getMessage(), e); + } + } + + /** + * Converts a JsonNode to the specified type. + * + * @param node + * the JsonNode + * @param clazz + * the target class + * @param + * the target type + * @return the converted object + * @throws GenkitException + * if conversion fails + */ + public static T fromJsonNode(JsonNode node, Class clazz) throws GenkitException { + try { + return objectMapper.treeToValue(node, clazz); + } catch (JsonProcessingException e) { + throw new GenkitException("Failed to convert JsonNode: " + e.getMessage(), e); + } + } + + /** + * Parses a JSON string to a JsonNode. + * + * @param json + * the JSON string + * @return the JsonNode + * @throws GenkitException + * if parsing fails + */ + public static JsonNode parseJson(String json) throws GenkitException { + try { + return objectMapper.readTree(json); + } catch (JsonProcessingException e) { + throw new GenkitException("Failed to parse JSON: " + e.getMessage(), e); + } + } + + /** + * Converts an object to the specified type. + * + *

+ * This is useful for converting Maps (from JSON parsing) to typed objects. + * + * @param value + * the object to convert (typically a Map from JSON parsing) + * @param clazz + * the target class + * @param + * the target type + * @return the converted object + * @throws GenkitException + * if conversion fails + */ + public static T convert(Object value, Class clazz) throws GenkitException { + try { + return objectMapper.convertValue(value, clazz); + } catch (IllegalArgumentException e) { + throw new GenkitException("Failed to convert object to " + clazz.getName() + ": " + e.getMessage(), e); + } + } + + /** + * Pretty prints a JSON object. + * + * @param value + * the object to print + * @return the pretty-printed JSON string + * @throws GenkitException + * if serialization fails + */ + public static String toPrettyJson(Object value) throws GenkitException { + try { + return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(value); + } catch (JsonProcessingException e) { + throw new GenkitException("Failed to serialize to JSON: " + e.getMessage(), e); + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/Plugin.java b/java/core/src/main/java/com/google/genkit/core/Plugin.java new file mode 100644 index 0000000000..25704f03d2 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/Plugin.java @@ -0,0 +1,66 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import java.util.List; + +/** + * Plugin is the interface implemented by types that extend Genkit's + * functionality. Plugins are typically used to integrate external services like + * model providers, vector databases, or monitoring tools. + * + *

+ * Plugins are registered and initialized via the Genkit builder during + * initialization. + */ +public interface Plugin { + + /** + * Returns the unique identifier for the plugin. This name is used for + * registration and lookup. + * + * @return the plugin name + */ + String getName(); + + /** + * Initializes the plugin. This method is called once during Genkit + * initialization. The plugin should return a list of actions that it provides. + * + * @return list of actions provided by this plugin + */ + List> init(); + + /** + * Initializes the plugin with access to the registry. This method is called + * once during Genkit initialization. The plugin should return a list of actions + * that it provides. + * + *

+ * Override this method instead of {@link #init()} when your plugin needs to + * resolve dependencies from the registry (e.g., embedders, models). + * + * @param registry + * the Genkit registry for resolving dependencies + * @return list of actions provided by this plugin + */ + default List> init(Registry registry) { + return init(); + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/Registerable.java b/java/core/src/main/java/com/google/genkit/core/Registerable.java new file mode 100644 index 0000000000..e944ef3b25 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/Registerable.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +/** + * Registerable allows a primitive to be registered with a registry. All Genkit + * primitives (actions, flows, models, tools, etc.) implement this interface. + */ +public interface Registerable { + + /** + * Registers this primitive with the given registry. + * + * @param registry + * the registry to register with + */ + void register(Registry registry); +} diff --git a/java/core/src/main/java/com/google/genkit/core/Registry.java b/java/core/src/main/java/com/google/genkit/core/Registry.java new file mode 100644 index 0000000000..c078bd5f8b --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/Registry.java @@ -0,0 +1,261 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import java.util.List; +import java.util.Map; + +/** + * Registry holds all registered actions and associated types, and provides + * methods to register, query, and look up actions. + * + *

+ * The Registry is the central component for managing Genkit primitives. It + * provides: + *

    + *
  • Storage and lookup of actions by key
  • + *
  • Plugin management
  • + *
  • Value storage for configuration
  • + *
  • Schema registration for JSON validation
  • + *
  • Hierarchical registry support for scoped operations
  • + *
+ */ +public interface Registry { + + /** + * Creates a new child registry that inherits from this registry. Child + * registries are useful for scoped operations and will fall back to the parent + * for lookups if a value is not found in the child. + * + * @return a new child registry + */ + Registry newChild(); + + /** + * Returns true if this registry is a child of another registry. + * + * @return true if this is a child registry + */ + boolean isChild(); + + /** + * Records the plugin in the registry. + * + * @param name + * the plugin name + * @param plugin + * the plugin to register + * @throws IllegalStateException + * if a plugin with the same name is already registered + */ + void registerPlugin(String name, Plugin plugin); + + /** + * Records the action in the registry. + * + * @param key + * the action key (type + name) + * @param action + * the action to register + * @throws IllegalStateException + * if an action with the same key is already registered + */ + void registerAction(String key, Action action); + + /** + * Records an arbitrary value in the registry. + * + * @param name + * the value name + * @param value + * the value to register + * @throws IllegalStateException + * if a value with the same name is already registered + */ + void registerValue(String name, Object value); + + /** + * Records a JSON schema in the registry. + * + * @param name + * the schema name + * @param schema + * the schema as a map + * @throws IllegalStateException + * if a schema with the same name is already registered + */ + void registerSchema(String name, Map schema); + + /** + * Returns the plugin for the given name. It first checks the current registry, + * then falls back to the parent if not found. + * + * @param name + * the plugin name + * @return the plugin, or null if not found + */ + Plugin lookupPlugin(String name); + + /** + * Returns the action for the given key. It first checks the current registry, + * then falls back to the parent if not found. + * + * @param key + * the action key + * @return the action, or null if not found + */ + Action lookupAction(String key); + + /** + * Returns the action for the given type and name. + * + * @param type + * the action type + * @param name + * the action name + * @return the action, or null if not found + */ + default Action lookupAction(ActionType type, String name) { + return lookupAction(type.keyFromName(name)); + } + + /** + * Returns the value for the given name. It first checks the current registry, + * then falls back to the parent if not found. + * + * @param name + * the value name + * @return the value, or null if not found + */ + Object lookupValue(String name); + + /** + * Returns a JSON schema for the given name. It first checks the current + * registry, then falls back to the parent if not found. + * + * @param name + * the schema name + * @return the schema as a map, or null if not found + */ + Map lookupSchema(String name); + + /** + * Looks up an action by key. If the action is not found, it attempts dynamic + * resolution through registered dynamic plugins. + * + * @param key + * the action key + * @return the action if found, or null if not found + */ + Action resolveAction(String key); + + /** + * Looks up an action by type and name with dynamic resolution support. + * + * @param type + * the action type + * @param name + * the action name + * @return the action if found, or null if not found + */ + default Action resolveAction(ActionType type, String name) { + return resolveAction(type.keyFromName(name)); + } + + /** + * Returns a list of all registered actions. This includes actions from both the + * current registry and its parent hierarchy. + * + * @return list of all registered actions + */ + List> listActions(); + + /** + * Returns a list of all registered actions of the specified type. + * + * @param type + * the action type to filter by + * @return list of actions of the specified type + */ + List> listActions(ActionType type); + + /** + * Registers an action by type and action name. + * + * @param type + * the action type + * @param action + * the action to register + */ + default void registerAction(ActionType type, Action action) { + registerAction(type.keyFromName(action.getName()), action); + } + + /** + * Returns a list of all registered plugins. + * + * @return list of all registered plugins + */ + List listPlugins(); + + /** + * Returns a map of all registered values. + * + * @return map of all registered values + */ + Map listValues(); + + /** + * Registers a partial template for use with prompts. + * + * @param name + * the partial name + * @param source + * the partial template source + */ + void registerPartial(String name, String source); + + /** + * Registers a helper function for use with prompts. + * + * @param name + * the helper name + * @param helper + * the helper function + */ + void registerHelper(String name, Object helper); + + /** + * Returns a registered partial by name. + * + * @param name + * the partial name + * @return the partial source, or null if not found + */ + String lookupPartial(String name); + + /** + * Returns a registered helper by name. + * + * @param name + * the helper name + * @return the helper function, or null if not found + */ + Object lookupHelper(String name); +} diff --git a/java/core/src/main/java/com/google/genkit/core/SchemaUtils.java b/java/core/src/main/java/com/google/genkit/core/SchemaUtils.java new file mode 100644 index 0000000000..c7dda33804 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/SchemaUtils.java @@ -0,0 +1,143 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import java.util.Map; + +import com.fasterxml.jackson.databind.JsonNode; +import com.github.victools.jsonschema.generator.*; + +/** + * SchemaUtils provides utilities for JSON Schema generation and validation. + */ +public final class SchemaUtils { + + private static final SchemaGenerator schemaGenerator; + + static { + SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_7, + OptionPreset.PLAIN_JSON); + + configBuilder.with(Option.EXTRA_OPEN_API_FORMAT_VALUES); + configBuilder.with(Option.FLATTENED_ENUMS); + // Note: Removed NULLABLE_FIELDS_BY_DEFAULT as it generates "type": ["string", + // "null"] + // for all fields, which causes issues with the Genkit UI input form generation. + // Fields should be explicitly marked as nullable using @Nullable annotation if + // needed. + + SchemaGeneratorConfig config = configBuilder.build(); + schemaGenerator = new SchemaGenerator(config); + } + + private SchemaUtils() { + // Utility class + } + + /** + * Generates a JSON Schema for the given class. + * + * @param clazz + * the class to generate schema for + * @return the JSON schema as a map + */ + @SuppressWarnings("unchecked") + public static Map inferSchema(Class clazz) { + if (clazz == null || clazz == Void.class || clazz == void.class) { + return null; + } + + try { + JsonNode schemaNode = schemaGenerator.generateSchema(clazz); + return JsonUtils.getObjectMapper().convertValue(schemaNode, Map.class); + } catch (Exception e) { + // If schema generation fails, return a simple object schema + return Map.of("type", "object"); + } + } + + /** + * Generates a JSON Schema for a primitive type. + * + * @param typeName + * the type name (string, number, integer, boolean, array, object) + * @return the JSON schema as a map + */ + public static Map simpleSchema(String typeName) { + return Map.of("type", typeName); + } + + /** + * Creates a schema for a string type. + * + * @return the string schema + */ + public static Map stringSchema() { + return simpleSchema("string"); + } + + /** + * Creates a schema for an integer type. + * + * @return the integer schema + */ + public static Map integerSchema() { + return simpleSchema("integer"); + } + + /** + * Creates a schema for a number type. + * + * @return the number schema + */ + public static Map numberSchema() { + return simpleSchema("number"); + } + + /** + * Creates a schema for a boolean type. + * + * @return the boolean schema + */ + public static Map booleanSchema() { + return simpleSchema("boolean"); + } + + /** + * Creates a schema for an array type with the given items schema. + * + * @param itemsSchema + * the schema for array items + * @return the array schema + */ + public static Map arraySchema(Map itemsSchema) { + return Map.of("type", "array", "items", itemsSchema); + } + + /** + * Creates a schema for an object type with the given properties. + * + * @param properties + * the property schemas + * @return the object schema + */ + public static Map objectSchema(Map properties) { + return Map.of("type", "object", "properties", properties); + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/ServerPlugin.java b/java/core/src/main/java/com/google/genkit/core/ServerPlugin.java new file mode 100644 index 0000000000..5688762c06 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/ServerPlugin.java @@ -0,0 +1,80 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +/** + * ServerPlugin is an extended Plugin interface for plugins that provide HTTP + * server functionality. + * + *

+ * This interface adds lifecycle methods for starting and stopping servers. The + * {@link #start()} method blocks until the server is stopped, similar to + * Express's app.listen() in JavaScript. + * + *

+ * Example usage: + * + *

{@code
+ * JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build());
+ * 
+ * Genkit genkit = Genkit.builder().plugin(jetty).build();
+ * 
+ * // Define your flows here...
+ * 
+ * // Start the server and block - this replaces Thread.currentThread().join()
+ * jetty.start();
+ * }
+ */ +public interface ServerPlugin extends Plugin { + + /** + * Starts the HTTP server and blocks until it is stopped. + * + *

+ * This is the recommended way to start a server in a main() method. Similar to + * Express's app.listen() in JavaScript, this method will keep your application + * running until the server is explicitly stopped. + * + * @throws Exception + * if the server cannot be started or if interrupted while waiting + */ + void start() throws Exception; + + /** + * Stops the HTTP server. + * + * @throws Exception + * if the server cannot be stopped + */ + void stop() throws Exception; + + /** + * Returns the port the server is listening on. + * + * @return the server port + */ + int getPort(); + + /** + * Returns true if the server is currently running. + * + * @return true if running, false otherwise + */ + boolean isRunning(); +} diff --git a/java/core/src/main/java/com/google/genkit/core/middleware/CommonMiddleware.java b/java/core/src/main/java/com/google/genkit/core/middleware/CommonMiddleware.java new file mode 100644 index 0000000000..16d23b98e9 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/middleware/CommonMiddleware.java @@ -0,0 +1,425 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.middleware; + +import java.time.Duration; +import java.time.Instant; +import java.util.function.BiConsumer; +import java.util.function.BiPredicate; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** + * CommonMiddleware provides factory methods for creating commonly-used + * middleware functions. + */ +public final class CommonMiddleware { + + private static final Logger logger = LoggerFactory.getLogger(CommonMiddleware.class); + + private CommonMiddleware() { + // Utility class + } + + /** + * Creates a logging middleware that logs requests and responses. + * + * @param name + * the name to use in log messages + * @param + * input type + * @param + * output type + * @return a logging middleware + */ + public static Middleware logging(String name) { + return (request, context, next) -> { + logger.info("[{}] Request: {}", name, request); + Instant start = Instant.now(); + try { + O result = next.apply(request, context); + Duration duration = Duration.between(start, Instant.now()); + logger.info("[{}] Response ({}ms): {}", name, duration.toMillis(), result); + return result; + } catch (GenkitException e) { + Duration duration = Duration.between(start, Instant.now()); + logger.error("[{}] Error ({}ms): {}", name, duration.toMillis(), e.getMessage()); + throw e; + } + }; + } + + /** + * Creates a logging middleware with a custom logger. + * + * @param name + * the name to use in log messages + * @param customLogger + * the logger to use + * @param + * input type + * @param + * output type + * @return a logging middleware + */ + public static Middleware logging(String name, Logger customLogger) { + return (request, context, next) -> { + customLogger.info("[{}] Request: {}", name, request); + Instant start = Instant.now(); + try { + O result = next.apply(request, context); + Duration duration = Duration.between(start, Instant.now()); + customLogger.info("[{}] Response ({}ms): {}", name, duration.toMillis(), result); + return result; + } catch (GenkitException e) { + Duration duration = Duration.between(start, Instant.now()); + customLogger.error("[{}] Error ({}ms): {}", name, duration.toMillis(), e.getMessage()); + throw e; + } + }; + } + + /** + * Creates a timing middleware that measures execution time. + * + * @param callback + * callback to receive timing information (duration in milliseconds) + * @param + * input type + * @param + * output type + * @return a timing middleware + */ + public static Middleware timing(Consumer callback) { + return (request, context, next) -> { + Instant start = Instant.now(); + try { + return next.apply(request, context); + } finally { + Duration duration = Duration.between(start, Instant.now()); + callback.accept(duration.toMillis()); + } + }; + } + + /** + * Creates a retry middleware with exponential backoff. + * + * @param maxRetries + * maximum number of retry attempts + * @param initialDelayMs + * initial delay between retries in milliseconds + * @param + * input type + * @param + * output type + * @return a retry middleware + */ + public static Middleware retry(int maxRetries, long initialDelayMs) { + return retry(maxRetries, initialDelayMs, e -> true); + } + + /** + * Creates a retry middleware with exponential backoff and custom retry + * predicate. + * + * @param maxRetries + * maximum number of retry attempts + * @param initialDelayMs + * initial delay between retries in milliseconds + * @param shouldRetry + * predicate to determine if an exception should trigger a retry + * @param + * input type + * @param + * output type + * @return a retry middleware + */ + public static Middleware retry(int maxRetries, long initialDelayMs, + Function shouldRetry) { + return (request, context, next) -> { + int attempt = 0; + GenkitException lastException = null; + long delay = initialDelayMs; + + while (attempt <= maxRetries) { + try { + return next.apply(request, context); + } catch (GenkitException e) { + lastException = e; + if (attempt >= maxRetries || !shouldRetry.apply(e)) { + throw e; + } + attempt++; + logger.warn("Retry attempt {} after error: {}", attempt, e.getMessage()); + try { + Thread.sleep(delay); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new GenkitException("Retry interrupted", ie); + } + delay *= 2; // Exponential backoff + } + } + throw lastException; + }; + } + + /** + * Creates a validation middleware that validates the request before processing. + * + * @param validator + * the validation function (throws GenkitException on invalid input) + * @param + * input type + * @param + * output type + * @return a validation middleware + */ + public static Middleware validate(Consumer validator) { + return (request, context, next) -> { + validator.accept(request); + return next.apply(request, context); + }; + } + + /** + * Creates a transformation middleware that transforms the request before + * processing. + * + * @param transformer + * the transformation function + * @param + * input type + * @param + * output type + * @return a transformation middleware + */ + public static Middleware transformRequest(Function transformer) { + return (request, context, next) -> { + I transformed = transformer.apply(request); + return next.apply(transformed, context); + }; + } + + /** + * Creates a transformation middleware that transforms the response after + * processing. + * + * @param transformer + * the transformation function + * @param + * input type + * @param + * output type + * @return a transformation middleware + */ + public static Middleware transformResponse(Function transformer) { + return (request, context, next) -> { + O result = next.apply(request, context); + return transformer.apply(result); + }; + } + + /** + * Creates a caching middleware that caches results based on a key. + * + * @param cache + * the cache implementation + * @param keyExtractor + * function to extract cache key from request + * @param + * input type + * @param + * output type + * @return a caching middleware + */ + public static Middleware cache(MiddlewareCache cache, Function keyExtractor) { + return (request, context, next) -> { + String key = keyExtractor.apply(request); + O cached = cache.get(key); + if (cached != null) { + logger.debug("Cache hit for key: {}", key); + return cached; + } + O result = next.apply(request, context); + cache.put(key, result); + return result; + }; + } + + /** + * Creates an error handling middleware that catches and transforms exceptions. + * + * @param errorHandler + * the error handler function + * @param + * input type + * @param + * output type + * @return an error handling middleware + */ + public static Middleware errorHandler(Function errorHandler) { + return (request, context, next) -> { + try { + return next.apply(request, context); + } catch (GenkitException e) { + return errorHandler.apply(e); + } + }; + } + + /** + * Creates a conditional middleware that only applies if the predicate is true. + * + * @param predicate + * the condition to check + * @param middleware + * the middleware to apply if condition is true + * @param + * input type + * @param + * output type + * @return a conditional middleware + */ + public static Middleware conditional(BiPredicate predicate, + Middleware middleware) { + return (request, context, next) -> { + if (predicate.test(request, context)) { + return middleware.handle(request, context, next); + } + return next.apply(request, context); + }; + } + + /** + * Creates a before/after middleware that runs callbacks before and after + * execution. + * + * @param before + * callback to run before execution + * @param after + * callback to run after execution + * @param + * input type + * @param + * output type + * @return a before/after middleware + */ + public static Middleware beforeAfter(BiConsumer before, + BiConsumer after) { + return (request, context, next) -> { + if (before != null) { + before.accept(request, context); + } + O result = next.apply(request, context); + if (after != null) { + after.accept(result, context); + } + return result; + }; + } + + /** + * Creates a rate limiting middleware (simple token bucket implementation). + * + * @param maxRequests + * maximum requests allowed in the time window + * @param windowMs + * time window in milliseconds + * @param + * input type + * @param + * output type + * @return a rate limiting middleware + */ + public static Middleware rateLimit(int maxRequests, long windowMs) { + return new RateLimitMiddleware<>(maxRequests, windowMs); + } + + /** + * Creates a timeout middleware that throws an exception if execution takes too + * long. + * + * @param timeoutMs + * timeout in milliseconds + * @param + * input type + * @param + * output type + * @return a timeout middleware + */ + public static Middleware timeout(long timeoutMs) { + return (request, context, next) -> { + // Note: This is a simple implementation. For true timeout support, + // you would need to use CompletableFuture or similar async patterns. + Instant start = Instant.now(); + O result = next.apply(request, context); + Duration duration = Duration.between(start, Instant.now()); + if (duration.toMillis() > timeoutMs) { + logger.warn("Execution exceeded timeout: {}ms > {}ms", duration.toMillis(), timeoutMs); + } + return result; + }; + } + + /** + * Simple rate limiting middleware implementation. + */ + private static class RateLimitMiddleware implements Middleware { + + private final int maxRequests; + private final long windowMs; + private int requestCount; + private long windowStart; + + RateLimitMiddleware(int maxRequests, long windowMs) { + this.maxRequests = maxRequests; + this.windowMs = windowMs; + this.requestCount = 0; + this.windowStart = System.currentTimeMillis(); + } + + @Override + public synchronized O handle(I request, ActionContext context, MiddlewareNext next) + throws GenkitException { + long now = System.currentTimeMillis(); + + // Reset window if expired + if (now - windowStart >= windowMs) { + windowStart = now; + requestCount = 0; + } + + // Check rate limit + if (requestCount >= maxRequests) { + throw new GenkitException("Rate limit exceeded: " + maxRequests + " requests per " + windowMs + "ms"); + } + + requestCount++; + return next.apply(request, context); + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/middleware/Middleware.java b/java/core/src/main/java/com/google/genkit/core/middleware/Middleware.java new file mode 100644 index 0000000000..70da6cc945 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/middleware/Middleware.java @@ -0,0 +1,69 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.middleware; + +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** + * Middleware is a function that wraps action execution, allowing pre-processing + * and post-processing of requests and responses. + * + *

+ * Middleware functions receive the request, action context, and a "next" + * function to call the next middleware in the chain (or the actual action if at + * the end of the chain). + * + *

+ * Example usage: + * + *

+ * {@code
+ * Middleware loggingMiddleware = (request, context, next) -> {
+ * 	System.out.println("Before: " + request);
+ * 	String result = next.apply(request, context);
+ * 	System.out.println("After: " + result);
+ * 	return result;
+ * };
+ * }
+ * 
+ * + * @param + * The input type + * @param + * The output type + */ +@FunctionalInterface +public interface Middleware { + + /** + * Processes the request through this middleware. + * + * @param request + * the input request + * @param context + * the action context + * @param next + * the next function in the middleware chain + * @return the output response + * @throws GenkitException + * if processing fails + */ + O handle(I request, ActionContext context, MiddlewareNext next) throws GenkitException; +} diff --git a/java/core/src/main/java/com/google/genkit/core/middleware/MiddlewareCache.java b/java/core/src/main/java/com/google/genkit/core/middleware/MiddlewareCache.java new file mode 100644 index 0000000000..f365db3259 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/middleware/MiddlewareCache.java @@ -0,0 +1,60 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.middleware; + +/** + * MiddlewareCache is a simple cache interface for use with caching middleware. + * + * @param + * the value type + */ +public interface MiddlewareCache { + + /** + * Gets a value from the cache. + * + * @param key + * the cache key + * @return the cached value, or null if not found + */ + V get(String key); + + /** + * Puts a value in the cache. + * + * @param key + * the cache key + * @param value + * the value to cache + */ + void put(String key, V value); + + /** + * Removes a value from the cache. + * + * @param key + * the cache key + */ + void remove(String key); + + /** + * Clears all values from the cache. + */ + void clear(); +} diff --git a/java/core/src/main/java/com/google/genkit/core/middleware/MiddlewareChain.java b/java/core/src/main/java/com/google/genkit/core/middleware/MiddlewareChain.java new file mode 100644 index 0000000000..7f19a99cf3 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/middleware/MiddlewareChain.java @@ -0,0 +1,215 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.middleware; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.BiFunction; + +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** + * MiddlewareChain manages a list of middleware and provides execution of the + * complete chain. It implements the chain of responsibility pattern where each + * middleware can process or modify the request/response. + * + * @param + * The input type + * @param + * The output type + */ +public class MiddlewareChain { + + private final List> middlewareList; + + /** + * Creates a new MiddlewareChain. + */ + public MiddlewareChain() { + this.middlewareList = new ArrayList<>(); + } + + /** + * Creates a new MiddlewareChain with the given middleware. + * + * @param middlewareList + * the initial list of middleware + */ + public MiddlewareChain(List> middlewareList) { + this.middlewareList = new ArrayList<>(middlewareList); + } + + /** + * Creates a copy of this MiddlewareChain. + * + * @return a new MiddlewareChain with the same middleware + */ + public MiddlewareChain copy() { + return new MiddlewareChain<>(this.middlewareList); + } + + /** + * Adds a middleware to the chain. + * + * @param middleware + * the middleware to add + * @return this chain for fluent chaining + */ + public MiddlewareChain use(Middleware middleware) { + if (middleware != null) { + middlewareList.add(middleware); + } + return this; + } + + /** + * Adds multiple middleware to the chain. + * + * @param middlewareList + * the middleware to add + * @return this chain for fluent chaining + */ + public MiddlewareChain useAll(List> middlewareList) { + if (middlewareList != null) { + this.middlewareList.addAll(middlewareList); + } + return this; + } + + /** + * Inserts a middleware at the beginning of the chain. + * + * @param middleware + * the middleware to insert + * @return this chain for fluent chaining + */ + public MiddlewareChain useFirst(Middleware middleware) { + if (middleware != null) { + middlewareList.add(0, middleware); + } + return this; + } + + /** + * Returns an unmodifiable view of the middleware list. + * + * @return the middleware list + */ + public List> getMiddlewareList() { + return Collections.unmodifiableList(middlewareList); + } + + /** + * Returns the number of middleware in the chain. + * + * @return the middleware count + */ + public int size() { + return middlewareList.size(); + } + + /** + * Checks if the chain is empty. + * + * @return true if no middleware is registered + */ + public boolean isEmpty() { + return middlewareList.isEmpty(); + } + + /** + * Clears all middleware from the chain. + */ + public void clear() { + middlewareList.clear(); + } + + /** + * Executes the middleware chain with the given request, context, and final + * action. + * + * @param request + * the input request + * @param context + * the action context + * @param finalAction + * the final action to execute after all middleware + * @return the output response + * @throws GenkitException + * if execution fails + */ + public O execute(I request, ActionContext context, BiFunction finalAction) + throws GenkitException { + return dispatch(0, request, context, finalAction); + } + + /** + * Dispatches to the next middleware in the chain or the final action. + * + * @param index + * the current middleware index + * @param request + * the input request + * @param context + * the action context + * @param finalAction + * the final action to execute + * @return the output response + * @throws GenkitException + * if execution fails + */ + private O dispatch(int index, I request, ActionContext context, BiFunction finalAction) + throws GenkitException { + if (index >= middlewareList.size()) { + // End of middleware chain, execute the final action + return finalAction.apply(context, request); + } + + Middleware currentMiddleware = middlewareList.get(index); + + // Create the next function that will dispatch to the next middleware + MiddlewareNext next = (modifiedRequest, modifiedContext) -> dispatch(index + 1, + modifiedRequest != null ? modifiedRequest : request, + modifiedContext != null ? modifiedContext : context, finalAction); + + return currentMiddleware.handle(request, context, next); + } + + /** + * Creates a new MiddlewareChain with the specified middleware. + * + * @param middleware + * the middleware to include + * @param + * input type + * @param + * output type + * @return a new MiddlewareChain + */ + @SafeVarargs + public static MiddlewareChain of(Middleware... middleware) { + MiddlewareChain chain = new MiddlewareChain<>(); + for (Middleware m : middleware) { + chain.use(m); + } + return chain; + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/middleware/MiddlewareNext.java b/java/core/src/main/java/com/google/genkit/core/middleware/MiddlewareNext.java new file mode 100644 index 0000000000..6f3d7fa4ee --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/middleware/MiddlewareNext.java @@ -0,0 +1,49 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.middleware; + +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** + * MiddlewareNext represents the next function in the middleware chain. It is + * used by middleware to pass control to the next middleware or the actual + * action. + * + * @param + * The input type + * @param + * The output type + */ +@FunctionalInterface +public interface MiddlewareNext { + + /** + * Calls the next middleware in the chain or the actual action. + * + * @param request + * the input request (may be modified by the middleware) + * @param context + * the action context (may be modified by the middleware) + * @return the output response + * @throws GenkitException + * if processing fails + */ + O apply(I request, ActionContext context) throws GenkitException; +} diff --git a/java/core/src/main/java/com/google/genkit/core/middleware/SimpleCache.java b/java/core/src/main/java/com/google/genkit/core/middleware/SimpleCache.java new file mode 100644 index 0000000000..1809271ece --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/middleware/SimpleCache.java @@ -0,0 +1,101 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.middleware; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * SimpleCache is a thread-safe in-memory cache implementation for use with + * caching middleware. + * + * @param + * the value type + */ +public class SimpleCache implements MiddlewareCache { + + private final Map> cache; + private final long ttlMs; + + /** + * Creates a SimpleCache with no TTL (entries never expire). + */ + public SimpleCache() { + this(0); + } + + /** + * Creates a SimpleCache with the specified TTL. + * + * @param ttlMs + * time-to-live in milliseconds (0 for no expiration) + */ + public SimpleCache(long ttlMs) { + this.cache = new ConcurrentHashMap<>(); + this.ttlMs = ttlMs; + } + + @Override + public V get(String key) { + CacheEntry entry = cache.get(key); + if (entry == null) { + return null; + } + if (ttlMs > 0 && System.currentTimeMillis() - entry.timestamp > ttlMs) { + cache.remove(key); + return null; + } + return entry.value; + } + + @Override + public void put(String key, V value) { + cache.put(key, new CacheEntry<>(value, System.currentTimeMillis())); + } + + @Override + public void remove(String key) { + cache.remove(key); + } + + @Override + public void clear() { + cache.clear(); + } + + /** + * Returns the number of entries in the cache. + * + * @return the cache size + */ + public int size() { + return cache.size(); + } + + private static class CacheEntry { + + final V value; + final long timestamp; + + CacheEntry(V value, long timestamp) { + this.value = value; + this.timestamp = timestamp; + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/middleware/package-info.java b/java/core/src/main/java/com/google/genkit/core/middleware/package-info.java new file mode 100644 index 0000000000..b6595c0ea7 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/middleware/package-info.java @@ -0,0 +1,58 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Middleware support for Genkit Java. + * + *

+ * This package provides a middleware pattern implementation for wrapping action + * execution with pre-processing and post-processing logic. Middleware can be + * used for: + *

    + *
  • Logging and monitoring
  • + *
  • Request/response transformation
  • + *
  • Caching
  • + *
  • Rate limiting
  • + *
  • Retry logic
  • + *
  • Validation
  • + *
  • Error handling
  • + *
+ * + *

+ * Example usage: + * + *

+ * {@code
+ * // Create a middleware chain
+ * MiddlewareChain chain = new MiddlewareChain<>();
+ * chain.use(CommonMiddleware.logging("myAction"));
+ * chain.use(CommonMiddleware.retry(3, 100));
+ *
+ * // Execute with middleware
+ * String result = chain.execute(input, context, (ctx, req) -> {
+ * 	// Actual action logic
+ * 	return "Hello, " + req;
+ * });
+ * }
+ * 
+ * + * @see com.google.genkit.core.middleware.Middleware + * @see com.google.genkit.core.middleware.MiddlewareChain + * @see com.google.genkit.core.middleware.CommonMiddleware + */ +package com.google.genkit.core.middleware; diff --git a/java/core/src/main/java/com/google/genkit/core/tracing/GenkitSpanData.java b/java/core/src/main/java/com/google/genkit/core/tracing/GenkitSpanData.java new file mode 100644 index 0000000000..63b9fb5230 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/tracing/GenkitSpanData.java @@ -0,0 +1,465 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.tracing; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonGetter; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSetter; + +/** + * GenkitSpanData represents information about a trace span. This format matches + * the telemetry server API expectations. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class GenkitSpanData { + + @JsonProperty("spanId") + private String spanId; + + @JsonProperty("traceId") + private String traceId; + + @JsonProperty("parentSpanId") + private String parentSpanId; + + @JsonProperty("startTime") + private long startTime; + + @JsonProperty("endTime") + private long endTime; + + @JsonProperty("attributes") + private Map attributes; + + // displayName is required by the telemetry server schema, so always include it + // Note: We use @JsonGetter on the getter to ensure the null-check is applied + // during serialization + private String displayName = ""; + + @JsonProperty("links") + private List links; + + @JsonProperty("instrumentationLibrary") + private InstrumentationScope instrumentationScope; + + @JsonProperty("spanKind") + private String spanKind; + + @JsonProperty("sameProcessAsParentSpan") + private BoolValue sameProcessAsParentSpan; + + @JsonProperty("status") + private Status status; + + @JsonProperty("timeEvents") + @JsonInclude(JsonInclude.Include.NON_NULL) + private TimeEvents timeEvents; + + public GenkitSpanData() { + this.attributes = new HashMap<>(); + this.sameProcessAsParentSpan = new BoolValue(true); + this.status = new Status(); + // Don't initialize timeEvents - it should be null when there are no events + // The TypeScript schema defines timeEvents as optional + } + + // Getters and setters + + public String getSpanId() { + return spanId; + } + + public void setSpanId(String spanId) { + this.spanId = spanId; + } + + public String getTraceId() { + return traceId; + } + + public void setTraceId(String traceId) { + this.traceId = traceId; + } + + public String getParentSpanId() { + return parentSpanId; + } + + public void setParentSpanId(String parentSpanId) { + this.parentSpanId = parentSpanId; + } + + public long getStartTime() { + return startTime; + } + + public void setStartTime(long startTime) { + this.startTime = startTime; + } + + public long getEndTime() { + return endTime; + } + + public void setEndTime(long endTime) { + this.endTime = endTime; + } + + public Map getAttributes() { + return attributes; + } + + public void setAttributes(Map attributes) { + this.attributes = attributes; + } + + public void addAttribute(String key, Object value) { + this.attributes.put(key, value); + } + + @JsonGetter("displayName") + @JsonInclude(JsonInclude.Include.ALWAYS) + public String getDisplayName() { + // Never return null - telemetry server requires a string + return displayName != null ? displayName : ""; + } + + @JsonSetter("displayName") + public void setDisplayName(String displayName) { + // Never accept null - telemetry server requires a string + this.displayName = displayName != null ? displayName : ""; + } + + public List getLinks() { + return links; + } + + public void setLinks(List links) { + this.links = links; + } + + public InstrumentationScope getInstrumentationScope() { + return instrumentationScope; + } + + public void setInstrumentationScope(InstrumentationScope instrumentationScope) { + this.instrumentationScope = instrumentationScope; + } + + public String getSpanKind() { + return spanKind; + } + + public void setSpanKind(String spanKind) { + this.spanKind = spanKind; + } + + public BoolValue getSameProcessAsParentSpan() { + return sameProcessAsParentSpan; + } + + public void setSameProcessAsParentSpan(BoolValue sameProcessAsParentSpan) { + this.sameProcessAsParentSpan = sameProcessAsParentSpan; + } + + public Status getStatus() { + return status; + } + + public void setStatus(Status status) { + this.status = status; + } + + public TimeEvents getTimeEvents() { + return timeEvents; + } + + public void setTimeEvents(TimeEvents timeEvents) { + this.timeEvents = timeEvents; + } + + /** + * BoolValue wraps a boolean to match the expected JSON format. + */ + public static class BoolValue { + @JsonProperty("value") + private boolean value; + + public BoolValue() { + } + + public BoolValue(boolean value) { + this.value = value; + } + + public boolean getValue() { + return value; + } + + public void setValue(boolean value) { + this.value = value; + } + } + + /** + * Status represents the span status. + */ + public static class Status { + @JsonProperty("code") + private int code; + + @JsonProperty("message") + @JsonInclude(JsonInclude.Include.NON_EMPTY) + private String message; + + public Status() { + this.code = 0; // OK + } + + public Status(int code, String message) { + this.code = code; + this.message = message; + } + + public int getCode() { + return code; + } + + public void setCode(int code) { + this.code = code; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + } + + /** + * InstrumentationScope represents the instrumentation library. + */ + public static class InstrumentationScope { + @JsonProperty("name") + private String name; + + @JsonProperty("version") + private String version; + + @JsonProperty("schemaUrl") + @JsonInclude(JsonInclude.Include.NON_EMPTY) + private String schemaUrl; + + public InstrumentationScope() { + } + + public InstrumentationScope(String name, String version) { + this.name = name; + this.version = version; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getVersion() { + return version; + } + + public void setVersion(String version) { + this.version = version; + } + + public String getSchemaUrl() { + return schemaUrl; + } + + public void setSchemaUrl(String schemaUrl) { + this.schemaUrl = schemaUrl; + } + } + + /** + * Link describes the relationship between two Spans. + */ + public static class Link { + @JsonProperty("context") + private SpanContextData context; + + @JsonProperty("attributes") + private Map attributes; + + @JsonProperty("droppedAttributesCount") + private int droppedAttributesCount; + + public SpanContextData getContext() { + return context; + } + + public void setContext(SpanContextData context) { + this.context = context; + } + + public Map getAttributes() { + return attributes; + } + + public void setAttributes(Map attributes) { + this.attributes = attributes; + } + + public int getDroppedAttributesCount() { + return droppedAttributesCount; + } + + public void setDroppedAttributesCount(int droppedAttributesCount) { + this.droppedAttributesCount = droppedAttributesCount; + } + } + + /** + * SpanContextData contains identifying trace information about a Span. + */ + public static class SpanContextData { + @JsonProperty("traceId") + private String traceId; + + @JsonProperty("spanId") + private String spanId; + + @JsonProperty("isRemote") + private boolean isRemote; + + @JsonProperty("traceFlags") + private int traceFlags; + + public String getTraceId() { + return traceId; + } + + public void setTraceId(String traceId) { + this.traceId = traceId; + } + + public String getSpanId() { + return spanId; + } + + public void setSpanId(String spanId) { + this.spanId = spanId; + } + + public boolean isRemote() { + return isRemote; + } + + public void setRemote(boolean remote) { + isRemote = remote; + } + + public int getTraceFlags() { + return traceFlags; + } + + public void setTraceFlags(int traceFlags) { + this.traceFlags = traceFlags; + } + } + + /** + * TimeEvents holds time-based events. + */ + public static class TimeEvents { + @JsonProperty("timeEvent") + private List timeEvent; + + public List getTimeEvent() { + return timeEvent; + } + + public void setTimeEvent(List timeEvent) { + this.timeEvent = timeEvent; + } + } + + /** + * TimeEvent represents a time-based event. + */ + public static class TimeEvent { + @JsonProperty("time") + private long time; + + @JsonProperty("annotation") + private Annotation annotation; + + public long getTime() { + return time; + } + + public void setTime(long time) { + this.time = time; + } + + public Annotation getAnnotation() { + return annotation; + } + + public void setAnnotation(Annotation annotation) { + this.annotation = annotation; + } + } + + /** + * Annotation represents an annotation. + */ + public static class Annotation { + @JsonProperty("attributes") + private Map attributes; + + @JsonProperty("description") + private String description; + + public Map getAttributes() { + return attributes; + } + + public void setAttributes(Map attributes) { + this.attributes = attributes; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/tracing/HttpTelemetryClient.java b/java/core/src/main/java/com/google/genkit/core/tracing/HttpTelemetryClient.java new file mode 100644 index 0000000000..9f40082d16 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/tracing/HttpTelemetryClient.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.tracing; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * HTTP-based telemetry client that sends traces to the Genkit telemetry server. + */ +public class HttpTelemetryClient implements TelemetryClient { + + private static final Logger logger = LoggerFactory.getLogger(HttpTelemetryClient.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private final String serverUrl; + private final HttpClient httpClient; + + /** + * Creates a new HTTP telemetry client. + * + * @param serverUrl + * the URL of the telemetry server + */ + public HttpTelemetryClient(String serverUrl) { + this.serverUrl = serverUrl.endsWith("/") ? serverUrl.substring(0, serverUrl.length() - 1) : serverUrl; + this.httpClient = HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(10)).build(); + } + + @Override + public void save(TraceData trace) throws Exception { + if (serverUrl == null || serverUrl.isEmpty()) { + logger.debug("Telemetry server URL not configured, skipping trace export"); + return; + } + + String json = objectMapper.writeValueAsString(trace); + + HttpRequest request = HttpRequest.newBuilder().uri(URI.create(serverUrl + "/api/traces")) + .header("Content-Type", "application/json").header("Accept", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(json)).timeout(Duration.ofSeconds(30)).build(); + + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + + if (response.statusCode() != 200) { + logger.warn("Failed to send trace to telemetry server: status={}, body={}", response.statusCode(), + response.body()); + throw new RuntimeException("Failed to send trace: HTTP " + response.statusCode()); + } + + logger.debug("Trace sent to telemetry server: traceId={}", trace.getTraceId()); + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/tracing/SpanContext.java b/java/core/src/main/java/com/google/genkit/core/tracing/SpanContext.java new file mode 100644 index 0000000000..85f5e4ec59 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/tracing/SpanContext.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.tracing; + +/** + * SpanContext contains trace and span identifiers for distributed tracing. + */ +public class SpanContext { + + private final String traceId; + private final String spanId; + private final String parentSpanId; + + /** + * Creates a new SpanContext. + * + * @param traceId + * the trace ID + * @param spanId + * the span ID + * @param parentSpanId + * the parent span ID, may be null + */ + public SpanContext(String traceId, String spanId, String parentSpanId) { + this.traceId = traceId; + this.spanId = spanId; + this.parentSpanId = parentSpanId; + } + + /** + * Returns the trace ID. + * + * @return the trace ID + */ + public String getTraceId() { + return traceId; + } + + /** + * Returns the span ID. + * + * @return the span ID + */ + public String getSpanId() { + return spanId; + } + + /** + * Returns the parent span ID. + * + * @return the parent span ID, or null if this is a root span + */ + public String getParentSpanId() { + return parentSpanId; + } + + /** + * Returns true if this span has a parent. + * + * @return true if this span has a parent + */ + public boolean hasParent() { + return parentSpanId != null && !parentSpanId.isEmpty(); + } + + @Override + public String toString() { + return "SpanContext{" + "traceId='" + traceId + '\'' + ", spanId='" + spanId + '\'' + ", parentSpanId='" + + parentSpanId + '\'' + '}'; + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/tracing/SpanMetadata.java b/java/core/src/main/java/com/google/genkit/core/tracing/SpanMetadata.java new file mode 100644 index 0000000000..43d3912870 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/tracing/SpanMetadata.java @@ -0,0 +1,155 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.tracing; + +import java.util.HashMap; +import java.util.Map; + +/** + * SpanMetadata contains metadata for a tracing span. + */ +public class SpanMetadata { + + private String name; + private String type; + private String subtype; + private Map attributes; + + /** + * Creates a new SpanMetadata. + */ + public SpanMetadata() { + this.attributes = new HashMap<>(); + } + + /** + * Creates a new SpanMetadata with the specified values. + * + * @param name + * the span name + * @param type + * the span type + * @param subtype + * the span subtype + * @param attributes + * additional attributes + */ + public SpanMetadata(String name, String type, String subtype, Map attributes) { + this.name = name; + this.type = type; + this.subtype = subtype; + this.attributes = attributes != null ? new HashMap<>(attributes) : new HashMap<>(); + } + + /** + * Creates a builder for SpanMetadata. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + // Getters and setters + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public String getSubtype() { + return subtype; + } + + public void setSubtype(String subtype) { + this.subtype = subtype; + } + + public Map getAttributes() { + return attributes; + } + + public void setAttributes(Map attributes) { + this.attributes = attributes; + } + + /** + * Adds an attribute to the span metadata. + * + * @param key + * the attribute key + * @param value + * the attribute value + * @return this SpanMetadata for chaining + */ + public SpanMetadata addAttribute(String key, Object value) { + this.attributes.put(key, value); + return this; + } + + /** + * Builder for SpanMetadata. + */ + public static class Builder { + private String name; + private String type; + private String subtype; + private Map attributes = new HashMap<>(); + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder type(String type) { + this.type = type; + return this; + } + + public Builder subtype(String subtype) { + this.subtype = subtype; + return this; + } + + public Builder attributes(Map attributes) { + this.attributes = new HashMap<>(attributes); + return this; + } + + public Builder addAttribute(String key, Object value) { + this.attributes.put(key, value); + return this; + } + + public SpanMetadata build() { + return new SpanMetadata(name, type, subtype, attributes); + } + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/tracing/TelemetryClient.java b/java/core/src/main/java/com/google/genkit/core/tracing/TelemetryClient.java new file mode 100644 index 0000000000..8a52efdf86 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/tracing/TelemetryClient.java @@ -0,0 +1,35 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.tracing; + +/** + * TelemetryClient interface for sending traces to a telemetry server. + */ +public interface TelemetryClient { + + /** + * Saves trace data to the telemetry server. + * + * @param trace + * the trace data to save + * @throws Exception + * if the save operation fails + */ + void save(TraceData trace) throws Exception; +} diff --git a/java/core/src/main/java/com/google/genkit/core/tracing/TelemetryServerExporter.java b/java/core/src/main/java/com/google/genkit/core/tracing/TelemetryServerExporter.java new file mode 100644 index 0000000000..fa7bf47c38 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/tracing/TelemetryServerExporter.java @@ -0,0 +1,294 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.tracing; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.context.Context; +import io.opentelemetry.sdk.common.CompletableResultCode; +import io.opentelemetry.sdk.trace.SpanProcessor; +import io.opentelemetry.sdk.trace.data.EventData; +import io.opentelemetry.sdk.trace.data.LinkData; +import io.opentelemetry.sdk.trace.data.SpanData; + +/** + * OpenTelemetry SpanProcessor that exports spans to the Genkit telemetry + * server. This enables traces to be visible in the Genkit Developer UI. + */ +public class TelemetryServerExporter implements SpanProcessor { + + private static final Logger logger = LoggerFactory.getLogger(TelemetryServerExporter.class); + private static final String INSTRUMENTATION_NAME = "genkit-java"; + private static final String INSTRUMENTATION_VERSION = "1.0.0"; + + private final AtomicReference clientRef = new AtomicReference<>(); + + // Buffer spans by trace ID to send complete traces + private final Map traceBuffer = new ConcurrentHashMap<>(); + + /** + * Creates a new TelemetryServerExporter. + */ + public TelemetryServerExporter() { + } + + /** + * Sets the telemetry client to use for exporting traces. + * + * @param client + * the telemetry client + */ + public void setClient(TelemetryClient client) { + this.clientRef.set(client); + logger.debug("Telemetry client configured"); + } + + /** + * Returns true if the exporter is configured with a client. + */ + public boolean isConfigured() { + return clientRef.get() != null; + } + + @Override + public void onStart(Context parentContext, io.opentelemetry.sdk.trace.ReadWriteSpan span) { + // No action needed on start + } + + @Override + public boolean isStartRequired() { + return false; + } + + @Override + public void onEnd(io.opentelemetry.sdk.trace.ReadableSpan span) { + TelemetryClient client = clientRef.get(); + if (client == null) { + logger.trace("No telemetry client configured, skipping span export"); + return; + } + + try { + SpanData otelSpanData = span.toSpanData(); + String traceId = otelSpanData.getTraceId(); + String spanId = otelSpanData.getSpanId(); + + // Convert OpenTelemetry span to our format + GenkitSpanData genkitSpanData = convertSpan(otelSpanData); + + // Get or create trace data + TraceData traceData = traceBuffer.computeIfAbsent(traceId, TraceData::new); + traceData.addSpan(genkitSpanData); + + // If this is a root span (no parent), set trace-level info and export + String parentSpanId = otelSpanData.getParentSpanId(); + if (parentSpanId == null || parentSpanId.isEmpty() || "0000000000000000".equals(parentSpanId)) { + traceData.setDisplayName(otelSpanData.getName()); + traceData.setStartTime(toMillis(otelSpanData.getStartEpochNanos())); + traceData.setEndTime(toMillis(otelSpanData.getEndEpochNanos())); + + // Export the trace + exportTrace(client, traceData); + + // Remove from buffer + traceBuffer.remove(traceId); + } else { + // For non-root spans, still try to export incrementally + // This ensures traces show up in the UI even before completion + exportTrace(client, traceData); + } + + } catch (Exception e) { + logger.error("Failed to export span to telemetry server", e); + } + } + + @Override + public boolean isEndRequired() { + return true; + } + + @Override + public CompletableResultCode shutdown() { + // Export any remaining buffered traces + TelemetryClient client = clientRef.get(); + if (client != null) { + for (TraceData trace : traceBuffer.values()) { + try { + client.save(trace); + } catch (Exception e) { + logger.error("Failed to export trace during shutdown", e); + } + } + } + traceBuffer.clear(); + return CompletableResultCode.ofSuccess(); + } + + @Override + public CompletableResultCode forceFlush() { + // Export all buffered traces + TelemetryClient client = clientRef.get(); + if (client != null) { + for (Map.Entry entry : traceBuffer.entrySet()) { + try { + client.save(entry.getValue()); + } catch (Exception e) { + logger.error("Failed to export trace during flush", e); + } + } + } + return CompletableResultCode.ofSuccess(); + } + + private void exportTrace(TelemetryClient client, TraceData traceData) { + try { + client.save(traceData); + } catch (Exception e) { + logger.error("Failed to export trace: traceId={}", traceData.getTraceId(), e); + } + } + + private GenkitSpanData convertSpan(SpanData otelSpan) { + GenkitSpanData span = new GenkitSpanData(); + + span.setSpanId(otelSpan.getSpanId()); + span.setTraceId(otelSpan.getTraceId()); + // displayName is required by the telemetry server schema - ensure it's never + // null + String spanName = otelSpan.getName(); + span.setDisplayName(spanName != null ? spanName : "unknown"); + span.setStartTime(toMillis(otelSpan.getStartEpochNanos())); + span.setEndTime(toMillis(otelSpan.getEndEpochNanos())); + span.setSpanKind(otelSpan.getKind().name()); + + String parentSpanId = otelSpan.getParentSpanId(); + if (parentSpanId != null && !parentSpanId.isEmpty() && !"0000000000000000".equals(parentSpanId)) { + span.setParentSpanId(parentSpanId); + } + + // Convert attributes + Map attributes = new HashMap<>(); + otelSpan.getAttributes().forEach((key, value) -> { + attributes.put(key.getKey(), value); + }); + span.setAttributes(attributes); + + // Convert status + GenkitSpanData.Status status = new GenkitSpanData.Status(); + status.setCode(convertStatusCode(otelSpan.getStatus().getStatusCode())); + if (otelSpan.getStatus().getDescription() != null) { + status.setMessage(otelSpan.getStatus().getDescription()); + } + span.setStatus(status); + + // Set instrumentation scope - name is required by the schema + GenkitSpanData.InstrumentationScope scope = new GenkitSpanData.InstrumentationScope(); + String scopeName = otelSpan.getInstrumentationScopeInfo().getName(); + scope.setName(scopeName != null && !scopeName.isEmpty() ? scopeName : "genkit-java"); + // Version is optional but default to 1.0.0 if not set + String version = otelSpan.getInstrumentationScopeInfo().getVersion(); + scope.setVersion(version != null ? version : "1.0.0"); + span.setInstrumentationScope(scope); + + // Convert events to time events + List events = otelSpan.getEvents(); + if (events != null && !events.isEmpty()) { + GenkitSpanData.TimeEvents timeEvents = new GenkitSpanData.TimeEvents(); + List timeEventList = new ArrayList<>(); + + for (EventData event : events) { + GenkitSpanData.TimeEvent timeEvent = new GenkitSpanData.TimeEvent(); + timeEvent.setTime(toMillis(event.getEpochNanos())); + + GenkitSpanData.Annotation annotation = new GenkitSpanData.Annotation(); + annotation.setDescription(event.getName()); + + Map eventAttrs = new HashMap<>(); + event.getAttributes().forEach((key, value) -> { + eventAttrs.put(key.getKey(), value); + }); + annotation.setAttributes(eventAttrs); + + timeEvent.setAnnotation(annotation); + timeEventList.add(timeEvent); + } + + timeEvents.setTimeEvent(timeEventList); + span.setTimeEvents(timeEvents); + } + + // Convert links + List links = otelSpan.getLinks(); + if (links != null && !links.isEmpty()) { + List linkList = new ArrayList<>(); + for (LinkData link : links) { + GenkitSpanData.Link l = new GenkitSpanData.Link(); + + GenkitSpanData.SpanContextData ctx = new GenkitSpanData.SpanContextData(); + ctx.setTraceId(link.getSpanContext().getTraceId()); + ctx.setSpanId(link.getSpanContext().getSpanId()); + ctx.setRemote(link.getSpanContext().isRemote()); + ctx.setTraceFlags(link.getSpanContext().getTraceFlags().asByte()); + l.setContext(ctx); + + Map linkAttrs = new HashMap<>(); + link.getAttributes().forEach((key, value) -> { + linkAttrs.put(key.getKey(), value); + }); + l.setAttributes(linkAttrs); + l.setDroppedAttributesCount(link.getTotalAttributeCount() - link.getAttributes().size()); + + linkList.add(l); + } + span.setLinks(linkList); + } + + // Set sameProcessAsParentSpan + span.setSameProcessAsParentSpan(new GenkitSpanData.BoolValue(!otelSpan.getSpanContext().isRemote())); + + return span; + } + + private int convertStatusCode(StatusCode statusCode) { + switch (statusCode) { + case OK : + return 0; + case ERROR : + return 2; + case UNSET : + default : + return 0; + } + } + + private long toMillis(long nanos) { + return TimeUnit.NANOSECONDS.toMillis(nanos); + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/tracing/TraceData.java b/java/core/src/main/java/com/google/genkit/core/tracing/TraceData.java new file mode 100644 index 0000000000..2cc4a6586c --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/tracing/TraceData.java @@ -0,0 +1,101 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.tracing; + +import java.util.HashMap; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * TraceData represents a complete trace with all its spans. This format matches + * the telemetry server API expectations. + */ +public class TraceData { + + @JsonProperty("traceId") + private String traceId; + + @JsonProperty("displayName") + @JsonInclude(JsonInclude.Include.NON_NULL) + private String displayName; + + @JsonProperty("startTime") + private long startTime; + + @JsonProperty("endTime") + private long endTime; + + @JsonProperty("spans") + private Map spans; + + public TraceData() { + this.spans = new HashMap<>(); + } + + public TraceData(String traceId) { + this.traceId = traceId; + this.spans = new HashMap<>(); + } + + public String getTraceId() { + return traceId; + } + + public void setTraceId(String traceId) { + this.traceId = traceId; + } + + public String getDisplayName() { + return displayName; + } + + public void setDisplayName(String displayName) { + this.displayName = displayName; + } + + public long getStartTime() { + return startTime; + } + + public void setStartTime(long startTime) { + this.startTime = startTime; + } + + public long getEndTime() { + return endTime; + } + + public void setEndTime(long endTime) { + this.endTime = endTime; + } + + public Map getSpans() { + return spans; + } + + public void setSpans(Map spans) { + this.spans = spans; + } + + public void addSpan(GenkitSpanData span) { + this.spans.put(span.getSpanId(), span); + } +} diff --git a/java/core/src/main/java/com/google/genkit/core/tracing/Tracer.java b/java/core/src/main/java/com/google/genkit/core/tracing/Tracer.java new file mode 100644 index 0000000000..c86c67fe12 --- /dev/null +++ b/java/core/src/main/java/com/google/genkit/core/tracing/Tracer.java @@ -0,0 +1,322 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.tracing; + +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiFunction; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.context.Scope; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.trace.SdkTracerProvider; + +/** + * Tracer provides tracing utilities for Genkit operations. It integrates with + * OpenTelemetry for distributed tracing. + */ +public final class Tracer { + + private static final Logger logger = LoggerFactory.getLogger(Tracer.class); + private static final String INSTRUMENTATION_NAME = "genkit-java"; + private static final AtomicBoolean initialized = new AtomicBoolean(false); + private static final ObjectMapper objectMapper = new ObjectMapper(); + private static volatile io.opentelemetry.api.trace.Tracer otelTracer; + private static volatile TelemetryServerExporter telemetryExporter; + private static volatile SdkTracerProvider tracerProvider; + private static volatile String configuredTelemetryServerUrl; + + static { + initializeTracer(); + } + + private Tracer() { + // Utility class + } + + /** + * Initializes the OpenTelemetry tracer with the telemetry exporter. + */ + private static synchronized void initializeTracer() { + if (initialized.compareAndSet(false, true)) { + try { + // Create the telemetry exporter + telemetryExporter = new TelemetryServerExporter(); + + // Create SDK tracer provider with our exporter + tracerProvider = SdkTracerProvider.builder().addSpanProcessor(telemetryExporter).build(); + + // Build the OpenTelemetry SDK - try to register globally + OpenTelemetrySdk openTelemetry; + try { + openTelemetry = OpenTelemetrySdk.builder().setTracerProvider(tracerProvider) + .buildAndRegisterGlobal(); + } catch (IllegalStateException e) { + // GlobalOpenTelemetry was already set - just build without registering + logger.debug("GlobalOpenTelemetry already set, building local SDK: {}", e.getMessage()); + openTelemetry = OpenTelemetrySdk.builder().setTracerProvider(tracerProvider).build(); + } + + otelTracer = openTelemetry.getTracer(INSTRUMENTATION_NAME); + + logger.debug("OpenTelemetry tracer initialized with telemetry exporter"); + } catch (Exception e) { + logger.error("Failed to initialize OpenTelemetry tracer", e); + } + + // Check for environment variable for telemetry server + String telemetryServerUrl = System.getenv("GENKIT_TELEMETRY_SERVER"); + if (telemetryServerUrl != null && !telemetryServerUrl.isEmpty()) { + configureTelemetryServer(telemetryServerUrl); + } + } + } + + /** + * Configures the telemetry server URL for exporting traces. This is typically + * called when the CLI notifies the runtime of the telemetry server URL. + * + * @param serverUrl + * the telemetry server URL + */ + public static void configureTelemetryServer(String serverUrl) { + if (serverUrl != null && !serverUrl.isEmpty() && telemetryExporter != null) { + // Skip if already configured with the same URL + if (serverUrl.equals(configuredTelemetryServerUrl)) { + return; + } + telemetryExporter.setClient(new HttpTelemetryClient(serverUrl)); + configuredTelemetryServerUrl = serverUrl; + logger.info("Connected to telemetry server: {}", serverUrl); + } + } + + /** + * Returns true if the telemetry exporter is configured. + */ + public static boolean isTelemetryConfigured() { + return telemetryExporter != null && telemetryExporter.isConfigured(); + } + + /** + * Runs a function within a new tracing span. + * + * @param ctx + * the action context + * @param metadata + * the span metadata + * @param input + * the input to pass to the function + * @param fn + * the function to execute + * @param + * the input type + * @param + * the output type + * @return the function result + * @throws GenkitException + * if the function throws an exception + */ + public static O runInNewSpan(ActionContext ctx, SpanMetadata metadata, I input, + BiFunction fn) throws GenkitException { + String spanName = metadata.getName() != null ? metadata.getName() : "unknown"; + + // Determine if this is a root span + boolean isRoot = ctx.getSpanContext() == null; + + // Build the path for this span + String parentPath = isRoot ? "" : ctx.getSpanPath(); + String path = buildPath(spanName, parentPath, metadata.getType(), metadata.getSubtype()); + + Span span = otelTracer.spanBuilder(spanName).setSpanKind(SpanKind.INTERNAL).startSpan(); + + // Add genkit-specific attributes + span.setAttribute("genkit:name", spanName); + span.setAttribute("genkit:path", path); + span.setAttribute("genkit:isRoot", isRoot); + + // Add input as JSON + if (input != null) { + try { + span.setAttribute("genkit:input", objectMapper.writeValueAsString(input)); + } catch (JsonProcessingException e) { + span.setAttribute("genkit:input", input.toString()); + } + } + + // Add attributes from metadata + if (metadata.getType() != null) { + span.setAttribute("genkit:type", metadata.getType()); + } + if (metadata.getSubtype() != null) { + // Use genkit:metadata:subtype to match JS/Go SDK format + span.setAttribute("genkit:metadata:subtype", metadata.getSubtype()); + } + + // Add session and thread info from context for multi-turn conversation tracking + if (ctx.getSessionId() != null) { + span.setAttribute("genkit:sessionId", ctx.getSessionId()); + } + if (ctx.getThreadName() != null) { + span.setAttribute("genkit:threadName", ctx.getThreadName()); + } + + if (metadata.getAttributes() != null) { + for (Map.Entry entry : metadata.getAttributes().entrySet()) { + if (entry.getValue() instanceof String) { + span.setAttribute(entry.getKey(), (String) entry.getValue()); + } else if (entry.getValue() instanceof Long) { + span.setAttribute(entry.getKey(), (Long) entry.getValue()); + } else if (entry.getValue() instanceof Double) { + span.setAttribute(entry.getKey(), (Double) entry.getValue()); + } else if (entry.getValue() instanceof Boolean) { + span.setAttribute(entry.getKey(), (Boolean) entry.getValue()); + } else if (entry.getValue() != null) { + span.setAttribute(entry.getKey(), entry.getValue().toString()); + } + } + } + + io.opentelemetry.api.trace.SpanContext otelSpanContext = span.getSpanContext(); + SpanContext spanContext = new SpanContext(otelSpanContext.getTraceId(), otelSpanContext.getSpanId(), + ctx.getSpanContext() != null ? ctx.getSpanContext().getSpanId() : null); + + try (Scope scope = span.makeCurrent()) { + O result = fn.apply(spanContext, input); + + // Add output as JSON + if (result != null) { + try { + span.setAttribute("genkit:output", objectMapper.writeValueAsString(result)); + } catch (JsonProcessingException e) { + span.setAttribute("genkit:output", result.toString()); + } + } + + span.setAttribute("genkit:state", "success"); + span.setStatus(StatusCode.OK); + return result; + } catch (GenkitException e) { + span.setAttribute("genkit:state", "error"); + span.setStatus(StatusCode.ERROR, e.getMessage()); + span.recordException(e); + throw e; + } catch (RuntimeException e) { + // Re-throw RuntimeExceptions as-is (includes AgentHandoffException, + // ToolInterruptException, etc.) + span.setAttribute("genkit:state", "error"); + span.setStatus(StatusCode.ERROR, e.getMessage()); + span.recordException(e); + throw e; + } catch (Exception e) { + span.setAttribute("genkit:state", "error"); + span.setStatus(StatusCode.ERROR, e.getMessage()); + span.recordException(e); + throw new GenkitException("Span execution failed: " + e.getMessage(), e); + } finally { + span.end(); + } + } + + /** + * Builds an annotated path for the span. Format: + * /{name,t:type}/{name,t:type,s:subtype} + */ + private static String buildPath(String name, String parentPath, String type, String subtype) { + StringBuilder segment = new StringBuilder("{").append(name); + if (type != null && !type.isEmpty()) { + segment.append(",t:").append(type); + } + if (subtype != null && !subtype.isEmpty()) { + segment.append(",s:").append(subtype); + } + segment.append("}"); + + return (parentPath != null ? parentPath : "") + "/" + segment; + } + + /** + * Creates a new root span context. + * + * @return a new SpanContext with a unique trace ID + */ + public static SpanContext newRootSpanContext() { + String traceId = UUID.randomUUID().toString().replace("-", ""); + String spanId = UUID.randomUUID().toString().replace("-", "").substring(0, 16); + return new SpanContext(traceId, spanId, null); + } + + /** + * Creates a child span context from a parent. + * + * @param parent + * the parent span context + * @return a new child SpanContext + */ + public static SpanContext newChildSpanContext(SpanContext parent) { + String spanId = UUID.randomUUID().toString().replace("-", "").substring(0, 16); + return new SpanContext(parent.getTraceId(), spanId, parent.getSpanId()); + } + + /** + * Adds an event to the current span. + * + * @param name + * the event name + * @param attributes + * the event attributes + */ + public static void addEvent(String name, Map attributes) { + Span currentSpan = Span.current(); + if (currentSpan != null) { + io.opentelemetry.api.common.AttributesBuilder attrBuilder = io.opentelemetry.api.common.Attributes + .builder(); + if (attributes != null) { + for (Map.Entry entry : attributes.entrySet()) { + attrBuilder.put(entry.getKey(), entry.getValue()); + } + } + currentSpan.addEvent(name, attrBuilder.build()); + } + } + + /** + * Records an exception on the current span. + * + * @param exception + * the exception to record + */ + public static void recordException(Throwable exception) { + Span currentSpan = Span.current(); + if (currentSpan != null) { + currentSpan.recordException(exception); + } + } +} diff --git a/java/core/src/test/java/com/google/genkit/core/ActionContextTest.java b/java/core/src/test/java/com/google/genkit/core/ActionContextTest.java new file mode 100644 index 0000000000..9d590f9476 --- /dev/null +++ b/java/core/src/test/java/com/google/genkit/core/ActionContextTest.java @@ -0,0 +1,167 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for ActionContext. + */ +class ActionContextTest { + + private Registry registry; + + @BeforeEach + void setUp() { + registry = new DefaultRegistry(); + } + + @Test + void testConstructorWithAllParameters() { + String flowName = "testFlow"; + String spanPath = "/flow/testFlow"; + String sessionId = "session123"; + String threadName = "thread1"; + + ActionContext context = new ActionContext(null, flowName, spanPath, registry, sessionId, threadName); + + assertNull(context.getSpanContext()); + assertEquals(flowName, context.getFlowName()); + assertEquals(spanPath, context.getSpanPath()); + assertEquals(registry, context.getRegistry()); + assertEquals(sessionId, context.getSessionId()); + assertEquals(threadName, context.getThreadName()); + } + + @Test + void testConstructorWithFourParameters() { + String flowName = "testFlow"; + String spanPath = "/flow/testFlow"; + + ActionContext context = new ActionContext(null, flowName, spanPath, registry); + + assertNull(context.getSpanContext()); + assertEquals(flowName, context.getFlowName()); + assertEquals(spanPath, context.getSpanPath()); + assertEquals(registry, context.getRegistry()); + assertNull(context.getSessionId()); + assertNull(context.getThreadName()); + } + + @Test + void testConstructorWithThreeParameters() { + String flowName = "testFlow"; + + ActionContext context = new ActionContext(null, flowName, registry); + + assertNull(context.getSpanContext()); + assertEquals(flowName, context.getFlowName()); + assertNull(context.getSpanPath()); + assertEquals(registry, context.getRegistry()); + } + + @Test + void testConstructorWithRegistryOnly() { + ActionContext context = new ActionContext(registry); + + assertNull(context.getSpanContext()); + assertNull(context.getFlowName()); + assertNull(context.getSpanPath()); + assertEquals(registry, context.getRegistry()); + assertNull(context.getSessionId()); + assertNull(context.getThreadName()); + } + + @Test + void testWithFlowName() { + ActionContext context = new ActionContext(registry); + String newFlowName = "newFlow"; + + ActionContext newContext = context.withFlowName(newFlowName); + + assertEquals(newFlowName, newContext.getFlowName()); + assertEquals(registry, newContext.getRegistry()); + } + + @Test + void testWithSpanPath() { + ActionContext context = new ActionContext(registry); + String newSpanPath = "/flow/test/step1"; + + ActionContext newContext = context.withSpanPath(newSpanPath); + + assertEquals(newSpanPath, newContext.getSpanPath()); + assertEquals(registry, newContext.getRegistry()); + } + + @Test + void testWithSessionId() { + ActionContext context = new ActionContext(registry); + String sessionId = "session456"; + + ActionContext newContext = context.withSessionId(sessionId); + + assertEquals(sessionId, newContext.getSessionId()); + assertEquals(registry, newContext.getRegistry()); + } + + @Test + void testWithThreadName() { + ActionContext context = new ActionContext(registry); + String threadName = "worker-thread"; + + ActionContext newContext = context.withThreadName(threadName); + + assertEquals(threadName, newContext.getThreadName()); + assertEquals(registry, newContext.getRegistry()); + } + + @Test + void testContextImmutability() { + ActionContext original = new ActionContext(null, "flow1", "/path", registry, "session1", "thread1"); + + ActionContext modified = original.withFlowName("flow2"); + + // Original should be unchanged + assertEquals("flow1", original.getFlowName()); + assertEquals("flow2", modified.getFlowName()); + } + + @Test + void testChainedWith() { + ActionContext context = new ActionContext(registry).withFlowName("myFlow").withSpanPath("/flow/myFlow") + .withSessionId("session789").withThreadName("main-thread"); + + assertEquals("myFlow", context.getFlowName()); + assertEquals("/flow/myFlow", context.getSpanPath()); + assertEquals("session789", context.getSessionId()); + assertEquals("main-thread", context.getThreadName()); + } + + @Test + void testNullSpanContext() { + ActionContext context = new ActionContext(null, "testFlow", registry); + + assertNull(context.getSpanContext()); + assertEquals("testFlow", context.getFlowName()); + } +} diff --git a/java/core/src/test/java/com/google/genkit/core/ActionTypeTest.java b/java/core/src/test/java/com/google/genkit/core/ActionTypeTest.java new file mode 100644 index 0000000000..95b9220df0 --- /dev/null +++ b/java/core/src/test/java/com/google/genkit/core/ActionTypeTest.java @@ -0,0 +1,177 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +/** + * Unit tests for ActionType. + */ +class ActionTypeTest { + + @Test + void testFlowType() { + assertEquals("flow", ActionType.FLOW.toString()); + } + + @Test + void testModelType() { + assertEquals("model", ActionType.MODEL.toString()); + } + + @Test + void testEmbedderType() { + assertEquals("embedder", ActionType.EMBEDDER.toString()); + } + + @Test + void testRetrieverType() { + assertEquals("retriever", ActionType.RETRIEVER.toString()); + } + + @Test + void testIndexerType() { + assertEquals("indexer", ActionType.INDEXER.toString()); + } + + @Test + void testEvaluatorType() { + assertEquals("evaluator", ActionType.EVALUATOR.toString()); + } + + @Test + void testToolType() { + assertEquals("tool", ActionType.TOOL.toString()); + } + + @Test + void testPromptType() { + assertEquals("prompt", ActionType.PROMPT.toString()); + } + + @Test + void testExecutablePromptType() { + assertEquals("executable-prompt", ActionType.EXECUTABLE_PROMPT.toString()); + } + + @Test + void testUtilType() { + assertEquals("util", ActionType.UTIL.toString()); + } + + @Test + void testCustomType() { + assertEquals("custom", ActionType.CUSTOM.toString()); + } + + @Test + void testKeyFromName() { + String name = "myAction"; + String key = ActionType.FLOW.keyFromName(name); + + assertEquals("/flow/myAction", key); + } + + @Test + void testKeyFromNameModel() { + String name = "gpt-4"; + String key = ActionType.MODEL.keyFromName(name); + + assertEquals("/model/gpt-4", key); + } + + @Test + void testKeyFromNameEmbedder() { + String name = "text-embedding-ada"; + String key = ActionType.EMBEDDER.keyFromName(name); + + assertEquals("/embedder/text-embedding-ada", key); + } + + @Test + void testFromValue() { + assertEquals(ActionType.FLOW, ActionType.fromValue("flow")); + assertEquals(ActionType.MODEL, ActionType.fromValue("model")); + assertEquals(ActionType.EMBEDDER, ActionType.fromValue("embedder")); + assertEquals(ActionType.RETRIEVER, ActionType.fromValue("retriever")); + assertEquals(ActionType.INDEXER, ActionType.fromValue("indexer")); + assertEquals(ActionType.EVALUATOR, ActionType.fromValue("evaluator")); + assertEquals(ActionType.TOOL, ActionType.fromValue("tool")); + assertEquals(ActionType.PROMPT, ActionType.fromValue("prompt")); + assertEquals(ActionType.UTIL, ActionType.fromValue("util")); + } + + @Test + void testFromValueIsCaseSensitive() { + // The fromValue method is case-sensitive + assertThrows(IllegalArgumentException.class, () -> ActionType.fromValue("FLOW")); + assertThrows(IllegalArgumentException.class, () -> ActionType.fromValue("Flow")); + assertThrows(IllegalArgumentException.class, () -> ActionType.fromValue("MODEL")); + assertThrows(IllegalArgumentException.class, () -> ActionType.fromValue("Model")); + } + + @Test + void testFromValueUnknown() { + assertThrows(IllegalArgumentException.class, () -> ActionType.fromValue("unknown-type")); + } + + @ParameterizedTest + @EnumSource(ActionType.class) + void testAllTypesHaveStringValue(ActionType type) { + assertNotNull(type.toString()); + assertFalse(type.toString().isEmpty()); + } + + @ParameterizedTest + @EnumSource(ActionType.class) + void testKeyFromNameFormat(ActionType type) { + String key = type.keyFromName("testAction"); + + assertTrue(key.startsWith("/")); + assertTrue(key.contains("testAction")); + } + + @Test + void testAllTypesAreUnique() { + ActionType[] types = ActionType.values(); + + for (int i = 0; i < types.length; i++) { + for (int j = i + 1; j < types.length; j++) { + assertNotEquals(types[i].toString(), types[j].toString(), String + .format("ActionTypes %s and %s have same string value", types[i].name(), types[j].name())); + } + } + } + + @Test + void testEnumValues() { + ActionType[] types = ActionType.values(); + assertTrue(types.length > 0); + } + + @Test + void testEnumValueOf() { + assertEquals(ActionType.FLOW, ActionType.valueOf("FLOW")); + assertEquals(ActionType.MODEL, ActionType.valueOf("MODEL")); + } +} diff --git a/java/core/src/test/java/com/google/genkit/core/DefaultRegistryTest.java b/java/core/src/test/java/com/google/genkit/core/DefaultRegistryTest.java new file mode 100644 index 0000000000..f6cf5f58ff --- /dev/null +++ b/java/core/src/test/java/com/google/genkit/core/DefaultRegistryTest.java @@ -0,0 +1,279 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +/** + * Unit tests for DefaultRegistry. + */ +@ExtendWith(MockitoExtension.class) +class DefaultRegistryTest { + + private DefaultRegistry registry; + + @Mock + private Plugin mockPlugin; + + @Mock + private Action mockAction; + + @BeforeEach + void setUp() { + registry = new DefaultRegistry(); + } + + @Test + void testNewChild() { + Registry child = registry.newChild(); + + assertNotNull(child); + assertTrue(child.isChild()); + assertFalse(registry.isChild()); + } + + @Test + void testRegisterPlugin() { + registry.registerPlugin("test-plugin", mockPlugin); + + Plugin result = registry.lookupPlugin("test-plugin"); + assertNotNull(result); + assertEquals(mockPlugin, result); + } + + @Test + void testRegisterPluginDuplicate() { + registry.registerPlugin("test-plugin", mockPlugin); + + assertThrows(IllegalStateException.class, () -> registry.registerPlugin("test-plugin", mockPlugin)); + } + + @Test + void testRegisterAction() { + String actionKey = "/flow/test-action"; + + registry.registerAction(actionKey, mockAction); + + Action result = registry.lookupAction(actionKey); + assertNotNull(result); + assertEquals(mockAction, result); + } + + @Test + void testRegisterActionDuplicate() { + String actionKey = "/flow/test-action"; + registry.registerAction(actionKey, mockAction); + + assertThrows(IllegalStateException.class, () -> registry.registerAction(actionKey, mockAction)); + } + + @Test + void testRegisterValue() { + String valueName = "test-value"; + Object value = "test-object"; + + registry.registerValue(valueName, value); + + Object result = registry.lookupValue(valueName); + assertNotNull(result); + assertEquals(value, result); + } + + @Test + void testRegisterValueDuplicate() { + String valueName = "test-value"; + registry.registerValue(valueName, "value1"); + + assertThrows(IllegalStateException.class, () -> registry.registerValue(valueName, "value2")); + } + + @Test + void testRegisterSchema() { + String schemaName = "test-schema"; + Map schema = new HashMap<>(); + schema.put("type", "object"); + + registry.registerSchema(schemaName, schema); + + Map result = registry.lookupSchema(schemaName); + assertNotNull(result); + assertEquals(schema, result); + } + + @Test + void testRegisterSchemaDuplicate() { + String schemaName = "test-schema"; + Map schema = new HashMap<>(); + registry.registerSchema(schemaName, schema); + + assertThrows(IllegalStateException.class, () -> registry.registerSchema(schemaName, new HashMap<>())); + } + + @Test + void testChildRegistryLookupFromParent() { + String actionKey = "/flow/parent-action"; + registry.registerAction(actionKey, mockAction); + + Registry child = registry.newChild(); + + Action result = child.lookupAction(actionKey); + assertNotNull(result); + assertEquals(mockAction, result); + } + + @Test + void testChildRegistryOverridesParent() { + String actionKey = "/flow/test-action"; + + @SuppressWarnings("unchecked") + Action childAction = mock(Action.class); + + registry.registerAction(actionKey, mockAction); + Registry child = registry.newChild(); + child.registerAction(actionKey + "-child", childAction); + + // Child can access parent action + Action parentResult = child.lookupAction(actionKey); + assertNotNull(parentResult); + assertEquals(mockAction, parentResult); + + // Child has its own action + Action childResult = child.lookupAction(actionKey + "-child"); + assertNotNull(childResult); + assertEquals(childAction, childResult); + } + + @Test + void testLookupNonExistentPlugin() { + Plugin result = registry.lookupPlugin("non-existent"); + assertNull(result); + } + + @Test + void testLookupNonExistentAction() { + Action result = registry.lookupAction("/flow/non-existent"); + assertNull(result); + } + + @Test + void testLookupNonExistentValue() { + Object result = registry.lookupValue("non-existent"); + assertNull(result); + } + + @Test + void testLookupNonExistentSchema() { + Map result = registry.lookupSchema("non-existent"); + assertNull(result); + } + + @Test + void testListPlugins() { + @SuppressWarnings("unchecked") + Plugin plugin2 = mock(Plugin.class); + + registry.registerPlugin("plugin1", mockPlugin); + registry.registerPlugin("plugin2", plugin2); + + List plugins = registry.listPlugins(); + + assertEquals(2, plugins.size()); + assertTrue(plugins.contains(mockPlugin)); + assertTrue(plugins.contains(plugin2)); + } + + @Test + void testListActions() { + registry.registerAction("/flow/action1", mockAction); + + List> actions = registry.listActions(); + + assertEquals(1, actions.size()); + assertTrue(actions.contains(mockAction)); + } + + @Test + void testListValues() { + registry.registerValue("value1", "object1"); + registry.registerValue("value2", "object2"); + + Map values = registry.listValues(); + + assertEquals(2, values.size()); + assertEquals("object1", values.get("value1")); + assertEquals("object2", values.get("value2")); + } + + @Test + void testRegisterPartial() { + String partialName = "myPartial"; + String partialSource = "{{#each items}}{{this}}{{/each}}"; + + registry.registerPartial(partialName, partialSource); + + String result = registry.lookupPartial(partialName); + assertEquals(partialSource, result); + } + + @Test + void testRegisterHelper() { + String helperName = "myHelper"; + Object helper = new Object(); + + registry.registerHelper(helperName, helper); + + Object result = registry.lookupHelper(helperName); + assertEquals(helper, result); + } + + @Test + void testChildLookupPartialFromParent() { + String partialName = "parentPartial"; + String partialSource = "parent template"; + + registry.registerPartial(partialName, partialSource); + Registry child = registry.newChild(); + + String result = child.lookupPartial(partialName); + assertEquals(partialSource, result); + } + + @Test + void testChildLookupHelperFromParent() { + String helperName = "parentHelper"; + Object helper = new Object(); + + registry.registerHelper(helperName, helper); + Registry child = registry.newChild(); + + Object result = child.lookupHelper(helperName); + assertEquals(helper, result); + } +} diff --git a/java/core/src/test/java/com/google/genkit/core/FlowTest.java b/java/core/src/test/java/com/google/genkit/core/FlowTest.java new file mode 100644 index 0000000000..70f04998b9 --- /dev/null +++ b/java/core/src/test/java/com/google/genkit/core/FlowTest.java @@ -0,0 +1,166 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for Flow. + */ +class FlowTest { + + private Registry registry; + + @BeforeEach + void setUp() { + registry = new DefaultRegistry(); + } + + @Test + void testDefineFlow() { + Flow flow = Flow.define(registry, "echoFlow", String.class, String.class, + (ctx, input) -> "Echo: " + input); + + assertNotNull(flow); + assertEquals("echoFlow", flow.getName()); + assertEquals(ActionType.FLOW, flow.getType()); + } + + @Test + void testFlowIsRegistered() { + Flow flow = Flow.define(registry, "testFlow", String.class, String.class, + (ctx, input) -> input.toUpperCase()); + + String key = ActionType.FLOW.keyFromName("testFlow"); + Action registered = registry.lookupAction(key); + + assertNotNull(registered); + // The registered action is the internal ActionDef, not the Flow wrapper + assertNotNull(registered.getDesc()); + } + + @Test + void testFlowRun() { + Flow flow = Flow.define(registry, "transformFlow", String.class, String.class, + (ctx, input) -> input.toLowerCase()); + + ActionContext ctx = new ActionContext(registry); + String result = flow.run(ctx, "HELLO WORLD"); + + assertEquals("hello world", result); + } + + @Test + void testFlowWithContext() { + Flow flow = Flow.define(registry, "contextFlow", String.class, String.class, + (ctx, input) -> { + // Flow should set the flow name in context + assertEquals("contextFlow", ctx.getFlowName()); + return input; + }); + + ActionContext ctx = new ActionContext(registry); + flow.run(ctx, "test"); + } + + @Test + void testFlowDesc() { + Flow flow = Flow.define(registry, "countFlow", String.class, Integer.class, + (ctx, input) -> input.length()); + + ActionDesc desc = flow.getDesc(); + + assertNotNull(desc); + assertEquals("countFlow", desc.getName()); + } + + @Test + void testDefineStreamingFlow() { + Flow flow = Flow.defineStreaming(registry, "streamingFlow", String.class, String.class, + (ctx, input, cb) -> { + if (cb != null) { + cb.accept("chunk1"); + cb.accept("chunk2"); + } + return "final result"; + }); + + assertNotNull(flow); + assertEquals("streamingFlow", flow.getName()); + assertEquals(ActionType.FLOW, flow.getType()); + } + + @Test + void testStreamingFlowWithCallback() { + StringBuilder chunks = new StringBuilder(); + + Flow flow = Flow.defineStreaming(registry, "chunkingFlow", String.class, String.class, + (ctx, input, cb) -> { + String[] words = input.split(" "); + for (String word : words) { + if (cb != null) { + cb.accept(word); + } + } + return input; + }); + + ActionContext ctx = new ActionContext(registry); + java.util.function.Consumer streamCallback = chunks::append; + String result = flow.run(ctx, "hello world", streamCallback); + + assertEquals("hello world", result); + assertEquals("helloworld", chunks.toString()); + } + + @Test + void testFlowRunThrowsGenkitException() { + Flow flow = Flow.define(registry, "errorFlow", String.class, String.class, + (ctx, input) -> { + throw new GenkitException("Intentional error"); + }); + + ActionContext ctx = new ActionContext(registry); + + assertThrows(GenkitException.class, () -> flow.run(ctx, "test")); + } + + @Test + void testMultipleFlowsInRegistry() { + Flow.define(registry, "flow1", String.class, String.class, (ctx, input) -> input); + Flow.define(registry, "flow2", String.class, Integer.class, (ctx, input) -> input.length()); + Flow.define(registry, "flow3", Integer.class, String.class, (ctx, input) -> String.valueOf(input)); + + assertNotNull(registry.lookupAction(ActionType.FLOW.keyFromName("flow1"))); + assertNotNull(registry.lookupAction(ActionType.FLOW.keyFromName("flow2"))); + assertNotNull(registry.lookupAction(ActionType.FLOW.keyFromName("flow3"))); + } + + @Test + void testFlowStepRunOutsideFlowThrowsException() { + ActionContext ctx = new ActionContext(registry); + + // Flow.run (step) should throw when not called from within a flow + java.util.function.Function stepFn = (v) -> "result"; + assertThrows(GenkitException.class, () -> Flow.run(ctx, "stepName", stepFn)); + } +} diff --git a/java/core/src/test/java/com/google/genkit/core/GenkitExceptionTest.java b/java/core/src/test/java/com/google/genkit/core/GenkitExceptionTest.java new file mode 100644 index 0000000000..35150ad3a7 --- /dev/null +++ b/java/core/src/test/java/com/google/genkit/core/GenkitExceptionTest.java @@ -0,0 +1,138 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for GenkitException. + */ +class GenkitExceptionTest { + + @Test + void testConstructorWithMessageOnly() { + String message = "Test error message"; + + GenkitException exception = new GenkitException(message); + + assertEquals(message, exception.getMessage()); + assertNull(exception.getCause()); + assertNull(exception.getErrorCode()); + assertNull(exception.getDetails()); + assertNull(exception.getTraceId()); + } + + @Test + void testConstructorWithMessageAndCause() { + String message = "Test error message"; + RuntimeException cause = new RuntimeException("Root cause"); + + GenkitException exception = new GenkitException(message, cause); + + assertEquals(message, exception.getMessage()); + assertEquals(cause, exception.getCause()); + assertNull(exception.getErrorCode()); + assertNull(exception.getDetails()); + assertNull(exception.getTraceId()); + } + + @Test + void testConstructorWithAllParameters() { + String message = "Test error message"; + RuntimeException cause = new RuntimeException("Root cause"); + String errorCode = "ERR_001"; + Map details = new HashMap<>(); + details.put("field", "value"); + String traceId = "trace-123"; + + GenkitException exception = new GenkitException(message, cause, errorCode, details, traceId); + + assertEquals(message, exception.getMessage()); + assertEquals(cause, exception.getCause()); + assertEquals(errorCode, exception.getErrorCode()); + assertEquals(details, exception.getDetails()); + assertEquals(traceId, exception.getTraceId()); + } + + @Test + void testIsRuntimeException() { + GenkitException exception = new GenkitException("Test"); + + assertTrue(exception instanceof RuntimeException); + } + + @Test + void testExceptionCanBeThrown() { + assertThrows(GenkitException.class, () -> { + throw new GenkitException("Test exception"); + }); + } + + @Test + void testExceptionChaining() { + Exception original = new Exception("Original error"); + GenkitException wrapped = new GenkitException("Wrapped error", original); + + Throwable cause = wrapped.getCause(); + assertNotNull(cause); + assertEquals("Original error", cause.getMessage()); + } + + @Test + void testNullCause() { + GenkitException exception = new GenkitException("Test", null, "ERR", null, null); + + assertNull(exception.getCause()); + } + + @Test + void testDetailsCanBeAnyObject() { + String stringDetails = "Simple string details"; + GenkitException exceptionWithString = new GenkitException("Test", null, "ERR", stringDetails, null); + assertEquals(stringDetails, exceptionWithString.getDetails()); + + Map mapDetails = Map.of("key", "value"); + GenkitException exceptionWithMap = new GenkitException("Test", null, "ERR", mapDetails, null); + assertEquals(mapDetails, exceptionWithMap.getDetails()); + } + + @Test + void testStackTrace() { + GenkitException exception = new GenkitException("Test"); + + StackTraceElement[] stackTrace = exception.getStackTrace(); + assertNotNull(stackTrace); + assertTrue(stackTrace.length > 0); + } + + @Test + void testToString() { + GenkitException exception = new GenkitException("Test message"); + + String string = exception.toString(); + assertNotNull(string); + assertTrue(string.contains("GenkitException")); + assertTrue(string.contains("Test message")); + } +} diff --git a/java/core/src/test/java/com/google/genkit/core/JsonUtilsTest.java b/java/core/src/test/java/com/google/genkit/core/JsonUtilsTest.java new file mode 100644 index 0000000000..37c16b8a2f --- /dev/null +++ b/java/core/src/test/java/com/google/genkit/core/JsonUtilsTest.java @@ -0,0 +1,213 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core; + +import static org.junit.jupiter.api.Assertions.*; + +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Unit tests for JsonUtils. + */ +class JsonUtilsTest { + + @Test + void testGetObjectMapper() { + ObjectMapper mapper = JsonUtils.getObjectMapper(); + assertNotNull(mapper); + // Same instance should be returned + assertSame(mapper, JsonUtils.getObjectMapper()); + } + + @Test + void testToJson() { + TestObject obj = new TestObject("test", 42); + + String json = JsonUtils.toJson(obj); + + assertNotNull(json); + assertTrue(json.contains("\"name\":\"test\"")); + assertTrue(json.contains("\"value\":42")); + } + + @Test + void testToJsonWithNull() { + String json = JsonUtils.toJson(null); + assertEquals("null", json); + } + + @Test + void testToJsonNode() { + TestObject obj = new TestObject("test", 42); + + JsonNode node = JsonUtils.toJsonNode(obj); + + assertNotNull(node); + assertEquals("test", node.get("name").asText()); + assertEquals(42, node.get("value").asInt()); + } + + @Test + void testFromJson() { + String json = "{\"name\":\"test\",\"value\":42}"; + + TestObject obj = JsonUtils.fromJson(json, TestObject.class); + + assertNotNull(obj); + assertEquals("test", obj.getName()); + assertEquals(42, obj.getValue()); + } + + @Test + void testFromJsonIgnoresUnknownProperties() { + String json = "{\"name\":\"test\",\"value\":42,\"unknown\":\"field\"}"; + + TestObject obj = JsonUtils.fromJson(json, TestObject.class); + + assertNotNull(obj); + assertEquals("test", obj.getName()); + assertEquals(42, obj.getValue()); + } + + @Test + void testFromJsonInvalidJson() { + String invalidJson = "{invalid}"; + + assertThrows(GenkitException.class, () -> JsonUtils.fromJson(invalidJson, TestObject.class)); + } + + @Test + void testFromJsonNode() { + JsonNode node = JsonUtils.getObjectMapper().createObjectNode().put("name", "test").put("value", 42); + + TestObject obj = JsonUtils.fromJsonNode(node, TestObject.class); + + assertNotNull(obj); + assertEquals("test", obj.getName()); + assertEquals(42, obj.getValue()); + } + + @Test + void testParseJson() { + String json = "{\"name\":\"test\",\"value\":42}"; + + JsonNode node = JsonUtils.parseJson(json); + + assertNotNull(node); + assertEquals("test", node.get("name").asText()); + assertEquals(42, node.get("value").asInt()); + } + + @Test + void testParseJsonInvalidJson() { + String invalidJson = "{invalid}"; + + assertThrows(GenkitException.class, () -> JsonUtils.parseJson(invalidJson)); + } + + @Test + void testToPrettyJson() { + TestObject obj = new TestObject("test", 42); + + String prettyJson = JsonUtils.toPrettyJson(obj); + + assertNotNull(prettyJson); + assertTrue(prettyJson.contains("\n")); // Should have newlines for pretty printing + assertTrue(prettyJson.contains("\"name\"")); + assertTrue(prettyJson.contains("\"value\"")); + } + + @Test + void testToJsonMap() { + Map map = new HashMap<>(); + map.put("key1", "value1"); + map.put("key2", 123); + + String json = JsonUtils.toJson(map); + + assertNotNull(json); + assertTrue(json.contains("\"key1\":\"value1\"")); + assertTrue(json.contains("\"key2\":123")); + } + + @Test + void testDateSerialization() { + Instant now = Instant.parse("2025-01-01T12:00:00Z"); + Map map = new HashMap<>(); + map.put("timestamp", now); + + String json = JsonUtils.toJson(map); + + assertNotNull(json); + // Should serialize as ISO-8601 string, not timestamps + assertTrue(json.contains("2025-01-01")); + } + + @Test + void testEmptyObjectSerialization() { + Object emptyObj = new Object() { + }; + + // Should not fail on empty beans + assertDoesNotThrow(() -> JsonUtils.toJson(emptyObj)); + } + + /** + * Test helper class. + */ + static class TestObject { + @JsonProperty("name") + private String name; + + @JsonProperty("value") + private int value; + + public TestObject() { + } + + public TestObject(String name, int value) { + this.name = name; + this.value = value; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getValue() { + return value; + } + + public void setValue(int value) { + this.value = value; + } + } +} diff --git a/java/core/src/test/java/com/google/genkit/core/middleware/MiddlewareTest.java b/java/core/src/test/java/com/google/genkit/core/middleware/MiddlewareTest.java new file mode 100644 index 0000000000..0d3d145201 --- /dev/null +++ b/java/core/src/test/java/com/google/genkit/core/middleware/MiddlewareTest.java @@ -0,0 +1,384 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.core.middleware; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.DefaultRegistry; +import com.google.genkit.core.GenkitException; + +/** + * Tests for the Middleware classes. + */ +class MiddlewareTest { + + private ActionContext context; + + @BeforeEach + void setUp() { + context = new ActionContext(new DefaultRegistry()); + } + + @Test + void testSimpleMiddleware() { + Middleware middleware = (request, ctx, next) -> { + String modified = request.toUpperCase(); + return next.apply(modified, ctx); + }; + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(middleware); + + String result = chain.execute("hello", context, (ctx, req) -> "Result: " + req); + assertEquals("Result: HELLO", result); + } + + @Test + void testMiddlewareChainOrder() { + List order = new ArrayList<>(); + + Middleware first = (request, ctx, next) -> { + order.add("first-before"); + String result = next.apply(request + "-first", ctx); + order.add("first-after"); + return result; + }; + + Middleware second = (request, ctx, next) -> { + order.add("second-before"); + String result = next.apply(request + "-second", ctx); + order.add("second-after"); + return result; + }; + + MiddlewareChain chain = MiddlewareChain.of(first, second); + + String result = chain.execute("input", context, (ctx, req) -> { + order.add("action"); + return req; + }); + + assertEquals("input-first-second", result); + assertEquals(List.of("first-before", "second-before", "action", "second-after", "first-after"), order); + } + + @Test + void testEmptyMiddlewareChain() { + MiddlewareChain chain = new MiddlewareChain<>(); + + String result = chain.execute("hello", context, (ctx, req) -> "Result: " + req); + assertEquals("Result: hello", result); + } + + @Test + void testMiddlewareModifiesResponse() { + Middleware middleware = (request, ctx, next) -> { + String result = next.apply(request, ctx); + return result.toUpperCase(); + }; + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(middleware); + + String result = chain.execute("hello", context, (ctx, req) -> "Result: " + req); + assertEquals("RESULT: HELLO", result); + } + + @Test + void testLoggingMiddleware() { + Middleware loggingMiddleware = CommonMiddleware.logging("test"); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(loggingMiddleware); + + String result = chain.execute("hello", context, (ctx, req) -> "Result: " + req); + assertEquals("Result: hello", result); + } + + @Test + void testRetryMiddleware() { + AtomicInteger attempts = new AtomicInteger(0); + + Middleware retryMiddleware = CommonMiddleware.retry(3, 10); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(retryMiddleware); + + String result = chain.execute("hello", context, (ctx, req) -> { + int attempt = attempts.incrementAndGet(); + if (attempt < 3) { + throw new GenkitException("Simulated failure"); + } + return "Success after " + attempt + " attempts"; + }); + + assertEquals("Success after 3 attempts", result); + assertEquals(3, attempts.get()); + } + + @Test + void testRetryMiddlewareMaxRetriesExceeded() { + AtomicInteger attempts = new AtomicInteger(0); + + Middleware retryMiddleware = CommonMiddleware.retry(2, 10); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(retryMiddleware); + + GenkitException exception = assertThrows(GenkitException.class, () -> { + chain.execute("hello", context, (ctx, req) -> { + attempts.incrementAndGet(); + throw new GenkitException("Simulated failure"); + }); + }); + + assertTrue(exception.getMessage().contains("Simulated failure")); + assertEquals(3, attempts.get()); // Initial + 2 retries + } + + @Test + void testValidationMiddleware() { + Middleware validationMiddleware = CommonMiddleware.validate(request -> { + if (request == null || request.isEmpty()) { + throw new GenkitException("Request cannot be empty"); + } + }); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(validationMiddleware); + + // Valid request + String result = chain.execute("hello", context, (ctx, req) -> "Result: " + req); + assertEquals("Result: hello", result); + + // Invalid request + GenkitException exception = assertThrows(GenkitException.class, () -> { + chain.execute("", context, (ctx, req) -> "Result: " + req); + }); + assertTrue(exception.getMessage().contains("empty")); + } + + @Test + void testTransformRequestMiddleware() { + Middleware transformMiddleware = CommonMiddleware.transformRequest(String::trim); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(transformMiddleware); + + String result = chain.execute(" hello ", context, (ctx, req) -> "Result: " + req); + assertEquals("Result: hello", result); + } + + @Test + void testTransformResponseMiddleware() { + Middleware transformMiddleware = CommonMiddleware.transformResponse(String::toUpperCase); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(transformMiddleware); + + String result = chain.execute("hello", context, (ctx, req) -> "Result: " + req); + assertEquals("RESULT: HELLO", result); + } + + @Test + void testTimingMiddleware() { + List timings = new ArrayList<>(); + Middleware timingMiddleware = CommonMiddleware.timing(timings::add); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(timingMiddleware); + + String result = chain.execute("hello", context, (ctx, req) -> { + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "Result: " + req; + }); + + assertEquals("Result: hello", result); + assertEquals(1, timings.size()); + assertTrue(timings.get(0) >= 50); + } + + @Test + void testCacheMiddleware() { + SimpleCache cache = new SimpleCache<>(); + AtomicInteger actionCalls = new AtomicInteger(0); + + Middleware cacheMiddleware = CommonMiddleware.cache(cache, request -> request); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(cacheMiddleware); + + // First call - should execute action + String result1 = chain.execute("hello", context, (ctx, req) -> { + actionCalls.incrementAndGet(); + return "Result: " + req; + }); + + // Second call - should use cache + String result2 = chain.execute("hello", context, (ctx, req) -> { + actionCalls.incrementAndGet(); + return "Result: " + req; + }); + + assertEquals("Result: hello", result1); + assertEquals("Result: hello", result2); + assertEquals(1, actionCalls.get()); // Action should only be called once + } + + @Test + void testConditionalMiddleware() { + Middleware upperCaseMiddleware = (request, ctx, next) -> { + return next.apply(request.toUpperCase(), ctx); + }; + + Middleware conditionalMiddleware = CommonMiddleware + .conditional((request, ctx) -> request.startsWith("transform:"), upperCaseMiddleware); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(conditionalMiddleware); + + // Should apply middleware + String result1 = chain.execute("transform:hello", context, (ctx, req) -> "Result: " + req); + assertEquals("Result: TRANSFORM:HELLO", result1); + + // Should skip middleware + String result2 = chain.execute("hello", context, (ctx, req) -> "Result: " + req); + assertEquals("Result: hello", result2); + } + + @Test + void testErrorHandlerMiddleware() { + Middleware errorHandler = CommonMiddleware + .errorHandler(e -> "Error handled: " + e.getMessage()); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(errorHandler); + + String result = chain.execute("hello", context, (ctx, req) -> { + throw new GenkitException("Something went wrong"); + }); + + assertEquals("Error handled: Something went wrong", result); + } + + @Test + void testRateLimitMiddleware() { + Middleware rateLimitMiddleware = CommonMiddleware.rateLimit(2, 1000); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(rateLimitMiddleware); + + // First two calls should succeed + String result1 = chain.execute("hello", context, (ctx, req) -> "Result: " + req); + String result2 = chain.execute("hello", context, (ctx, req) -> "Result: " + req); + + assertEquals("Result: hello", result1); + assertEquals("Result: hello", result2); + + // Third call should fail + GenkitException exception = assertThrows(GenkitException.class, () -> { + chain.execute("hello", context, (ctx, req) -> "Result: " + req); + }); + assertTrue(exception.getMessage().contains("Rate limit exceeded")); + } + + @Test + void testBeforeAfterMiddleware() { + List events = new ArrayList<>(); + + Middleware beforeAfterMiddleware = CommonMiddleware.beforeAfter( + (request, ctx) -> events.add("before: " + request), + (response, ctx) -> events.add("after: " + response)); + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(beforeAfterMiddleware); + + String result = chain.execute("hello", context, (ctx, req) -> "Result: " + req); + + assertEquals("Result: hello", result); + assertEquals(List.of("before: hello", "after: Result: hello"), events); + } + + @Test + void testMiddlewareChainCopy() { + Middleware middleware = (request, ctx, next) -> next.apply(request.toUpperCase(), ctx); + + MiddlewareChain original = new MiddlewareChain<>(); + original.use(middleware); + + MiddlewareChain copy = original.copy(); + + // Both should work + String result1 = original.execute("hello", context, (ctx, req) -> "Result: " + req); + String result2 = copy.execute("world", context, (ctx, req) -> "Result: " + req); + + assertEquals("Result: HELLO", result1); + assertEquals("Result: WORLD", result2); + } + + @Test + void testUseFirst() { + List order = new ArrayList<>(); + + Middleware first = (request, ctx, next) -> { + order.add("first"); + return next.apply(request, ctx); + }; + + Middleware second = (request, ctx, next) -> { + order.add("second"); + return next.apply(request, ctx); + }; + + MiddlewareChain chain = new MiddlewareChain<>(); + chain.use(first); + chain.useFirst(second); + + chain.execute("hello", context, (ctx, req) -> "Result: " + req); + + assertEquals(List.of("second", "first"), order); + } + + @Test + void testChainSize() { + MiddlewareChain chain = new MiddlewareChain<>(); + assertEquals(0, chain.size()); + assertTrue(chain.isEmpty()); + + chain.use((request, ctx, next) -> next.apply(request, ctx)); + assertEquals(1, chain.size()); + assertFalse(chain.isEmpty()); + + chain.clear(); + assertEquals(0, chain.size()); + assertTrue(chain.isEmpty()); + } +} diff --git a/java/genkit/pom.xml b/java/genkit/pom.xml new file mode 100644 index 0000000000..2570158a41 --- /dev/null +++ b/java/genkit/pom.xml @@ -0,0 +1,98 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../pom.xml + + + genkit + jar + Genkit + Genkit main module with Genkit class and reflection server + + + + com.google.genkit + genkit-core + ${project.version} + + + com.google.genkit + genkit-ai + ${project.version} + + + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + + + org.eclipse.jetty + jetty-server + + + + + org.slf4j + slf4j-api + + + ch.qos.logback + logback-classic + runtime + + + + + com.github.jknack + handlebars + + + + + org.junit.jupiter + junit-jupiter + test + + + org.mockito + mockito-core + test + + + diff --git a/java/genkit/src/main/java/com/google/genkit/Genkit.java b/java/genkit/src/main/java/com/google/genkit/Genkit.java new file mode 100644 index 0000000000..930a276ada --- /dev/null +++ b/java/genkit/src/main/java/com/google/genkit/Genkit.java @@ -0,0 +1,1829 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiFunction; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.ai.*; +import com.google.genkit.ai.evaluation.*; +import com.google.genkit.ai.session.*; +import com.google.genkit.core.*; +import com.google.genkit.core.middleware.Middleware; +import com.google.genkit.core.tracing.SpanMetadata; +import com.google.genkit.core.tracing.Tracer; +import com.google.genkit.prompt.DotPrompt; +import com.google.genkit.prompt.ExecutablePrompt; + +/** + * Genkit is the main entry point for the Genkit framework. + * + * It provides methods to define and run flows, configure AI models, and + * interact with the Genkit ecosystem. + */ +public class Genkit { + + private static final Logger logger = LoggerFactory.getLogger(Genkit.class); + + private final Registry registry; + private final List plugins; + private final GenkitOptions options; + private final Map> promptCache; + private final Map agentRegistry; + private ReflectionServer reflectionServer; + private EvaluationManager evaluationManager; + + /** + * Creates a new Genkit instance with default options. + */ + public Genkit() { + this(GenkitOptions.builder().build()); + } + + /** + * Creates a new Genkit instance with the given options. + * + * @param options + * the Genkit options + */ + public Genkit(GenkitOptions options) { + this.options = options; + this.registry = new DefaultRegistry(); + this.plugins = new ArrayList<>(); + this.promptCache = new ConcurrentHashMap<>(); + this.agentRegistry = new ConcurrentHashMap<>(); + } + + /** + * Creates a new Genkit builder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Creates a Genkit instance with the given plugins. + * + * @param plugins + * the plugins to use + * @return a configured Genkit instance + */ + public static Genkit create(Plugin... plugins) { + Builder builder = builder(); + for (Plugin plugin : plugins) { + builder.plugin(plugin); + } + return builder.build(); + } + + /** + * Initializes plugins. + */ + public void init() { + // Register utility actions + registerUtilityActions(); + + for (Plugin plugin : plugins) { + try { + List> actions = plugin.init(registry); + for (Action action : actions) { + String key = action.getType().keyFromName(action.getName()); + registry.registerAction(key, action); + } + logger.info("Initialized plugin: {}", plugin.getName()); + } catch (Exception e) { + logger.error("Failed to initialize plugin: {}", plugin.getName(), e); + throw new GenkitException("Failed to initialize plugin: " + plugin.getName(), e); + } + } + + // Start reflection server in dev mode + if (options.isDevMode()) { + startReflectionServer(); + } + } + + /** + * Registers utility actions like /util/generate. + */ + private void registerUtilityActions() { + GenerateAction.define(registry); + } + + /** + * Defines a flow. + * + * @param + * the input type + * @param + * the output type + * @param name + * the flow name + * @param inputClass + * the input class + * @param outputClass + * the output class + * @param handler + * the flow handler + * @return the flow + */ + public Flow defineFlow(String name, Class inputClass, Class outputClass, + BiFunction handler) { + return Flow.define(registry, name, inputClass, outputClass, handler); + } + + /** + * Defines a flow with middleware. + * + * @param + * the input type + * @param + * the output type + * @param name + * the flow name + * @param inputClass + * the input class + * @param outputClass + * the output class + * @param handler + * the flow handler + * @param middleware + * the middleware to apply + * @return the flow + */ + public Flow defineFlow(String name, Class inputClass, Class outputClass, + BiFunction handler, List> middleware) { + return Flow.define(registry, name, inputClass, outputClass, handler, middleware); + } + + /** + * Defines a flow with a simple handler. + * + * @param + * the input type + * @param + * the output type + * @param name + * the flow name + * @param inputClass + * the input class + * @param outputClass + * the output class + * @param handler + * the flow handler + * @return the flow + */ + public Flow defineFlow(String name, Class inputClass, Class outputClass, + Function handler) { + return Flow.define(registry, name, inputClass, outputClass, (ctx, input) -> handler.apply(input)); + } + + /** + * Defines a flow with a simple handler and middleware. + * + * @param + * the input type + * @param + * the output type + * @param name + * the flow name + * @param inputClass + * the input class + * @param outputClass + * the output class + * @param handler + * the flow handler + * @param middleware + * the middleware to apply + * @return the flow + */ + public Flow defineFlow(String name, Class inputClass, Class outputClass, + Function handler, List> middleware) { + return Flow.define(registry, name, inputClass, outputClass, (ctx, input) -> handler.apply(input), middleware); + } + + /** + * Defines a tool. + * + * @param + * the input type + * @param + * the output type + * @param name + * the tool name + * @param description + * the tool description + * @param inputSchema + * the input JSON schema + * @param inputClass + * the input class + * @param handler + * the tool handler + * @return the tool + */ + public Tool defineTool(String name, String description, Map inputSchema, + Class inputClass, BiFunction handler) { + Tool tool = Tool.builder().name(name).description(description).inputSchema(inputSchema) + .inputClass(inputClass).handler(handler).build(); + tool.register(registry); + return tool; + } + + /** + * Loads a prompt by name from the prompts directory. + * + *

+ * This is similar to the JavaScript API: `ai.prompt('hello')`. The prompt is + * loaded from the configured promptDir (default: /prompts). The prompt is + * automatically registered as an action and cached for reuse. + * + *

+ * Example usage: + * + *

{@code
+   * ExecutablePrompt helloPrompt = genkit.prompt("hello", HelloInput.class);
+   * ModelResponse response = helloPrompt.generate(new HelloInput("John"));
+   * }
+ * + * @param + * the input type + * @param name + * the prompt name (without .prompt extension) + * @param inputClass + * the input class + * @return the executable prompt + * @throws GenkitException + * if the prompt cannot be loaded + */ + @SuppressWarnings("unchecked") + public ExecutablePrompt prompt(String name, Class inputClass) throws GenkitException { + return prompt(name, inputClass, null); + } + + /** + * Loads a prompt by name with an optional variant. + * + *

+ * Variants allow different versions of the same prompt to be tested. For + * example: "recipe" with variant "gemini25pro" loads + * "recipe.gemini25pro.prompt". + * + * @param + * the input type + * @param name + * the prompt name (without .prompt extension) + * @param inputClass + * the input class + * @param variant + * optional variant name (e.g., "gemini25pro") + * @return the executable prompt + * @throws GenkitException + * if the prompt cannot be loaded + */ + @SuppressWarnings("unchecked") + public ExecutablePrompt prompt(String name, Class inputClass, String variant) throws GenkitException { + // Build the cache key + String cacheKey = variant != null ? name + "." + variant : name; + + // Check cache first + DotPrompt dotPrompt = (DotPrompt) promptCache.get(cacheKey); + + if (dotPrompt == null) { + // Build the resource path + String promptDir = options.getPromptDir(); + String fileName = variant != null ? name + "." + variant + ".prompt" : name + ".prompt"; + String resourcePath = promptDir + "/" + fileName; + + // Load the prompt + dotPrompt = DotPrompt.loadFromResource(resourcePath); + promptCache.put(cacheKey, dotPrompt); + + // Auto-register as action + dotPrompt.register(registry, inputClass); + String registeredKey = ActionType.EXECUTABLE_PROMPT.keyFromName(dotPrompt.getName()); + logger.info("Loaded and registered prompt: {} as {} (variant: {})", name, registeredKey, variant); + } + + return new ExecutablePrompt<>(dotPrompt, registry, inputClass).withGenerateFunction(this::generate); + } + + /** + * Loads a prompt by name using a Map as input type. + * + *

+ * This is a convenience method when you don't want to define a specific input + * class. + * + * @param name + * the prompt name (without .prompt extension) + * @return the executable prompt with Map input + * @throws GenkitException + * if the prompt cannot be loaded + */ + @SuppressWarnings("unchecked") + public ExecutablePrompt> prompt(String name) throws GenkitException { + return prompt(name, (Class>) (Class) Map.class, null); + } + + /** + * Defines a prompt. + * + * @param + * the input type + * @param name + * the prompt name + * @param template + * the prompt template + * @param inputClass + * the input class + * @param renderer + * the prompt renderer + * @return the prompt + */ + public Prompt definePrompt(String name, String template, Class inputClass, + BiFunction renderer) { + Prompt prompt = Prompt.builder().name(name).template(template).inputClass(inputClass).renderer(renderer) + .build(); + prompt.register(registry); + return prompt; + } + + /** + * Registers a model. + * + * @param model + * the model to register + */ + public void registerModel(Model model) { + model.register(registry); + } + + /** + * Registers an embedder. + * + * @param embedder + * the embedder to register + */ + public void registerEmbedder(Embedder embedder) { + embedder.register(registry); + } + + /** + * Registers a retriever. + * + * @param retriever + * the retriever to register + */ + public void registerRetriever(Retriever retriever) { + retriever.register(registry); + } + + /** + * Registers an indexer. + * + * @param indexer + * the indexer to register + */ + public void registerIndexer(Indexer indexer) { + indexer.register(registry); + } + + /** + * Defines and registers a retriever. + * + *

+ * This is the preferred way to create retrievers as it automatically registers + * them with the Genkit registry. + * + *

+ * Example usage: + * + *

{@code
+   * Retriever myRetriever = genkit.defineRetriever("myStore/docs", (ctx, request) -> {
+   * 	// Find similar documents
+   * 	List docs = findSimilarDocs(request.getQuery());
+   * 	return new RetrieverResponse(docs);
+   * });
+   * }
+ * + * @param name + * the retriever name + * @param handler + * the retrieval function + * @return the registered retriever + */ + public Retriever defineRetriever(String name, + BiFunction handler) { + Retriever retriever = Retriever.builder().name(name).handler(handler).build(); + retriever.register(registry); + return retriever; + } + + /** + * Defines and registers an indexer. + * + *

+ * This is the preferred way to create indexers as it automatically registers + * them with the Genkit registry. + * + *

+ * Example usage: + * + *

{@code
+   * Indexer myIndexer = genkit.defineIndexer("myStore/docs", (ctx, request) -> {
+   * 	// Index the documents
+   * 	indexDocuments(request.getDocuments());
+   * 	return new IndexerResponse();
+   * });
+   * }
+ * + * @param name + * the indexer name + * @param handler + * the indexing function + * @return the registered indexer + */ + public Indexer defineIndexer(String name, BiFunction handler) { + Indexer indexer = Indexer.builder().name(name).handler(handler).build(); + indexer.register(registry); + return indexer; + } + + /** + * Gets a model by name. + * + * @param name + * the model name + * @return the model + */ + public Model getModel(String name) { + Action action = registry.lookupAction(ActionType.MODEL, name); + if (action == null) { + throw new GenkitException("Model not found: " + name); + } + return (Model) action; + } + + /** + * Gets an embedder by name. + * + * @param name + * the embedder name + * @return the embedder + */ + public Embedder getEmbedder(String name) { + Action action = registry.lookupAction(ActionType.EMBEDDER, name); + if (action == null) { + throw new GenkitException("Embedder not found: " + name); + } + return (Embedder) action; + } + + /** + * Gets a retriever by name. + * + * @param name + * the retriever name + * @return the retriever + */ + public Retriever getRetriever(String name) { + Action action = registry.lookupAction(ActionType.RETRIEVER, name); + if (action == null) { + throw new GenkitException("Retriever not found: " + name); + } + return (Retriever) action; + } + + /** + * Generates a model response using the specified options. + * + *

+ * This method handles tool execution automatically. If the model requests tool + * calls, this method will execute the tools, add the results to the + * conversation, and continue generation until the model produces a final + * response. + * + *

+ * When a tool throws a {@link ToolInterruptException}, the generation is halted + * and the response is returned with {@link FinishReason#INTERRUPTED}. The + * caller can then use {@link ResumeOptions} to continue generation after + * handling the interrupt. + * + *

+ * Example with interrupts: + * + *

{@code
+   * // First generation - may be interrupted
+   * ModelResponse response = genkit.generate(GenerateOptions.builder().model("googleai/gemini-pro")
+   * 		.prompt("Transfer $100 to account 12345").tools(List.of(confirmTransfer)).build());
+   *
+   * // Check if interrupted
+   * if (response.isInterrupted()) {
+   * 	Part interrupt = response.getInterrupts().get(0);
+   * 
+   * 	// Get user confirmation
+   * 	boolean confirmed = askUserForConfirmation();
+   * 
+   * 	// Resume with user response
+   * 	Part responseData = confirmTransfer.respond(interrupt, new ConfirmOutput(confirmed));
+   * 	ModelResponse resumed = genkit.generate(GenerateOptions.builder().model("googleai/gemini-pro")
+   * 			.messages(response.getMessages()).tools(List.of(confirmTransfer))
+   * 			.resume(ResumeOptions.builder().respond(responseData).build()).build());
+   * }
+   * }
+ * + * @param options + * the generate options + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse generate(GenerateOptions options) throws GenkitException { + Model model = getModel(options.getModel()); + ModelRequest request = options.toModelRequest(); + ActionContext ctx = new ActionContext(registry); + + int maxTurns = options.getMaxTurns() != null ? options.getMaxTurns() : 5; + int turn = 0; + + // Handle resume option if provided + if (options.getResume() != null) { + request = handleResumeOption(request, options); + } + + while (turn < maxTurns) { + // Create span metadata for the model call + SpanMetadata modelSpanMetadata = SpanMetadata.builder().name(options.getModel()) + .type(ActionType.MODEL.getValue()).subtype("model").build(); + + String flowName = ctx.getFlowName(); + if (flowName != null) { + modelSpanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); + } + + final ModelRequest currentRequest = request; + ModelResponse response = Tracer.runInNewSpan(ctx, modelSpanMetadata, request, (spanCtx, req) -> { + return model.run(ctx.withSpanContext(spanCtx), currentRequest); + }); + + // Check if the model requested tool calls + List toolRequestParts = extractToolRequestParts(response); + if (toolRequestParts.isEmpty()) { + // No tool calls, return the response + return response; + } + + // Execute tools and handle interrupts + ToolExecutionResult toolResult = executeToolsWithInterruptHandling(ctx, toolRequestParts, + options.getTools()); + + // If there are interrupts, return immediately with interrupted response + if (!toolResult.getInterrupts().isEmpty()) { + return buildInterruptedResponse(response, toolResult); + } + + // Add the assistant message with tool requests + Message assistantMessage = response.getMessage(); + List updatedMessages = new java.util.ArrayList<>(request.getMessages()); + updatedMessages.add(assistantMessage); + + // Add tool response message + Message toolResponseMessage = new Message(); + toolResponseMessage.setRole(Role.TOOL); + toolResponseMessage.setContent(toolResult.getResponses()); + updatedMessages.add(toolResponseMessage); + + // Update request with new messages for next turn + request = ModelRequest.builder().messages(updatedMessages).config(request.getConfig()) + .tools(request.getTools()).output(request.getOutput()).build(); + + turn++; + } + + throw new GenkitException("Max tool execution turns (" + maxTurns + ") exceeded"); + } + + /** + * Handles resume options by processing respond and restart directives. + */ + private ModelRequest handleResumeOption(ModelRequest request, GenerateOptions options) { + ResumeOptions resume = options.getResume(); + List messages = new java.util.ArrayList<>(request.getMessages()); + + if (messages.isEmpty()) { + throw new GenkitException("Cannot resume generation with no messages"); + } + + Message lastMessage = messages.get(messages.size() - 1); + if (lastMessage.getRole() != Role.MODEL) { + throw new GenkitException("Cannot resume unless the last message is from the model"); + } + + // Build tool response parts from resume options + List toolResponseParts = new java.util.ArrayList<>(); + + // Handle respond directives + if (resume.getRespond() != null) { + for (ToolResponse toolResponse : resume.getRespond()) { + Part responsePart = new Part(); + responsePart.setToolResponse(toolResponse); + Map metadata = new java.util.HashMap<>(); + metadata.put("interruptResponse", true); + responsePart.setMetadata(metadata); + toolResponseParts.add(responsePart); + } + } + + // Handle restart directives - execute the tools + if (resume.getRestart() != null) { + ActionContext ctx = new ActionContext(registry); + for (ToolRequest restartRequest : resume.getRestart()) { + Tool tool = findTool(restartRequest.getName(), options.getTools()); + if (tool == null) { + throw new GenkitException("Tool not found for restart: " + restartRequest.getName()); + } + + try { + @SuppressWarnings("unchecked") + Tool typedTool = (Tool) tool; + Object result = typedTool.run(ctx, restartRequest.getInput()); + + Part responsePart = new Part(); + ToolResponse toolResponse = new ToolResponse(restartRequest.getRef(), restartRequest.getName(), + result); + responsePart.setToolResponse(toolResponse); + Map metadata = new java.util.HashMap<>(); + metadata.put("source", "restart"); + responsePart.setMetadata(metadata); + toolResponseParts.add(responsePart); + } catch (ToolInterruptException e) { + // Tool interrupted again during restart + throw new GenkitException( + "Tool '" + restartRequest.getName() + "' triggered an interrupt during restart. " + + "Re-interrupting during restart is not supported."); + } + } + } + + if (toolResponseParts.isEmpty()) { + throw new GenkitException("Resume options must contain either respond or restart directives"); + } + + // Add tool response message + Message toolResponseMessage = new Message(); + toolResponseMessage.setRole(Role.TOOL); + toolResponseMessage.setContent(toolResponseParts); + Map toolMsgMetadata = new java.util.HashMap<>(); + toolMsgMetadata.put("resumed", true); + toolResponseMessage.setMetadata(toolMsgMetadata); + messages.add(toolResponseMessage); + + return ModelRequest.builder().messages(messages).config(request.getConfig()).tools(request.getTools()) + .output(request.getOutput()).build(); + } + + /** + * Builds an interrupted response from the model response and tool execution + * result. + */ + private ModelResponse buildInterruptedResponse(ModelResponse response, ToolExecutionResult toolResult) { + // Update the model message content with interrupt metadata + Message originalMessage = response.getMessage(); + List updatedContent = new java.util.ArrayList<>(); + + for (Part part : originalMessage.getContent()) { + if (part.getToolRequest() != null) { + ToolRequest toolRequest = part.getToolRequest(); + String key = toolRequest.getName() + "#" + (toolRequest.getRef() != null ? toolRequest.getRef() : ""); + + // Check if this tool request was interrupted + Part interruptPart = toolResult.getInterruptMap().get(key); + if (interruptPart != null) { + updatedContent.add(interruptPart); + } else { + // Check for pending output + Object pendingOutput = toolResult.getPendingOutputMap().get(key); + if (pendingOutput != null) { + Part pendingPart = new Part(); + pendingPart.setToolRequest(toolRequest); + Map metadata = part.getMetadata() != null + ? new java.util.HashMap<>(part.getMetadata()) + : new java.util.HashMap<>(); + metadata.put("pendingOutput", pendingOutput); + pendingPart.setMetadata(metadata); + updatedContent.add(pendingPart); + } else { + updatedContent.add(part); + } + } + } else { + updatedContent.add(part); + } + } + + Message updatedMessage = new Message(); + updatedMessage.setRole(originalMessage.getRole()); + updatedMessage.setContent(updatedContent); + updatedMessage.setMetadata(originalMessage.getMetadata()); + + // Create candidate with updated message + Candidate updatedCandidate = new Candidate(); + updatedCandidate.setMessage(updatedMessage); + updatedCandidate.setFinishReason(FinishReason.INTERRUPTED); + + return ModelResponse.builder().candidates(List.of(updatedCandidate)).usage(response.getUsage()) + .request(response.getRequest()).custom(response.getCustom()).latencyMs(response.getLatencyMs()) + .finishReason(FinishReason.INTERRUPTED).finishMessage("One or more tool calls resulted in interrupts.") + .interrupts(toolResult.getInterrupts()).build(); + } + + /** + * Extracts tool request parts from a model response. + */ + private List extractToolRequestParts(ModelResponse response) { + List parts = new java.util.ArrayList<>(); + if (response.getCandidates() != null) { + for (Candidate candidate : response.getCandidates()) { + if (candidate.getMessage() != null && candidate.getMessage().getContent() != null) { + for (Part part : candidate.getMessage().getContent()) { + if (part.getToolRequest() != null) { + parts.add(part); + } + } + } + } + } + return parts; + } + + /** + * Extracts tool requests from a model response. + */ + private List extractToolRequests(ModelResponse response) { + List requests = new java.util.ArrayList<>(); + if (response.getCandidates() != null) { + for (Candidate candidate : response.getCandidates()) { + if (candidate.getMessage() != null && candidate.getMessage().getContent() != null) { + for (Part part : candidate.getMessage().getContent()) { + if (part.getToolRequest() != null) { + requests.add(part.getToolRequest()); + } + } + } + } + } + return requests; + } + + /** + * Result of tool execution with interrupt handling. + */ + private static class ToolExecutionResult { + private final List responses; + private final List interrupts; + private final Map interruptMap; + private final Map pendingOutputMap; + + ToolExecutionResult(List responses, List interrupts, Map interruptMap, + Map pendingOutputMap) { + this.responses = responses; + this.interrupts = interrupts; + this.interruptMap = interruptMap; + this.pendingOutputMap = pendingOutputMap; + } + + List getResponses() { + return responses; + } + List getInterrupts() { + return interrupts; + } + Map getInterruptMap() { + return interruptMap; + } + Map getPendingOutputMap() { + return pendingOutputMap; + } + } + + /** + * Executes tools with interrupt handling. + */ + private ToolExecutionResult executeToolsWithInterruptHandling(ActionContext ctx, List toolRequestParts, + List> tools) { + + List responseParts = new java.util.ArrayList<>(); + List interrupts = new java.util.ArrayList<>(); + Map interruptMap = new java.util.HashMap<>(); + Map pendingOutputMap = new java.util.HashMap<>(); + + for (Part toolRequestPart : toolRequestParts) { + ToolRequest toolRequest = toolRequestPart.getToolRequest(); + String toolName = toolRequest.getName(); + String key = toolName + "#" + (toolRequest.getRef() != null ? toolRequest.getRef() : ""); + + // Find the tool + Tool tool = findTool(toolName, tools); + if (tool == null) { + Part errorPart = new Part(); + ToolResponse errorResponse = new ToolResponse(toolRequest.getRef(), toolName, + Map.of("error", "Tool not found: " + toolName)); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + continue; + } + + // Check if this is an interrupt tool (has "interrupt" metadata marker) + boolean isInterruptTool = tool.getMetadata() != null + && Boolean.TRUE.equals(tool.getMetadata().get("interrupt")); + + try { + // Convert input to the expected type + Object toolInput = toolRequest.getInput(); + Class inputClass = tool.getInputClass(); + if (inputClass != null && toolInput != null && !inputClass.isInstance(toolInput)) { + toolInput = JsonUtils.convert(toolInput, inputClass); + } + + // Execute the tool + @SuppressWarnings("unchecked") + Tool typedTool = (Tool) tool; + Object result = typedTool.run(ctx, toolInput); + + // Create tool response part + Part responsePart = new Part(); + ToolResponse toolResponse = new ToolResponse(toolRequest.getRef(), toolName, result); + responsePart.setToolResponse(toolResponse); + responseParts.add(responsePart); + + // Store pending output in case other tools interrupt + pendingOutputMap.put(key, result); + + logger.debug("Executed tool '{}' successfully", toolName); + + } catch (ToolInterruptException e) { + // Tool interrupted - store the interrupt + Map interruptMetadata = e.getMetadata(); + + Part interruptPart = new Part(); + interruptPart.setToolRequest(toolRequest); + Map metadata = toolRequestPart.getMetadata() != null + ? new java.util.HashMap<>(toolRequestPart.getMetadata()) + : new java.util.HashMap<>(); + metadata.put("interrupt", + interruptMetadata != null && !interruptMetadata.isEmpty() ? interruptMetadata : true); + interruptPart.setMetadata(metadata); + + interrupts.add(interruptPart); + interruptMap.put(key, interruptPart); + + logger.debug("Tool '{}' triggered interrupt", toolName); + + } catch (Exception e) { + logger.error("Tool execution failed for '{}': {}", toolName, e.getMessage()); + Part errorPart = new Part(); + ToolResponse errorResponse = new ToolResponse(toolRequest.getRef(), toolName, + Map.of("error", "Tool execution failed: " + e.getMessage())); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + } + } + + return new ToolExecutionResult(responseParts, interrupts, interruptMap, pendingOutputMap); + } + + /** + * Executes tools and returns the response parts. + */ + private List executeTools(ActionContext ctx, List toolRequests, List> tools) { + List responseParts = new java.util.ArrayList<>(); + + for (ToolRequest toolRequest : toolRequests) { + String toolName = toolRequest.getName(); + Object toolInput = toolRequest.getInput(); + + // Find the tool + Tool tool = findTool(toolName, tools); + if (tool == null) { + // Tool not found, create an error response + Part errorPart = new Part(); + ToolResponse errorResponse = new ToolResponse(toolRequest.getRef(), toolName, + Map.of("error", "Tool not found: " + toolName)); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + continue; + } + + try { + // Execute the tool + @SuppressWarnings("unchecked") + Tool typedTool = (Tool) tool; + Object result = typedTool.run(ctx, toolInput); + + // Create tool response part + Part responsePart = new Part(); + ToolResponse toolResponse = new ToolResponse(toolRequest.getRef(), toolName, result); + responsePart.setToolResponse(toolResponse); + responseParts.add(responsePart); + + logger.debug("Executed tool '{}' successfully", toolName); + } catch (Exception e) { + logger.error("Tool execution failed for '{}': {}", toolName, e.getMessage()); + Part errorPart = new Part(); + ToolResponse errorResponse = new ToolResponse(toolRequest.getRef(), toolName, + Map.of("error", "Tool execution failed: " + e.getMessage())); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + } + } + + return responseParts; + } + + /** + * Finds a tool by name. + */ + private Tool findTool(String toolName, List> tools) { + if (tools != null) { + for (Tool tool : tools) { + if (tool.getName().equals(toolName)) { + return tool; + } + } + } + + // Also try to find in registry + Action action = registry.lookupAction(ActionType.TOOL, toolName); + if (action instanceof Tool) { + return (Tool) action; + } + + return null; + } + + /** + * Generates a streaming model response using the specified options. + * + *

+ * This method invokes the model with streaming enabled, calling the provided + * callback for each chunk of the response as it arrives. This is useful for + * displaying responses incrementally to users. + * + *

+ * Example usage: + * + *

{@code
+   * StringBuilder result = new StringBuilder();
+   * ModelResponse response = genkit.generateStream(
+   * 		GenerateOptions.builder().model("openai/gpt-4o").prompt("Tell me a story").build(), chunk -> {
+   * 			System.out.print(chunk.getText());
+   * 			result.append(chunk.getText());
+   * 		});
+   * }
+ * + * @param options + * the generate options + * @param streamCallback + * callback invoked for each response chunk + * @return the final aggregated model response + * @throws GenkitException + * if generation fails or model doesn't support streaming + */ + public ModelResponse generateStream(GenerateOptions options, + java.util.function.Consumer streamCallback) throws GenkitException { + Model model = getModel(options.getModel()); + if (!model.supportsStreaming()) { + throw new GenkitException("Model " + options.getModel() + " does not support streaming"); + } + ModelRequest request = options.toModelRequest(); + ActionContext ctx = new ActionContext(registry); + + int maxTurns = options.getMaxTurns() != null ? options.getMaxTurns() : 5; + int turn = 0; + + while (turn < maxTurns) { + // Create span metadata for the model call + SpanMetadata modelSpanMetadata = SpanMetadata.builder().name(options.getModel()) + .type(ActionType.MODEL.getValue()).subtype("model").build(); + + String flowName = ctx.getFlowName(); + if (flowName != null) { + modelSpanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); + } + + final ModelRequest currentRequest = request; + ModelResponse response = Tracer.runInNewSpan(ctx, modelSpanMetadata, request, (spanCtx, req) -> { + return model.run(ctx.withSpanContext(spanCtx), currentRequest, streamCallback); + }); + + // Check if the model requested tool calls + List toolRequests = extractToolRequests(response); + if (toolRequests.isEmpty()) { + // No tool calls, return the response + return response; + } + + // Execute tools and build tool response messages + List toolResponseParts = executeTools(ctx, toolRequests, options.getTools()); + + // Add the assistant message with tool requests + Message assistantMessage = response.getMessage(); + List updatedMessages = new java.util.ArrayList<>(request.getMessages()); + updatedMessages.add(assistantMessage); + + // Add tool response message + Message toolResponseMessage = new Message(); + toolResponseMessage.setRole(Role.TOOL); + toolResponseMessage.setContent(toolResponseParts); + updatedMessages.add(toolResponseMessage); + + // Update request with new messages for next turn + request = ModelRequest.builder().messages(updatedMessages).config(request.getConfig()) + .tools(request.getTools()).output(request.getOutput()).build(); + + turn++; + } + + throw new GenkitException("Max tool execution turns (" + maxTurns + ") exceeded"); + } + + /** + * Generates a model response with a simple prompt. + * + * @param modelName + * the model name + * @param prompt + * the prompt text + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse generate(String modelName, String prompt) throws GenkitException { + return generate(GenerateOptions.builder().model(modelName).prompt(prompt).build()); + } + + /** + * Embeds documents using the specified embedder. + * + * @param embedderName + * the embedder name + * @param documents + * the documents to embed + * @return the embed response + * @throws GenkitException + * if embedding fails + */ + public EmbedResponse embed(String embedderName, List documents) throws GenkitException { + Embedder embedder = getEmbedder(embedderName); + EmbedRequest request = new EmbedRequest(documents); + ActionContext ctx = new ActionContext(registry); + return embedder.run(ctx, request); + } + + /** + * Retrieves documents using the specified retriever. + * + *

+ * This is the primary method for retrieval in RAG workflows. The returned + * documents can be passed directly to {@code generate()} via the + * {@code .docs()} option. + * + *

+ * Example usage: + * + *

{@code
+   * // Retrieve relevant documents
+   * List docs = genkit.retrieve("myStore/docs", "What is the capital of France?");
+   * 
+   * // Use documents in generation
+   * ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini")
+   * 		.prompt("Answer the question based on context").docs(docs).build());
+   * }
+ * + * @param retrieverName + * the retriever name + * @param query + * the query text + * @return the list of retrieved documents + * @throws GenkitException + * if retrieval fails + */ + public List retrieve(String retrieverName, String query) throws GenkitException { + Retriever retriever = getRetriever(retrieverName); + RetrieverRequest request = RetrieverRequest.fromText(query); + ActionContext ctx = new ActionContext(registry); + RetrieverResponse response = retriever.run(ctx, request); + return response.getDocuments(); + } + + /** + * Retrieves documents using the specified retriever with options. + * + *

+ * Example usage: + * + *

{@code
+   * List docs = genkit.retrieve("myStore/docs", "query", RetrieverParams.builder().k(5).build());
+   * }
+ * + * @param retrieverName + * the retriever name + * @param query + * the query text + * @param options + * retrieval options (e.g., k for number of results) + * @return the list of retrieved documents + * @throws GenkitException + * if retrieval fails + */ + public List retrieve(String retrieverName, String query, RetrieverRequest.RetrieverOptions options) + throws GenkitException { + Retriever retriever = getRetriever(retrieverName); + RetrieverRequest request = RetrieverRequest.fromText(query); + request.setOptions(options); + ActionContext ctx = new ActionContext(registry); + RetrieverResponse response = retriever.run(ctx, request); + return response.getDocuments(); + } + + /** + * Retrieves documents using a Document as the query. + * + * @param retrieverName + * the retriever name + * @param query + * the query document + * @return the list of retrieved documents + * @throws GenkitException + * if retrieval fails + */ + public List retrieve(String retrieverName, Document query) throws GenkitException { + Retriever retriever = getRetriever(retrieverName); + RetrieverRequest request = new RetrieverRequest(query); + ActionContext ctx = new ActionContext(registry); + RetrieverResponse response = retriever.run(ctx, request); + return response.getDocuments(); + } + + /** + * Indexes documents using the specified indexer. + * + *

+ * Example usage: + * + *

{@code
+   * List docs = List.of(Document.fromText("Paris is the capital of France."),
+   * 		Document.fromText("Berlin is the capital of Germany."));
+   * genkit.index("myStore/docs", docs);
+   * }
+ * + * @param indexerName + * the indexer name + * @param documents + * the documents to index + * @throws GenkitException + * if indexing fails + */ + public void index(String indexerName, List documents) throws GenkitException { + Indexer indexer = getIndexer(indexerName); + IndexerRequest request = new IndexerRequest(documents); + ActionContext ctx = new ActionContext(registry); + indexer.run(ctx, request); + } + + /** + * Gets an indexer by name. + * + * @param name + * the indexer name + * @return the indexer + */ + public Indexer getIndexer(String name) { + Action action = registry.lookupAction(ActionType.INDEXER, name); + if (action == null) { + throw new GenkitException("Indexer not found: " + name); + } + return (Indexer) action; + } + + /** + * Runs a flow by name. + * + * @param + * the input type + * @param + * the output type + * @param flowName + * the flow name + * @param input + * the flow input + * @return the flow output + * @throws GenkitException + * if execution fails + */ + @SuppressWarnings("unchecked") + public O runFlow(String flowName, I input) throws GenkitException { + Action action = registry.lookupAction(ActionType.FLOW, flowName); + if (action == null) { + throw new GenkitException("Flow not found: " + flowName); + } + Flow flow = (Flow) action; + ActionContext ctx = new ActionContext(registry); + return flow.run(ctx, input); + } + + /** + * Gets the registry. + * + * @return the registry + */ + public Registry getRegistry() { + return registry; + } + + /** + * Gets the options. + * + * @return the options + */ + public GenkitOptions getOptions() { + return options; + } + + /** + * Gets the registered plugins. + * + * @return the plugins + */ + public List getPlugins() { + return plugins; + } + + /** + * Starts the reflection server for dev tools integration. + */ + private void startReflectionServer() { + try { + int port = options.getReflectionPort(); + reflectionServer = new ReflectionServer(registry, port); + reflectionServer.start(); + logger.info("Reflection server started on port {}", port); + + // Write runtime file with matching runtime ID + RuntimeFileWriter.write(port, reflectionServer.getRuntimeId()); + } catch (Exception e) { + logger.error("Failed to start reflection server", e); + throw new GenkitException("Failed to start reflection server", e); + } + } + + /** + * Stops the Genkit instance and cleans up resources. + */ + public void stop() { + if (reflectionServer != null) { + try { + reflectionServer.stop(); + RuntimeFileWriter.cleanup(); + } catch (Exception e) { + logger.warn("Error stopping reflection server", e); + } + } + } + + // ========================================================================= + // Session Methods + // ========================================================================= + + /** + * Creates a new session with default options. + * + *

+ * Sessions provide stateful multi-turn conversations with automatic history + * persistence. Each session can have multiple named conversation threads. + * + *

+ * Example usage: + * + *

{@code
+   * Session session = genkit.createSession();
+   * Chat chat = session
+   * 		.chat(ChatOptions.builder().model("openai/gpt-4o").system("You are a helpful assistant.").build());
+   * chat.send("Hello!");
+   * }
+ * + * @param + * the session state type + * @return a new session + */ + public Session createSession() { + return Session.create(registry, SessionOptions.builder().build(), agentRegistry); + } + + /** + * Creates a new session with the given options. + * + *

+ * Example usage: + * + *

{@code
+   * // With custom state
+   * Session session = genkit
+   * 		.createSession(SessionOptions.builder().initialState(new MyState("John")).build());
+   *
+   * // With custom store and session ID
+   * Session session = genkit.createSession(SessionOptions.builder()
+   * 		.store(new RedisSessionStore<>()).sessionId("my-session-123").initialState(new MyState()).build());
+   * }
+ * + * @param + * the session state type + * @param options + * the session options + * @return a new session + */ + public Session createSession(SessionOptions options) { + return Session.create(registry, options, agentRegistry); + } + + /** + * Loads an existing session from a store. + * + *

+ * Example usage: + * + *

{@code
+   * CompletableFuture> sessionFuture = genkit.loadSession("session-123",
+   * 		SessionOptions.builder().store(mySessionStore).build());
+   * Session session = sessionFuture.get();
+   * if (session != null) {
+   * 	Chat chat = session.chat();
+   * 	// Continue conversation...
+   * }
+   * }
+ * + * @param + * the session state type + * @param sessionId + * the session ID to load + * @param options + * the session options (must include store) + * @return a CompletableFuture containing the session, or null if not found + */ + public CompletableFuture> loadSession(String sessionId, SessionOptions options) { + return Session.load(registry, sessionId, options, agentRegistry); + } + + /** + * Creates a simple chat without session persistence. + * + *

+ * This is a convenience method for quick interactions without full session + * management. Use {@link #createSession()} for persistent multi-turn + * conversations. + * + *

+ * Example usage: + * + *

{@code
+   * Chat chat = genkit
+   * 		.chat(ChatOptions.builder().model("openai/gpt-4o").system("You are a helpful assistant.").build());
+   * ModelResponse response = chat.send("Hello!");
+   * }
+ * + * @param + * the state type (usually Void for simple chats) + * @param options + * the chat options + * @return a new chat instance + */ + public Chat chat(ChatOptions options) { + Session session = createSession(); + return session.chat(options); + } + + // ========================================================================= + // Agent and Interrupt Methods + // ========================================================================= + + /** + * Defines an agent that can be used as a tool in multi-agent systems. + * + *

+ * Agents are specialized conversational components that can be delegated to by + * other agents. When an agent is called as a tool, it takes over the + * conversation with its own system prompt, model, and tools. + * + *

+ * Example usage: + * + *

{@code
+   * // Define a specialized agent
+   * Agent reservationAgent = genkit.defineAgent(AgentConfig.builder().name("reservationAgent")
+   * 		.description("Handles restaurant reservations").system("You are a reservation specialist...")
+   * 		.model("openai/gpt-4o").tools(List.of(reservationTool, lookupTool)).build());
+   *
+   * // Use in a parent agent
+   * Agent triageAgent = genkit.defineAgent(
+   * 		AgentConfig.builder().name("triageAgent").description("Routes customer requests to specialists")
+   * 				.system("You route customer requests to the appropriate specialist")
+   * 				.agents(List.of(reservationAgent.getConfig())).build());
+   *
+   * // Start chat with triage agent
+   * Chat chat = genkit.chat(ChatOptions.builder().model("openai/gpt-4o").system(triageAgent.getSystem())
+   * 		.tools(triageAgent.getAllTools(agentRegistry)).build());
+   * }
+ * + * @param config + * the agent configuration + * @return the created agent + */ + public Agent defineAgent(AgentConfig config) { + Agent agent = new Agent(config); + // Register the agent as a tool + registry.registerAction(ActionType.TOOL, agent.asTool()); + // Register in agent registry for getAllTools lookup + agentRegistry.put(config.getName(), agent); + return agent; + } + + /** + * Gets an agent by name. + * + * @param name + * the agent name + * @return the agent, or null if not found + */ + public Agent getAgent(String name) { + return agentRegistry.get(name); + } + + /** + * Gets the agent registry. + * + *

+ * This returns an unmodifiable view of all registered agents. + * + * @return the agent registry + */ + public Map getAgentRegistry() { + return java.util.Collections.unmodifiableMap(agentRegistry); + } + + /** + * Gets all tools for an agent, including sub-agent tools. + * + *

+ * This is a convenience method that collects all tools from an agent, including + * tools from any sub-agents defined in its configuration. + * + * @param agent + * the agent + * @return the list of all tools + */ + public List> getAllToolsForAgent(Agent agent) { + return agent.getAllTools(agentRegistry); + } + + /** + * Defines an interrupt tool for human-in-the-loop interactions. + * + *

+ * Interrupts allow tools to pause generation and request user input. When a + * tool throws a {@link ToolInterruptException}, the chat returns early with the + * interrupt information, allowing the application to collect user input and + * resume. + * + *

+ * Example usage: + * + *

{@code
+   * // Define an interrupt for confirming actions
+   * Tool confirmInterrupt = genkit.defineInterrupt(InterruptConfig
+   * 		.builder().name("confirm").description("Asks user to confirm an action")
+   * 		.inputType(ConfirmInput.class).outputType(ConfirmOutput.class).build());
+   *
+   * // Use in a chat with tools
+   * Chat chat = genkit.chat(
+   * 		ChatOptions.builder().model("openai/gpt-4o").tools(List.of(someActionTool, confirmInterrupt)).build());
+   *
+   * ModelResponse response = chat.send("Book a table for 4");
+   *
+   * // Check for interrupts
+   * if (chat.hasPendingInterrupts()) {
+   * 	List interrupts = chat.getPendingInterrupts();
+   * 	// Show UI to user, collect response
+   * 	ConfirmOutput userResponse = getUserConfirmation(interrupts.get(0));
+   * 
+   * 	// Resume with user response
+   * 	response = chat.send("",
+   * 			SendOptions.builder().resumeOptions(
+   * 					ResumeOptions.builder().respond(List.of(interrupts.get(0).respond(userResponse))).build())
+   * 					.build());
+   * }
+   * }
+ * + * @param + * the interrupt input type + * @param + * the interrupt output type (user response) + * @param config + * the interrupt configuration + * @return the interrupt as a tool + */ + public Tool defineInterrupt(InterruptConfig config) { + Map inputSchema = config.getInputSchema(); + if (inputSchema == null) { + inputSchema = new java.util.HashMap<>(); + inputSchema.put("type", "object"); + } + + Map outputSchema = config.getOutputSchema(); + if (outputSchema == null) { + outputSchema = new java.util.HashMap<>(); + outputSchema.put("type", "object"); + } + + Tool interruptTool = new Tool<>(config.getName(), + config.getDescription() != null ? config.getDescription() : "Interrupt: " + config.getName(), + inputSchema, outputSchema, config.getInputType(), (ctx, input) -> { + // Build metadata from input - create a mutable copy since user may return + // immutable map + Map metadata = new java.util.HashMap<>(); + if (config.getRequestMetadata() != null) { + metadata.putAll(config.getRequestMetadata().apply(input)); + } + metadata.put("interrupt", true); + metadata.put("interruptName", config.getName()); + metadata.put("input", input); + + // Throw interrupt exception - this never returns + throw new ToolInterruptException(metadata); + }); + + // Register the interrupt tool + registry.registerAction(ActionType.TOOL, interruptTool); + return interruptTool; + } + + /** + * Gets the current session from the context. + * + *

+ * This method can be called from within tool execution to access the current + * session state. It uses a thread-local context that is set during chat + * execution. + * + *

+ * Example usage: + * + *

{@code
+   * Tool myTool = genkit.defineTool("myTool", Input.class, Output.class, (ctx, input) -> {
+   * 	Session session = genkit.currentSession();
+   * 	if (session != null) {
+   * 		Object state = session.getState();
+   * 		// Use session state...
+   * 	}
+   * 	return new Output();
+   * });
+   * }
+ * + * @param + * the session state type + * @return the current session, or null if not in a session context + */ + @SuppressWarnings("unchecked") + public Session currentSession() { + return (Session) SessionContext.currentSession(); + } + + // ========================================================================= + // Evaluation Methods + // ========================================================================= + + /** + * Defines a new evaluator and registers it with the registry. + * + *

+ * Evaluators assess the quality of AI outputs. They can be used to: + *

    + *
  • Score outputs based on various criteria (accuracy, relevance, etc.)
  • + *
  • Compare outputs against reference data
  • + *
  • Run automated quality checks in CI/CD pipelines
  • + *
+ * + *

+ * Example usage: + * + *

{@code
+   * genkit.defineEvaluator("myEvaluator", "My Evaluator", "Checks output quality", (dataPoint, options) -> {
+   * 	// Evaluate the output
+   * 	double score = calculateScore(dataPoint.getOutput());
+   * 	return EvalResponse.builder().testCaseId(dataPoint.getTestCaseId())
+   * 			.evaluation(Score.builder().score(score).build()).build();
+   * });
+   * }
+ * + * @param + * the options type + * @param name + * the evaluator name + * @param displayName + * the display name shown in the UI + * @param definition + * description of what the evaluator measures + * @param evaluatorFn + * the evaluation function + * @return the created evaluator + */ + public Evaluator defineEvaluator(String name, String displayName, String definition, + EvaluatorFn evaluatorFn) { + return Evaluator.define(registry, name, displayName, definition, evaluatorFn); + } + + /** + * Defines a new evaluator with full options. + * + * @param + * the options type + * @param name + * the evaluator name + * @param displayName + * the display name shown in the UI + * @param definition + * description of what the evaluator measures + * @param isBilled + * whether using this evaluator incurs costs + * @param optionsClass + * the class for evaluator-specific options + * @param evaluatorFn + * the evaluation function + * @return the created evaluator + */ + public Evaluator defineEvaluator(String name, String displayName, String definition, boolean isBilled, + Class optionsClass, EvaluatorFn evaluatorFn) { + return Evaluator.define(registry, name, displayName, definition, isBilled, optionsClass, evaluatorFn); + } + + /** + * Gets an evaluator by name. + * + * @param name + * the evaluator name + * @return the evaluator + * @throws GenkitException + * if evaluator not found + */ + @SuppressWarnings("unchecked") + public Evaluator getEvaluator(String name) { + Action action = registry.lookupAction(ActionType.EVALUATOR, name); + if (action == null) { + throw new GenkitException("Evaluator not found: " + name); + } + return (Evaluator) action; + } + + /** + * Runs an evaluation using the specified request. + * + *

+ * This method: + *

    + *
  1. Loads the dataset
  2. + *
  3. Runs inference on the target action
  4. + *
  5. Executes all specified evaluators
  6. + *
  7. Stores and returns the results
  8. + *
+ * + * @param request + * the evaluation request + * @return the evaluation run key + * @throws Exception + * if evaluation fails + */ + public EvalRunKey evaluate(RunEvaluationRequest request) throws Exception { + return getEvaluationManager().runEvaluation(request); + } + + /** + * Gets the evaluation manager. + * + * @return the evaluation manager + */ + public synchronized EvaluationManager getEvaluationManager() { + if (evaluationManager == null) { + evaluationManager = new EvaluationManager(registry); + } + return evaluationManager; + } + + /** + * Gets the dataset store. + * + * @return the dataset store + */ + public DatasetStore getDatasetStore() { + return getEvaluationManager().getDatasetStore(); + } + + /** + * Gets the eval store. + * + * @return the eval store + */ + public EvalStore getEvalStore() { + return getEvaluationManager().getEvalStore(); + } + + /** + * Builder for Genkit. + */ + public static class Builder { + private final List plugins = new ArrayList<>(); + private GenkitOptions options = GenkitOptions.builder().build(); + + /** + * Sets the Genkit options. + * + * @param options + * the options + * @return this builder + */ + public Builder options(GenkitOptions options) { + this.options = options; + return this; + } + + /** + * Adds a plugin. + * + * @param plugin + * the plugin to add + * @return this builder + */ + public Builder plugin(Plugin plugin) { + this.plugins.add(plugin); + return this; + } + + /** + * Enables dev mode. + * + * @return this builder + */ + public Builder devMode() { + this.options = GenkitOptions.builder().devMode(true).build(); + return this; + } + + /** + * Sets the reflection port. + * + * @param port + * the port number + * @return this builder + */ + public Builder reflectionPort(int port) { + this.options = GenkitOptions.builder().devMode(options.isDevMode()).reflectionPort(port).build(); + return this; + } + + /** + * Builds the Genkit instance. + * + * @return the configured Genkit instance + */ + public Genkit build() { + Genkit genkit = new Genkit(options); + genkit.plugins.addAll(plugins); + genkit.init(); + return genkit; + } + } +} diff --git a/java/genkit/src/main/java/com/google/genkit/GenkitOptions.java b/java/genkit/src/main/java/com/google/genkit/GenkitOptions.java new file mode 100644 index 0000000000..13ffc00da5 --- /dev/null +++ b/java/genkit/src/main/java/com/google/genkit/GenkitOptions.java @@ -0,0 +1,134 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit; + +/** + * GenkitOptions contains configuration options for Genkit. + */ +public class GenkitOptions { + + private final boolean devMode; + private final int reflectionPort; + private final String projectRoot; + private final String promptDir; + + private GenkitOptions(Builder builder) { + this.devMode = builder.devMode; + this.reflectionPort = builder.reflectionPort; + this.projectRoot = builder.projectRoot; + this.promptDir = builder.promptDir; + } + + /** + * Creates a new builder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Returns whether dev mode is enabled. + * + * @return true if dev mode is enabled + */ + public boolean isDevMode() { + return devMode; + } + + /** + * Returns the reflection server port. + * + * @return the port number + */ + public int getReflectionPort() { + return reflectionPort; + } + + /** + * Returns the project root directory. + * + * @return the project root path + */ + public String getProjectRoot() { + return projectRoot; + } + + /** + * Returns the prompt directory path (relative to resources or absolute). + * Defaults to "/prompts" for loading from classpath resources. + * + * @return the prompt directory path + */ + public String getPromptDir() { + return promptDir; + } + + /** + * Builder for GenkitOptions. + */ + public static class Builder { + private boolean devMode = isDevModeFromEnv(); + private int reflectionPort = getReflectionPortFromEnv(); + private String projectRoot = System.getProperty("user.dir"); + private String promptDir = "/prompts"; + + private static boolean isDevModeFromEnv() { + String env = System.getenv("GENKIT_ENV"); + return "dev".equals(env) || env == null; + } + + private static int getReflectionPortFromEnv() { + String port = System.getenv("GENKIT_REFLECTION_PORT"); + if (port != null) { + try { + return Integer.parseInt(port); + } catch (NumberFormatException e) { + // fall through to default + } + } + return 3100; + } + + public Builder devMode(boolean devMode) { + this.devMode = devMode; + return this; + } + + public Builder reflectionPort(int reflectionPort) { + this.reflectionPort = reflectionPort; + return this; + } + + public Builder projectRoot(String projectRoot) { + this.projectRoot = projectRoot; + return this; + } + + public Builder promptDir(String promptDir) { + this.promptDir = promptDir; + return this; + } + + public GenkitOptions build() { + return new GenkitOptions(this); + } + } +} diff --git a/java/genkit/src/main/java/com/google/genkit/ReflectionServer.java b/java/genkit/src/main/java/com/google/genkit/ReflectionServer.java new file mode 100644 index 0000000000..2d483f443d --- /dev/null +++ b/java/genkit/src/main/java/com/google/genkit/ReflectionServer.java @@ -0,0 +1,516 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit; + +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.eclipse.jetty.server.Handler; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Response; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.util.Callback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.genkit.ai.evaluation.*; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionDesc; +import com.google.genkit.core.ActionRunResult; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.JsonUtils; +import com.google.genkit.core.Registry; +import com.google.genkit.core.tracing.Tracer; + +/** + * ReflectionServer provides an HTTP API for the Genkit Developer UI to interact + * with. + * + * It exposes endpoints for listing actions, running actions, querying traces, + * and evaluation. + */ +public class ReflectionServer { + + private static final Logger logger = LoggerFactory.getLogger(ReflectionServer.class); + + private final Registry registry; + private final int port; + private Server server; + private String runtimeId; + private EvaluationManager evaluationManager; + + /** + * Creates a new ReflectionServer. + * + * @param registry + * the Genkit registry + * @param port + * the port to listen on + */ + public ReflectionServer(Registry registry, int port) { + this.registry = registry; + this.port = port; + this.runtimeId = "java-" + ProcessHandle.current().pid() + "-" + System.currentTimeMillis(); + this.evaluationManager = new EvaluationManager(registry); + } + + /** + * Gets the runtime ID. + */ + public String getRuntimeId() { + return runtimeId; + } + + /** + * Starts the reflection server. + * + * @throws Exception + * if the server fails to start + */ + public void start() throws Exception { + server = new Server(); + + // Configure connector with extended idle timeout for long-running operations + // (e.g., video generation can take several minutes) + ServerConnector connector = new ServerConnector(server); + connector.setPort(port); + connector.setIdleTimeout(900000); // 15 minutes idle timeout + server.addConnector(connector); + + server.setHandler(new ReflectionHandler()); + server.start(); + logger.info("Reflection server started on port {}", port); + } + + /** + * Stops the reflection server. + * + * @throws Exception + * if the server fails to stop + */ + public void stop() throws Exception { + if (server != null) { + server.stop(); + logger.info("Reflection server stopped"); + } + } + + /** + * Handler for reflection API requests using Jetty 12 Handler.Abstract. + */ + private class ReflectionHandler extends Handler.Abstract { + + @Override + public boolean handle(Request request, Response response, Callback callback) throws Exception { + String target = request.getHttpURI().getPath(); + String method = request.getMethod(); + + // Enable CORS + response.getHeaders().add("Access-Control-Allow-Origin", "*"); + response.getHeaders().add("Access-Control-Allow-Methods", "GET, POST, OPTIONS"); + response.getHeaders().add("Access-Control-Allow-Headers", "Content-Type, Accept"); + + if ("OPTIONS".equals(method)) { + response.setStatus(200); + callback.succeeded(); + return true; + } + + // Handle health check separately (no JSON content type needed) + if ("/api/__health".equals(target)) { + String query = request.getHttpURI().getQuery(); + String idParam = null; + if (query != null) { + for (String param : query.split("&")) { + if (param.startsWith("id=")) { + idParam = param.substring(3); + break; + } + } + } + // If ID is provided, it must match our runtime ID + if (idParam != null && !idParam.equals(runtimeId)) { + response.setStatus(503); + callback.succeeded(); + return true; + } + response.setStatus(200); + callback.succeeded(); + return true; + } + + response.getHeaders().add("Content-Type", "application/json"); + + try { + String result; + int status = 200; + + if ("/api/actions".equals(target)) { + result = handleListActions(); + } else if (target.startsWith("/api/actions/")) { + String actionKey = target.substring("/api/actions/".length()); + result = handleGetAction(actionKey); + } else if ("/api/runAction".equals(target)) { + String body = readRequestBody(request); + result = handleRunAction(body); + } else if ("/api/notify".equals(target) && "POST".equals(method)) { + String body = readRequestBody(request); + result = handleNotify(body); + } else if (target.startsWith("/api/envs/") && target.contains("/traces")) { + result = handleListTraces(); + } + // Dataset endpoints + else if ("/api/datasets".equals(target) && "GET".equals(method)) { + result = handleListDatasets(); + } else if ("/api/datasets".equals(target) && "POST".equals(method)) { + String body = readRequestBody(request); + result = handleCreateDataset(body); + } else if (target.startsWith("/api/datasets/") && "GET".equals(method)) { + String datasetId = target.substring("/api/datasets/".length()); + result = handleGetDataset(datasetId); + } else if (target.startsWith("/api/datasets/") && "PUT".equals(method)) { + String body = readRequestBody(request); + result = handleUpdateDataset(body); + } else if (target.startsWith("/api/datasets/") && "DELETE".equals(method)) { + String datasetId = target.substring("/api/datasets/".length()); + result = handleDeleteDataset(datasetId); + } + // Evaluation endpoints + else if ("/api/evalRuns".equals(target) && "GET".equals(method)) { + result = handleListEvalRuns(); + } else if (target.startsWith("/api/evalRuns/") && "GET".equals(method)) { + String evalRunId = target.substring("/api/evalRuns/".length()); + result = handleGetEvalRun(evalRunId); + } else if (target.startsWith("/api/evalRuns/") && "DELETE".equals(method)) { + String evalRunId = target.substring("/api/evalRuns/".length()); + result = handleDeleteEvalRun(evalRunId); + } else if ("/api/runEvaluation".equals(target) && "POST".equals(method)) { + String body = readRequestBody(request); + result = handleRunEvaluation(body); + } else { + status = 404; + result = createErrorResponse(5, "Not found", null); // NOT_FOUND code = 5 + } + + response.setStatus(status); + byte[] bytes = result.getBytes(StandardCharsets.UTF_8); + response.write(true, ByteBuffer.wrap(bytes), callback); + + } catch (Exception e) { + logger.error("Error handling request", e); + response.setStatus(500); + String errorMessage = e.getMessage() != null ? e.getMessage() : "Unknown error"; + String stacktrace = getStackTraceString(e); + // For HTTP 500 errors, send error status directly (no wrapper) + String errorJson = createErrorStatus(2, errorMessage, stacktrace); // INTERNAL error code = 2 + byte[] bytes = errorJson.getBytes(StandardCharsets.UTF_8); + response.write(true, ByteBuffer.wrap(bytes), callback); + } + + return true; + } + + private String getStackTraceString(Throwable e) { + java.io.StringWriter sw = new java.io.StringWriter(); + e.printStackTrace(new java.io.PrintWriter(sw)); + return sw.toString(); + } + + /** + * Creates a structured error status JSON string (without wrapper). Format: + * {code, message, details: {stack}} Used for HTTP 500 error responses where the + * body IS the error. + */ + private String createErrorStatus(int code, String message, String stack) { + Map errorDetails = new HashMap<>(); + if (stack != null) { + errorDetails.put("stack", stack); + } + + Map errorStatus = new HashMap<>(); + errorStatus.put("code", code); + errorStatus.put("message", message); + errorStatus.put("details", errorDetails); + + return JsonUtils.toJson(errorStatus); + } + + /** + * Creates a wrapped error response JSON string. Format: {error: {code, message, + * details: {stack}}} Used for inline errors in 200 OK responses (e.g., action + * not found). + */ + private String createErrorResponse(int code, String message, String stack) { + Map errorDetails = new HashMap<>(); + if (stack != null) { + errorDetails.put("stack", stack); + } + + Map errorStatus = new HashMap<>(); + errorStatus.put("code", code); + errorStatus.put("message", message); + errorStatus.put("details", errorDetails); + + Map errorResponse = new HashMap<>(); + errorResponse.put("error", errorStatus); + + return JsonUtils.toJson(errorResponse); + } + + private String readRequestBody(Request request) throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + InputStream is = Request.asInputStream(request); + byte[] buffer = new byte[4096]; + int bytesRead; + while ((bytesRead = is.read(buffer)) != -1) { + baos.write(buffer, 0, bytesRead); + } + return baos.toString(StandardCharsets.UTF_8); + } + + private String handleListActions() { + List> actions = registry.listActions(); + + // Return as object keyed by action key (not array) + Map> actionMap = new HashMap<>(); + + for (Action action : actions) { + Map actionInfo = new HashMap<>(); + ActionDesc desc = action.getDesc(); + String key; + if (desc != null) { + key = desc.getKey(); + actionInfo.put("key", key); + actionInfo.put("name", desc.getName()); + actionInfo.put("description", desc.getDescription()); + actionInfo.put("inputSchema", desc.getInputSchema()); + actionInfo.put("outputSchema", desc.getOutputSchema()); + actionInfo.put("metadata", desc.getMetadata()); + } else { + key = action.getType().keyFromName(action.getName()); + actionInfo.put("key", key); + actionInfo.put("name", action.getName()); + } + actionMap.put(key, actionInfo); + } + + return JsonUtils.toJson(actionMap); + } + + private String handleGetAction(String actionKey) { + Action action = registry.lookupAction(actionKey); + if (action == null) { + return createErrorResponse(5, "Action not found: " + actionKey, null); // NOT_FOUND code = 5 + } + + Map actionInfo = new HashMap<>(); + ActionDesc desc = action.getDesc(); + if (desc != null) { + actionInfo.put("key", desc.getKey()); + actionInfo.put("name", desc.getName()); + actionInfo.put("description", desc.getDescription()); + actionInfo.put("inputSchema", desc.getInputSchema()); + actionInfo.put("outputSchema", desc.getOutputSchema()); + actionInfo.put("metadata", desc.getMetadata()); + } else { + actionInfo.put("key", actionKey); + actionInfo.put("name", action.getName()); + } + + return JsonUtils.toJson(actionInfo); + } + + private String handleRunAction(String body) throws GenkitException { + JsonNode requestNode = JsonUtils.parseJson(body); + + String key = requestNode.has("key") ? requestNode.get("key").asText() : null; + JsonNode input = requestNode.has("input") ? requestNode.get("input") : null; + + if (key == null) { + throw new GenkitException("Missing 'key' in request body"); + } + + Action action = registry.lookupAction(key); + if (action == null) { + throw new GenkitException("Action not found: " + key); + } + + ActionContext context = new ActionContext(registry); + + ActionRunResult result = action.runJsonWithTelemetry(context, input, null); + + Map response = new HashMap<>(); + response.put("result", result.getResult()); + response.put("traceId", result.getTraceId()); + + return JsonUtils.toJson(response); + } + + private String handleListTraces() { + // Return empty traces for now - this would need a proper trace store + return "{\"traces\":[]}"; + } + + /** + * Handle the notify endpoint from the Genkit CLI. This is used to receive + * configuration like the telemetry server URL. + */ + private String handleNotify(String body) { + try { + JsonNode requestNode = JsonUtils.parseJson(body); + + String telemetryServerUrl = requestNode.has("telemetryServerUrl") + ? requestNode.get("telemetryServerUrl").asText() + : null; + int reflectionApiSpecVersion = requestNode.has("reflectionApiSpecVersion") + ? requestNode.get("reflectionApiSpecVersion").asInt() + : 0; + + if (telemetryServerUrl != null && !telemetryServerUrl.isEmpty()) { + // Configure the telemetry exporter with the server URL + Tracer.configureTelemetryServer(telemetryServerUrl); + } + + // Warn if version mismatch + if (reflectionApiSpecVersion != 0 && reflectionApiSpecVersion != 1) { + logger.warn("Genkit CLI version may not be compatible with runtime library."); + } + + return "\"OK\""; + } catch (Exception e) { + logger.error("Error handling notify request", e); + String errorMessage = e.getMessage() != null ? e.getMessage() : "Unknown error"; + String stacktrace = getStackTraceString(e); + return createErrorResponse(2, errorMessage, stacktrace); // INTERNAL error code = 2 + } + } + + // ===================================================================== + // Dataset Endpoints + // ===================================================================== + + private String handleListDatasets() { + try { + List datasets = evaluationManager.getDatasetStore().listDatasets(); + return JsonUtils.toJson(datasets); + } catch (Exception e) { + logger.error("Error listing datasets", e); + return createErrorResponse(2, e.getMessage(), getStackTraceString(e)); + } + } + + private String handleGetDataset(String datasetId) { + try { + List dataset = evaluationManager.getDatasetStore().getDataset(datasetId); + return JsonUtils.toJson(dataset); + } catch (Exception e) { + logger.error("Error getting dataset: {}", datasetId, e); + return createErrorResponse(5, e.getMessage(), getStackTraceString(e)); + } + } + + private String handleCreateDataset(String body) { + try { + CreateDatasetRequest request = JsonUtils.fromJson(body, CreateDatasetRequest.class); + DatasetMetadata metadata = evaluationManager.getDatasetStore().createDataset(request); + return JsonUtils.toJson(metadata); + } catch (Exception e) { + logger.error("Error creating dataset", e); + return createErrorResponse(2, e.getMessage(), getStackTraceString(e)); + } + } + + private String handleUpdateDataset(String body) { + try { + UpdateDatasetRequest request = JsonUtils.fromJson(body, UpdateDatasetRequest.class); + DatasetMetadata metadata = evaluationManager.getDatasetStore().updateDataset(request); + return JsonUtils.toJson(metadata); + } catch (Exception e) { + logger.error("Error updating dataset", e); + return createErrorResponse(2, e.getMessage(), getStackTraceString(e)); + } + } + + private String handleDeleteDataset(String datasetId) { + try { + evaluationManager.getDatasetStore().deleteDataset(datasetId); + return "{}"; + } catch (Exception e) { + logger.error("Error deleting dataset: {}", datasetId, e); + return createErrorResponse(2, e.getMessage(), getStackTraceString(e)); + } + } + + // ===================================================================== + // Evaluation Endpoints + // ===================================================================== + + private String handleListEvalRuns() { + try { + List evalRuns = evaluationManager.getEvalStore().list(); + return JsonUtils.toJson(evalRuns); + } catch (Exception e) { + logger.error("Error listing eval runs", e); + return createErrorResponse(2, e.getMessage(), getStackTraceString(e)); + } + } + + private String handleGetEvalRun(String evalRunId) { + try { + EvalRun evalRun = evaluationManager.getEvalStore().load(evalRunId); + if (evalRun == null) { + return createErrorResponse(5, "Eval run not found: " + evalRunId, null); + } + return JsonUtils.toJson(evalRun); + } catch (Exception e) { + logger.error("Error getting eval run: {}", evalRunId, e); + return createErrorResponse(2, e.getMessage(), getStackTraceString(e)); + } + } + + private String handleDeleteEvalRun(String evalRunId) { + try { + evaluationManager.getEvalStore().delete(evalRunId); + return "{}"; + } catch (Exception e) { + logger.error("Error deleting eval run: {}", evalRunId, e); + return createErrorResponse(2, e.getMessage(), getStackTraceString(e)); + } + } + + private String handleRunEvaluation(String body) { + try { + RunEvaluationRequest request = JsonUtils.fromJson(body, RunEvaluationRequest.class); + EvalRunKey evalRunKey = evaluationManager.runEvaluation(request); + return JsonUtils.toJson(evalRunKey); + } catch (Exception e) { + logger.error("Error running evaluation", e); + return createErrorResponse(2, e.getMessage(), getStackTraceString(e)); + } + } + } +} diff --git a/java/genkit/src/main/java/com/google/genkit/RuntimeFileWriter.java b/java/genkit/src/main/java/com/google/genkit/RuntimeFileWriter.java new file mode 100644 index 0000000000..c9bbd21918 --- /dev/null +++ b/java/genkit/src/main/java/com/google/genkit/RuntimeFileWriter.java @@ -0,0 +1,215 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit; + +import java.io.IOException; +import java.nio.file.*; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.util.HashMap; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * RuntimeFileWriter writes runtime discovery files for the Genkit Dev UI. + * + * The Dev UI discovers running Genkit instances by looking for JSON files in + * the .genkit/runtimes directory. + */ +public class RuntimeFileWriter { + + private static final Logger logger = LoggerFactory.getLogger(RuntimeFileWriter.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + private static Path currentRuntimeFile; + + /** + * Writes a runtime file for Dev UI discovery. Uses findProjectRoot() to locate + * the project root by searching up for pom.xml. + * + * @param port + * the reflection server port + * @param runtimeId + * the runtime ID from the reflection server + */ + public static void write(int port, String runtimeId) { + String projectRoot = findProjectRoot(); + writeRuntimeFile(port, projectRoot, runtimeId); + } + + /** + * Finds the project root by searching up from the current directory. + * Prioritizes package.json to match the genkit CLI's behavior - this ensures + * the Java runtime writes to the same .genkit directory the CLI reads from. + * Falls back to pom.xml/build.gradle only if no package.json is found. + * + * @return the project root directory path + */ + private static String findProjectRoot() { + Path dir = Paths.get(System.getProperty("user.dir")).toAbsolutePath(); + + // First pass: Look for package.json (CLI primary marker) to ensure we match CLI + // behavior + // The CLI looks for package.json first, so we need to find the same root it + // uses + Path cliRoot = null; + Path currentDir = dir; + while (currentDir != null) { + Path packageJson = currentDir.resolve("package.json"); + if (Files.exists(packageJson)) { + cliRoot = currentDir; + logger.debug("Found CLI project root at: {} (found package.json)", currentDir); + break; + } + Path parent = currentDir.getParent(); + if (parent == null || parent.equals(currentDir)) { + break; + } + currentDir = parent; + } + + // If we found a package.json (CLI root), use that + if (cliRoot != null) { + return cliRoot.toString(); + } + + // Second pass: Fall back to Java/other markers if no package.json found + String[] fallbackMarkers = {"pom.xml", "build.gradle", "go.mod", "pyproject.toml", "requirements.txt"}; + currentDir = dir; + while (currentDir != null) { + for (String marker : fallbackMarkers) { + Path markerFile = currentDir.resolve(marker); + if (Files.exists(markerFile)) { + logger.debug("Found project root at: {} (found {})", currentDir, marker); + return currentDir.toString(); + } + } + + Path parent = currentDir.getParent(); + if (parent == null || parent.equals(currentDir)) { + logger.warn("Could not find project root, using current directory"); + return System.getProperty("user.dir"); + } + currentDir = parent; + } + + return System.getProperty("user.dir"); + } + + /** + * Writes a runtime file for Dev UI discovery. + * + * @param port + * the reflection server port + * @param projectRoot + * the project root directory + * @param runtimeId + * the runtime ID + */ + public static void writeRuntimeFile(int port, String projectRoot, String runtimeId) { + try { + Path runtimesDir = getRuntimesDir(projectRoot); + Files.createDirectories(runtimesDir); + + Path runtimeFile = runtimesDir.resolve(runtimeId + ".json"); + + // Use ISO 8601 format like Go does: 2025-12-21T16:12:32Z + // Replace colons with underscores for filename compatibility + String timestamp = Instant.now().atOffset(ZoneOffset.UTC).format(DateTimeFormatter.ISO_INSTANT).replace(":", + "_"); + + Map runtimeInfo = new HashMap<>(); + runtimeInfo.put("id", runtimeId); + runtimeInfo.put("pid", ProcessHandle.current().pid()); + runtimeInfo.put("reflectionServerUrl", "http://localhost:" + port); + runtimeInfo.put("timestamp", timestamp); + runtimeInfo.put("genkitVersion", "java/1.0.0"); + runtimeInfo.put("reflectionApiSpecVersion", 1); + + String json = objectMapper.writeValueAsString(runtimeInfo); + Files.writeString(runtimeFile, json); + currentRuntimeFile = runtimeFile; + + logger.info("Runtime file written: {}", runtimeFile); + } catch (IOException e) { + logger.error("Failed to write runtime file", e); + } + } + + /** + * Cleans up the runtime file using the current directory as project root. + */ + public static void cleanup() { + if (currentRuntimeFile != null) { + try { + Files.deleteIfExists(currentRuntimeFile); + logger.info("Runtime file removed: {}", currentRuntimeFile); + currentRuntimeFile = null; + } catch (IOException e) { + logger.error("Failed to remove runtime file", e); + } + } else { + removeRuntimeFile(System.getProperty("user.dir")); + } + } + + /** + * Removes the runtime file. + * + * @param projectRoot + * the project root directory + */ + public static void removeRuntimeFile(String projectRoot) { + try { + Path runtimesDir = getRuntimesDir(projectRoot); + long pid = ProcessHandle.current().pid(); + + // Find and delete files matching our PID + if (Files.exists(runtimesDir)) { + try (DirectoryStream stream = Files.newDirectoryStream(runtimesDir, "*.json")) { + for (Path file : stream) { + try { + String content = Files.readString(file); + Map info = objectMapper.readValue(content, Map.class); + if (info.get("pid") != null && ((Number) info.get("pid")).longValue() == pid) { + Files.delete(file); + logger.info("Runtime file removed: {}", file); + } + } catch (Exception e) { + // Ignore files we can't read + } + } + } + } + } catch (IOException e) { + logger.error("Failed to remove runtime file", e); + } + } + + /** + * Gets the runtimes directory path. + */ + private static Path getRuntimesDir(String projectRoot) { + return Paths.get(projectRoot, ".genkit", "runtimes"); + } +} diff --git a/java/genkit/src/main/java/com/google/genkit/prompt/DotPrompt.java b/java/genkit/src/main/java/com/google/genkit/prompt/DotPrompt.java new file mode 100644 index 0000000000..671dff2dc0 --- /dev/null +++ b/java/genkit/src/main/java/com/google/genkit/prompt/DotPrompt.java @@ -0,0 +1,550 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.prompt; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +import com.github.jknack.handlebars.Context; +import com.github.jknack.handlebars.Handlebars; +import com.github.jknack.handlebars.Template; +import com.github.jknack.handlebars.io.StringTemplateSource; +import com.github.jknack.handlebars.io.TemplateLoader; +import com.github.jknack.handlebars.io.TemplateSource; +import com.google.genkit.ai.GenerateOptions; +import com.google.genkit.ai.GenerationConfig; +import com.google.genkit.ai.Model; +import com.google.genkit.ai.ModelRequest; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.ModelResponseChunk; +import com.google.genkit.ai.Prompt; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.Registry; + +/** + * DotPrompt provides support for .prompt files using Handlebars templating. + * + * .prompt files are structured text files with YAML frontmatter containing + * configuration options and a Handlebars template body. + * + * Partials are supported by files starting with underscore (e.g., + * _style.prompt). Partials are automatically loaded when referenced in + * templates. + */ +public class DotPrompt { + + /** Registry of partials (for debugging/introspection). */ + private static final Map registeredPartials = new ConcurrentHashMap<>(); + + /** Custom TemplateLoader that resolves partials from our registry. */ + private static final TemplateLoader partialLoader = new TemplateLoader() { + @Override + public TemplateSource sourceAt(String location) throws IOException { + String partial = registeredPartials.get(location); + if (partial != null) { + return new StringTemplateSource(location, partial); + } + throw new IOException("Partial not found: " + location); + } + + @Override + public String resolve(String location) { + return location; + } + + @Override + public String getPrefix() { + return ""; + } + + @Override + public String getSuffix() { + return ""; + } + + @Override + public void setPrefix(String prefix) { + } + + @Override + public void setSuffix(String suffix) { + } + + @Override + public void setCharset(Charset charset) { + } + + @Override + public Charset getCharset() { + return StandardCharsets.UTF_8; + } + }; + + /** Shared Handlebars instance with registered partials. */ + private static final Handlebars sharedHandlebars = new Handlebars(partialLoader); + + private final String name; + private final String model; + private final String template; + private final Map inputSchema; + private final GenerationConfig config; + private final Handlebars handlebars; + + /** + * Creates a new DotPrompt. + * + * @param name + * the prompt name + * @param model + * the default model name + * @param template + * the Handlebars template + * @param inputSchema + * the input JSON schema + * @param config + * the default generation config + */ + public DotPrompt(String name, String model, String template, Map inputSchema, + GenerationConfig config) { + this.name = name; + this.model = model; + this.template = template; + this.inputSchema = inputSchema; + this.config = config; + this.handlebars = sharedHandlebars; // Use shared instance with registered partials + } + + /** + * Registers a partial template that can be included in other prompts. Partials + * are referenced using {{>partialName}} syntax in templates. + * + * @param name + * the partial name (without underscore prefix or .prompt extension) + * @param source + * the partial template source + * @throws GenkitException + * if registration fails + */ + public static void registerPartial(String name, String source) throws GenkitException { + // Extract just the template body (skip frontmatter if present) + String templateBody = source; + if (source.startsWith("---")) { + int endIndex = source.indexOf("---", 3); + if (endIndex > 0) { + templateBody = source.substring(endIndex + 3).trim(); + } + } + registeredPartials.put(name, templateBody); + } + + /** + * Loads and registers a partial from a resource file. The partial name is + * derived from the filename (without underscore prefix and .prompt extension). + * + * @param resourcePath + * the resource path (e.g., "/prompts/_style.prompt") + * @throws GenkitException + * if loading fails + */ + public static void loadPartialFromResource(String resourcePath) throws GenkitException { + try (InputStream is = DotPrompt.class.getResourceAsStream(resourcePath)) { + if (is == null) { + throw new GenkitException("Partial resource not found: " + resourcePath); + } + String content = new String(is.readAllBytes(), StandardCharsets.UTF_8); + + // Extract partial name from path + String name = resourcePath; + if (name.contains("/")) { + name = name.substring(name.lastIndexOf('/') + 1); + } + if (name.startsWith("_")) { + name = name.substring(1); + } + if (name.endsWith(".prompt")) { + name = name.substring(0, name.length() - 7); + } + + registerPartial(name, content); + } catch (IOException e) { + throw new GenkitException("Failed to load partial resource: " + resourcePath, e); + } + } + + /** + * Returns the names of all registered partials. + * + * @return set of partial names + */ + public static java.util.Set getRegisteredPartialNames() { + return registeredPartials.keySet(); + } + + /** + * Loads a DotPrompt from a resource file. Automatically loads any partials + * referenced in the template from the same directory. Partials should be named + * with underscore prefix (e.g., _style.prompt). + * + * @param + * the input type + * @param resourcePath + * the resource path + * @return the loaded DotPrompt + * @throws GenkitException + * if loading fails + */ + public static DotPrompt loadFromResource(String resourcePath) throws GenkitException { + try (InputStream is = DotPrompt.class.getResourceAsStream(resourcePath)) { + if (is == null) { + throw new GenkitException("Resource not found: " + resourcePath); + } + String content = new String(is.readAllBytes(), StandardCharsets.UTF_8); + + // Get the directory path for loading partials + String directory = resourcePath.contains("/") + ? resourcePath.substring(0, resourcePath.lastIndexOf('/')) + : ""; + + // Auto-load partials referenced in the template + autoLoadPartials(content, directory); + + return parse(resourcePath, content); + } catch (IOException e) { + throw new GenkitException("Failed to load prompt resource: " + resourcePath, e); + } + } + + /** + * Scans template content for partial references ({{>partialName}}) and loads + * them. Partials are loaded from the same directory with underscore prefix. + */ + private static void autoLoadPartials(String content, String directory) { + // Find all partial references: {{>partialName}} or {{> partialName}} + java.util.regex.Pattern pattern = java.util.regex.Pattern.compile("\\{\\{>\\s*([\\w-]+)"); + java.util.regex.Matcher matcher = pattern.matcher(content); + + while (matcher.find()) { + String partialName = matcher.group(1); + + // Skip if already registered + if (registeredPartials.containsKey(partialName)) { + continue; + } + + // Try to load the partial from resource + String partialPath = directory + "/_" + partialName + ".prompt"; + try (InputStream partialIs = DotPrompt.class.getResourceAsStream(partialPath)) { + if (partialIs != null) { + String partialContent = new String(partialIs.readAllBytes(), StandardCharsets.UTF_8); + registerPartial(partialName, partialContent); + } + // If partial not found, Handlebars will report the error when rendering + } catch (IOException e) { + // Ignore - partial loading is best-effort, Handlebars will report if missing + } + } + } + + /** + * Parses a DotPrompt from its string content. + * + * @param + * the input type + * @param name + * the prompt name + * @param content + * the prompt file content + * @return the parsed DotPrompt + * @throws GenkitException + * if parsing fails + */ + public static DotPrompt parse(String name, String content) throws GenkitException { + // Split frontmatter from template + String template = content; + String model = null; + Map inputSchema = null; + GenerationConfig config = null; + + if (content.startsWith("---")) { + int endIndex = content.indexOf("---", 3); + if (endIndex > 0) { + String frontmatter = content.substring(3, endIndex).trim(); + template = content.substring(endIndex + 3).trim(); + + // Simple YAML parsing for common fields + for (String line : frontmatter.split("\n")) { + line = line.trim(); + if (line.startsWith("model:")) { + model = line.substring(6).trim(); + } + } + } + } + + // Clean up the name (remove extension) + if (name.endsWith(".prompt")) { + name = name.substring(0, name.length() - 7); + } + if (name.contains("/")) { + name = name.substring(name.lastIndexOf('/') + 1); + } + + return new DotPrompt<>(name, model, template, inputSchema, config); + } + + /** + * Renders the prompt with the given input. + * + * @param input + * the input data + * @return the rendered prompt text + * @throws GenkitException + * if rendering fails + */ + public String render(I input) throws GenkitException { + try { + Template compiledTemplate = handlebars.compileInline(template); + Context context = Context.newBuilder(input).build(); + return compiledTemplate.apply(context); + } catch (IOException e) { + throw new GenkitException("Failed to render prompt template", e); + } + } + + /** + * Renders the prompt and creates a ModelRequest. + * + * @param input + * the input data + * @return the model request + * @throws GenkitException + * if rendering fails + */ + public ModelRequest toModelRequest(I input) throws GenkitException { + String rendered = render(input); + + ModelRequest.Builder builder = ModelRequest.builder().addUserMessage(rendered); + + if (config != null) { + // Convert GenerationConfig to Map + Map configMap = new HashMap<>(); + if (config.getTemperature() != null) { + configMap.put("temperature", config.getTemperature()); + } + if (config.getMaxOutputTokens() != null) { + configMap.put("maxOutputTokens", config.getMaxOutputTokens()); + } + if (config.getTopP() != null) { + configMap.put("topP", config.getTopP()); + } + if (config.getTopK() != null) { + configMap.put("topK", config.getTopK()); + } + // Include custom config for model-specific options + if (config.getCustom() != null) { + configMap.putAll(config.getCustom()); + } + builder.config(configMap); + } + + return builder.build(); + } + + /** + * Creates a Prompt action from this DotPrompt. + * + * @param inputClass + * the input class + * @return the Prompt action + */ + public Prompt toPrompt(Class inputClass) { + return Prompt.builder().name(name).model(model).template(template).inputSchema(inputSchema).config(config) + .inputClass(inputClass).renderer((ctx, input) -> toModelRequest(input)).build(); + } + + /** + * Registers this DotPrompt as an action. + * + * @param registry + * the registry + * @param inputClass + * the input class + */ + public void register(Registry registry, Class inputClass) { + Prompt prompt = toPrompt(inputClass); + prompt.register(registry); + } + + /** + * Generates a response using this prompt with the given registry. + * + *

+ * This method allows generating directly from a DotPrompt without needing to go + * through ExecutablePrompt. The model is looked up from the registry using the + * model name specified in the prompt. + * + * @param registry + * the registry to look up the model + * @param input + * the prompt input + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse generate(Registry registry, I input) throws GenkitException { + return generate(registry, input, null, null); + } + + /** + * Generates a response using this prompt with custom options. + * + * @param registry + * the registry to look up the model + * @param input + * the prompt input + * @param options + * optional generation options to override prompt defaults + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse generate(Registry registry, I input, GenerateOptions options) throws GenkitException { + return generate(registry, input, options, null); + } + + /** + * Generates a response using this prompt with streaming. + * + * @param registry + * the registry to look up the model + * @param input + * the prompt input + * @param options + * optional generation options + * @param streamCallback + * callback for streaming chunks + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse generate(Registry registry, I input, GenerateOptions options, + Consumer streamCallback) throws GenkitException { + ModelRequest request = toModelRequest(input); + String modelName = resolveModel(options); + + Model modelAction = getModel(registry, modelName); + ActionContext ctx = new ActionContext(registry); + + if (options != null && options.getConfig() != null) { + request = mergeConfig(request, options); + } + + if (streamCallback != null) { + return modelAction.run(ctx, request, streamCallback); + } else { + return modelAction.run(ctx, request); + } + } + + // Private helper methods for generation + + private String resolveModel(GenerateOptions options) { + if (options != null && options.getModel() != null && !options.getModel().isEmpty()) { + return options.getModel(); + } + if (model == null || model.isEmpty()) { + throw new GenkitException("No model specified in prompt or options"); + } + return model; + } + + private Model getModel(Registry registry, String modelName) { + Action action = registry.lookupAction(ActionType.MODEL, modelName); + if (action == null) { + String key = ActionType.MODEL.keyFromName(modelName); + action = registry.lookupAction(key); + } + if (action == null) { + throw new GenkitException("Model not found: " + modelName); + } + if (!(action instanceof Model)) { + throw new GenkitException("Action is not a model: " + modelName); + } + return (Model) action; + } + + private ModelRequest mergeConfig(ModelRequest request, GenerateOptions options) { + GenerationConfig optionsConfig = options.getConfig(); + if (optionsConfig == null) { + return request; + } + + Map configMap = new HashMap<>(); + if (request.getConfig() != null) { + configMap.putAll(request.getConfig()); + } + + if (optionsConfig.getTemperature() != null) { + configMap.put("temperature", optionsConfig.getTemperature()); + } + if (optionsConfig.getMaxOutputTokens() != null) { + configMap.put("maxOutputTokens", optionsConfig.getMaxOutputTokens()); + } + if (optionsConfig.getTopP() != null) { + configMap.put("topP", optionsConfig.getTopP()); + } + if (optionsConfig.getTopK() != null) { + configMap.put("topK", optionsConfig.getTopK()); + } + + return ModelRequest.builder().messages(request.getMessages()).config(configMap).tools(request.getTools()) + .output(request.getOutput()).build(); + } + + // Getters + + public String getName() { + return name; + } + + public String getModel() { + return model; + } + + public String getTemplate() { + return template; + } + + public Map getInputSchema() { + return inputSchema; + } + + public GenerationConfig getConfig() { + return config; + } +} diff --git a/java/genkit/src/main/java/com/google/genkit/prompt/ExecutablePrompt.java b/java/genkit/src/main/java/com/google/genkit/prompt/ExecutablePrompt.java new file mode 100644 index 0000000000..4bb974387c --- /dev/null +++ b/java/genkit/src/main/java/com/google/genkit/prompt/ExecutablePrompt.java @@ -0,0 +1,376 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.prompt; + +import java.util.Map; +import java.util.function.Consumer; + +import com.google.genkit.ai.*; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.Registry; + +/** + * ExecutablePrompt wraps a DotPrompt and provides direct generation + * capabilities. + * + *

+ * This class allows prompts to be called directly for generation, similar to + * the JavaScript API: `const response = await helloPrompt({ name: 'John' });` + * + *

+ * In Java, this becomes: + * + *

{@code
+ * ExecutablePrompt helloPrompt = genkit.prompt("hello", HelloInput.class);
+ * ModelResponse response = helloPrompt.generate(new HelloInput("John"));
+ * }
+ * + *

+ * Or for streaming: + * + *

{@code
+ * helloPrompt.stream(input, chunk -> System.out.println(chunk.getText()));
+ * }
+ * + * @param + * the input type for the prompt + */ +public class ExecutablePrompt { + + private final DotPrompt dotPrompt; + private final Registry registry; + private final Class inputClass; + private GenerateFunction generateFunction; + + /** + * Functional interface for the generate function. This allows ExecutablePrompt + * to use Genkit.generate() for tool/interrupt support. + */ + @FunctionalInterface + public interface GenerateFunction { + ModelResponse generate(GenerateOptions options) throws GenkitException; + } + + /** + * Creates a new ExecutablePrompt. + * + * @param dotPrompt + * the underlying DotPrompt + * @param registry + * the Genkit registry + * @param inputClass + * the input class for type checking + */ + public ExecutablePrompt(DotPrompt dotPrompt, Registry registry, Class inputClass) { + this.dotPrompt = dotPrompt; + this.registry = registry; + this.inputClass = inputClass; + } + + /** + * Sets the generate function to use Genkit.generate() for tool/interrupt + * support. + * + * @param generateFunction + * the generate function + * @return this for chaining + */ + public ExecutablePrompt withGenerateFunction(GenerateFunction generateFunction) { + this.generateFunction = generateFunction; + return this; + } + + /** + * Generates a response using the default model specified in the prompt. + * + * @param input + * the prompt input + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse generate(I input) throws GenkitException { + return generate(input, null); + } + + /** + * Generates a response with custom options. + * + *

+ * If a generateFunction is set (via Genkit), this uses Genkit.generate() which + * supports tools and interrupts. Otherwise, it calls the model directly. + * + * @param input + * the prompt input + * @param options + * optional generation options to override prompt defaults + * @return the model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse generate(I input, GenerateOptions options) throws GenkitException { + ModelRequest request = dotPrompt.toModelRequest(input); + String modelName = resolveModel(options); + + // If we have a generate function (from Genkit), use it for tool/interrupt + // support + if (generateFunction != null) { + GenerateOptions.Builder genOptions = GenerateOptions.builder().model(modelName) + .messages(request.getMessages()); + + // Add system message if present + if (request.getMessages() != null && !request.getMessages().isEmpty()) { + Message systemMsg = request.getMessages().stream().filter(m -> m.getRole() == Role.SYSTEM).findFirst() + .orElse(null); + if (systemMsg != null && systemMsg.getContent() != null && !systemMsg.getContent().isEmpty()) { + genOptions.system(systemMsg.getContent().get(0).getText()); + } + } + + // Add tools if present in options + if (options != null && options.getTools() != null) { + genOptions.tools(options.getTools()); + } + + // Add resume options if present + if (options != null && options.getResume() != null) { + genOptions.resume(options.getResume()); + } + + // Add config if present + if (options != null && options.getConfig() != null) { + genOptions.config(options.getConfig()); + } else if (dotPrompt.getConfig() != null) { + genOptions.config(dotPrompt.getConfig()); + } + + return generateFunction.generate(genOptions.build()); + } + + // Fall back to direct model call (no tool/interrupt support) + Model model = getModel(modelName); + ActionContext ctx = new ActionContext(registry); + + // Merge generation config from options if provided + if (options != null && options.getConfig() != null) { + // Override with options config + request = mergeConfig(request, options); + } + + return model.run(ctx, request); + } + + /** + * Generates a response with streaming. + * + * @param input + * the prompt input + * @param streamCallback + * callback for streaming chunks + * @return the final model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse stream(I input, Consumer streamCallback) throws GenkitException { + return stream(input, null, streamCallback); + } + + /** + * Generates a response with streaming and custom options. + * + * @param input + * the prompt input + * @param options + * optional generation options + * @param streamCallback + * callback for streaming chunks + * @return the final model response + * @throws GenkitException + * if generation fails + */ + public ModelResponse stream(I input, GenerateOptions options, Consumer streamCallback) + throws GenkitException { + ModelRequest request = dotPrompt.toModelRequest(input); + String modelName = resolveModel(options); + + Model model = getModel(modelName); + ActionContext ctx = new ActionContext(registry); + + if (options != null && options.getConfig() != null) { + request = mergeConfig(request, options); + } + + return model.run(ctx, request, streamCallback); + } + + /** + * Renders the prompt template without generating. + * + * @param input + * the prompt input + * @return the rendered prompt text + * @throws GenkitException + * if rendering fails + */ + public String render(I input) throws GenkitException { + return dotPrompt.render(input); + } + + /** + * Gets the ModelRequest that would be sent to the model. + * + * @param input + * the prompt input + * @return the model request + * @throws GenkitException + * if conversion fails + */ + public ModelRequest toModelRequest(I input) throws GenkitException { + return dotPrompt.toModelRequest(input); + } + + /** + * Converts this executable prompt to a Prompt action. + * + * @return the Prompt action + */ + public Prompt toPrompt() { + return dotPrompt.toPrompt(inputClass); + } + + /** + * Registers this prompt as an action in the registry. + */ + public void register() { + dotPrompt.register(registry, inputClass); + } + + /** + * Gets the underlying DotPrompt. + * + * @return the DotPrompt + */ + public DotPrompt getDotPrompt() { + return dotPrompt; + } + + /** + * Gets the prompt name. + * + * @return the name + */ + public String getName() { + return dotPrompt.getName(); + } + + /** + * Gets the default model name. + * + * @return the model name + */ + public String getModel() { + return dotPrompt.getModel(); + } + + /** + * Gets the template. + * + * @return the template + */ + public String getTemplate() { + return dotPrompt.getTemplate(); + } + + /** + * Gets the generation config. + * + * @return the config + */ + public GenerationConfig getConfig() { + return dotPrompt.getConfig(); + } + + // Private helper methods + + private String resolveModel(GenerateOptions options) { + // Options model takes precedence + if (options != null && options.getModel() != null && !options.getModel().isEmpty()) { + return options.getModel(); + } + // Fall back to prompt's default model + String model = dotPrompt.getModel(); + if (model == null || model.isEmpty()) { + throw new GenkitException("No model specified in prompt or options"); + } + return model; + } + + private Model getModel(String modelName) { + // Try direct lookup first + com.google.genkit.core.Action action = registry.lookupAction(com.google.genkit.core.ActionType.MODEL, + modelName); + + if (action == null) { + // Try with model/ prefix + String key = com.google.genkit.core.ActionType.MODEL.keyFromName(modelName); + action = registry.lookupAction(key); + } + + if (action == null) { + throw new GenkitException("Model not found: " + modelName); + } + + if (!(action instanceof Model)) { + throw new GenkitException("Action is not a model: " + modelName); + } + + return (Model) action; + } + + private ModelRequest mergeConfig(ModelRequest request, GenerateOptions options) { + GenerationConfig optionsConfig = options.getConfig(); + if (optionsConfig == null) { + return request; + } + + // Build new config map merging prompt config with options config + Map configMap = new java.util.HashMap<>(); + if (request.getConfig() != null) { + configMap.putAll(request.getConfig()); + } + + // Override with options config + if (optionsConfig.getTemperature() != null) { + configMap.put("temperature", optionsConfig.getTemperature()); + } + if (optionsConfig.getMaxOutputTokens() != null) { + configMap.put("maxOutputTokens", optionsConfig.getMaxOutputTokens()); + } + if (optionsConfig.getTopP() != null) { + configMap.put("topP", optionsConfig.getTopP()); + } + if (optionsConfig.getTopK() != null) { + configMap.put("topK", optionsConfig.getTopK()); + } + + return ModelRequest.builder().messages(request.getMessages()).config(configMap).tools(request.getTools()) + .output(request.getOutput()).build(); + } +} diff --git a/java/genkit/src/test/java/com/google/genkit/GenkitOptionsTest.java b/java/genkit/src/test/java/com/google/genkit/GenkitOptionsTest.java new file mode 100644 index 0000000000..ebed968902 --- /dev/null +++ b/java/genkit/src/test/java/com/google/genkit/GenkitOptionsTest.java @@ -0,0 +1,134 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for GenkitOptions. + */ +class GenkitOptionsTest { + + @Test + void testDefaultBuilder() { + GenkitOptions options = GenkitOptions.builder().build(); + + assertNotNull(options); + // Default values may depend on environment, so just check not null + assertNotNull(options.getProjectRoot()); + } + + @Test + void testDevMode() { + GenkitOptions options = GenkitOptions.builder().devMode(true).build(); + + assertTrue(options.isDevMode()); + + options = GenkitOptions.builder().devMode(false).build(); + + assertFalse(options.isDevMode()); + } + + @Test + void testReflectionPort() { + GenkitOptions options = GenkitOptions.builder().reflectionPort(5000).build(); + + assertEquals(5000, options.getReflectionPort()); + } + + @Test + void testProjectRoot() { + GenkitOptions options = GenkitOptions.builder().projectRoot("/custom/project/root").build(); + + assertEquals("/custom/project/root", options.getProjectRoot()); + } + + @Test + void testPromptDir() { + GenkitOptions options = GenkitOptions.builder().promptDir("/custom/prompts").build(); + + assertEquals("/custom/prompts", options.getPromptDir()); + } + + @Test + void testDefaultPromptDir() { + GenkitOptions options = GenkitOptions.builder().build(); + + assertEquals("/prompts", options.getPromptDir()); + } + + @Test + void testBuilderChaining() { + GenkitOptions options = GenkitOptions.builder().devMode(true).reflectionPort(4321).projectRoot("/my/project") + .promptDir("/my/prompts").build(); + + assertTrue(options.isDevMode()); + assertEquals(4321, options.getReflectionPort()); + assertEquals("/my/project", options.getProjectRoot()); + assertEquals("/my/prompts", options.getPromptDir()); + } + + @Test + void testMultipleBuildCalls() { + GenkitOptions.Builder builder = GenkitOptions.builder().devMode(true).reflectionPort(3000); + + GenkitOptions options1 = builder.build(); + GenkitOptions options2 = builder.build(); + + // Both should have same values + assertEquals(options1.isDevMode(), options2.isDevMode()); + assertEquals(options1.getReflectionPort(), options2.getReflectionPort()); + } + + @Test + void testBuilderModificationAfterBuild() { + GenkitOptions.Builder builder = GenkitOptions.builder().devMode(true); + + GenkitOptions options1 = builder.build(); + + builder.devMode(false); + GenkitOptions options2 = builder.build(); + + // Options1 should not be affected + assertTrue(options1.isDevMode()); + assertFalse(options2.isDevMode()); + } + + @Test + void testDifferentPortValues() { + for (int port : new int[]{0, 1, 1024, 3000, 8080, 65535}) { + GenkitOptions options = GenkitOptions.builder().reflectionPort(port).build(); + + assertEquals(port, options.getReflectionPort()); + } + } + + @Test + void testProjectRootVariations() { + String[] paths = {"/", "/usr/local", "C:\\Users\\test", "relative/path", "./current", "../parent"}; + + for (String path : paths) { + GenkitOptions options = GenkitOptions.builder().projectRoot(path).build(); + + assertEquals(path, options.getProjectRoot()); + } + } +} diff --git a/java/genkit/src/test/java/com/google/genkit/GenkitTest.java b/java/genkit/src/test/java/com/google/genkit/GenkitTest.java new file mode 100644 index 0000000000..4156adbab1 --- /dev/null +++ b/java/genkit/src/test/java/com/google/genkit/GenkitTest.java @@ -0,0 +1,188 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.Flow; + +/** + * Unit tests for Genkit. + */ +class GenkitTest { + + @Test + void testDefaultConstructor() { + Genkit genkit = new Genkit(); + assertNotNull(genkit); + } + + @Test + void testConstructorWithOptions() { + GenkitOptions options = GenkitOptions.builder().devMode(false).reflectionPort(4000).build(); + + Genkit genkit = new Genkit(options); + assertNotNull(genkit); + } + + @Test + void testBuilder() { + Genkit genkit = Genkit.builder().build(); + assertNotNull(genkit); + } + + @Test + void testDefineFlowWithBiFunction() { + Genkit genkit = new Genkit(); + + Flow flow = genkit.defineFlow("echoFlow", String.class, String.class, + (ctx, input) -> "Echo: " + input); + + assertNotNull(flow); + assertEquals("echoFlow", flow.getName()); + assertEquals(ActionType.FLOW, flow.getType()); + } + + @Test + void testDefineFlowWithFunction() { + Genkit genkit = new Genkit(); + + Flow flow = genkit.defineFlow("lengthFlow", String.class, Integer.class, String::length); + + assertNotNull(flow); + assertEquals("lengthFlow", flow.getName()); + } + + @Test + void testFlowExecution() { + Genkit genkit = new Genkit(); + + Flow flow = genkit.defineFlow("upperCaseFlow", String.class, String.class, + input -> input.toUpperCase()); + + ActionContext ctx = new ActionContext(null, null, null); + String result = flow.run(ctx, "hello"); + + assertEquals("HELLO", result); + } + + @Test + void testMultipleFlows() { + Genkit genkit = new Genkit(); + + Flow flow1 = genkit.defineFlow("flow1", String.class, String.class, s -> s); + Flow flow2 = genkit.defineFlow("flow2", Integer.class, String.class, + i -> String.valueOf(i)); + + assertNotNull(flow1); + assertNotNull(flow2); + assertEquals("flow1", flow1.getName()); + assertEquals("flow2", flow2.getName()); + } + + @Test + void testFlowWithComplexInput() { + Genkit genkit = new Genkit(); + + Flow flow = genkit.defineFlow("complexFlow", TestInput.class, TestOutput.class, + input -> new TestOutput(input.getName(), input.getValue() * 2)); + + assertNotNull(flow); + + ActionContext ctx = new ActionContext(null, null, null); + TestOutput result = flow.run(ctx, new TestInput("test", 21)); + + assertEquals("test", result.getName()); + assertEquals(42, result.getValue()); + } + + @Test + void testGetRegistry() { + Genkit genkit = new Genkit(); + assertNotNull(genkit.getRegistry()); + } + + /** + * Test input class. + */ + static class TestInput { + private String name; + private int value; + + public TestInput() { + } + + public TestInput(String name, int value) { + this.name = name; + this.value = value; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getValue() { + return value; + } + + public void setValue(int value) { + this.value = value; + } + } + + /** + * Test output class. + */ + static class TestOutput { + private String name; + private int value; + + public TestOutput() { + } + + public TestOutput(String name, int value) { + this.name = name; + this.value = value; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getValue() { + return value; + } + + public void setValue(int value) { + this.value = value; + } + } +} diff --git a/java/plugins/google-genai/pom.xml b/java/plugins/google-genai/pom.xml new file mode 100644 index 0000000000..f0279ca48a --- /dev/null +++ b/java/plugins/google-genai/pom.xml @@ -0,0 +1,70 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + genkit-plugin-google-genai + jar + Genkit Google GenAI Plugin + Google GenAI (Gemini) model plugin for Genkit using the official google-genai SDK + + + + com.google.genkit + genkit + ${project.version} + + + + + com.google.genai + google-genai + 1.32.0 + + + + + com.fasterxml.jackson.core + jackson-databind + + + + + org.slf4j + slf4j-api + + + + + org.junit.jupiter + junit-jupiter + test + + + diff --git a/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GeminiEmbedder.java b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GeminiEmbedder.java new file mode 100644 index 0000000000..423625fa46 --- /dev/null +++ b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GeminiEmbedder.java @@ -0,0 +1,178 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.googlegenai; + +import java.util.ArrayList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genai.Client; +import com.google.genai.types.ContentEmbedding; +import com.google.genai.types.EmbedContentConfig; +import com.google.genai.types.EmbedContentResponse; +import com.google.genai.types.HttpOptions; +import com.google.genkit.ai.EmbedRequest; +import com.google.genkit.ai.EmbedResponse; +import com.google.genkit.ai.Embedder; +import com.google.genkit.ai.EmbedderInfo; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** + * Gemini embedder implementation using the official Google GenAI SDK. + */ +public class GeminiEmbedder extends Embedder { + + private static final Logger logger = LoggerFactory.getLogger(GeminiEmbedder.class); + + private final String modelName; + private final GoogleGenAIPluginOptions options; + private final Client client; + + /** + * Creates a new GeminiEmbedder. + * + * @param modelName + * the embedding model name (e.g., "text-embedding-004", + * "gemini-embedding-001") + * @param options + * the plugin options + */ + public GeminiEmbedder(String modelName, GoogleGenAIPluginOptions options) { + super("googleai/" + modelName, createEmbedderInfo(modelName), (ctx, req) -> { + throw new GenkitException("Handler not initialized"); + }); + this.modelName = modelName; + this.options = options; + this.client = createClient(); + } + + private Client createClient() { + Client.Builder builder = Client.builder(); + + if (options.isVertexAI()) { + builder.vertexAI(true); + if (options.getProject() != null) { + builder.project(options.getProject()); + } + if (options.getLocation() != null) { + builder.location(options.getLocation()); + } + if (options.getApiKey() != null) { + builder.apiKey(options.getApiKey()); + } + } else { + builder.apiKey(options.getApiKey()); + } + + HttpOptions httpOptions = options.toHttpOptions(); + if (httpOptions != null) { + builder.httpOptions(httpOptions); + } + + return builder.build(); + } + + private static EmbedderInfo createEmbedderInfo(String modelName) { + EmbedderInfo info = new EmbedderInfo(); + info.setLabel("Google AI " + modelName); + + // Default dimensions for Gemini embedding models + switch (modelName) { + case "text-embedding-004" : + case "text-embedding-005" : + info.setDimensions(768); + break; + case "gemini-embedding-001" : + info.setDimensions(768); + break; + case "text-multilingual-embedding-002" : + info.setDimensions(768); + break; + default : + info.setDimensions(768); // Default + } + + return info; + } + + @Override + public EmbedResponse run(ActionContext context, EmbedRequest request) { + if (request == null) { + throw new GenkitException("Embed request is required."); + } + if (request.getDocuments() == null || request.getDocuments().isEmpty()) { + throw new GenkitException("Embed request must contain at least one document."); + } + + try { + return callGeminiEmbed(request); + } catch (Exception e) { + throw new GenkitException("Gemini Embedding API call failed: " + e.getMessage(), e); + } + } + + private EmbedResponse callGeminiEmbed(EmbedRequest request) { + List embeddings = new ArrayList<>(); + + for (com.google.genkit.ai.Document doc : request.getDocuments()) { + String text = doc.text(); + if (text == null || text.isEmpty()) { + logger.warn("Document has empty text, skipping"); + continue; + } + + // Build embed config + EmbedContentConfig.Builder configBuilder = EmbedContentConfig.builder(); + + // Apply options from config + if (request.getOptions() != null) { + if (request.getOptions().containsKey("taskType")) { + String taskType = (String) request.getOptions().get("taskType"); + configBuilder.taskType(taskType); + } + if (request.getOptions().containsKey("title")) { + configBuilder.title((String) request.getOptions().get("title")); + } + if (request.getOptions().containsKey("outputDimensionality")) { + configBuilder.outputDimensionality( + ((Number) request.getOptions().get("outputDimensionality")).intValue()); + } + } + + EmbedContentResponse response = client.models.embedContent(modelName, text, configBuilder.build()); + + if (response.embeddings().isPresent() && !response.embeddings().get().isEmpty()) { + ContentEmbedding embedding = response.embeddings().get().get(0); + if (embedding.values().isPresent()) { + List values = embedding.values().get(); + float[] floatValues = new float[values.size()]; + for (int i = 0; i < values.size(); i++) { + floatValues[i] = values.get(i); + } + embeddings.add(new EmbedResponse.Embedding(floatValues)); + } + } + } + + return new EmbedResponse(embeddings); + } +} diff --git a/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GeminiModel.java b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GeminiModel.java new file mode 100644 index 0000000000..f9f7396004 --- /dev/null +++ b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GeminiModel.java @@ -0,0 +1,590 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.googlegenai; + +import java.util.*; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genai.Client; +import com.google.genai.ResponseStream; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.HarmBlockThreshold; +import com.google.genai.types.HarmCategory; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.SafetySetting; +import com.google.genai.types.Schema; +import com.google.genai.types.ThinkingConfig; +import com.google.genai.types.Type; +import com.google.genkit.ai.Media; +import com.google.genkit.ai.Message; +import com.google.genkit.ai.Model; +import com.google.genkit.ai.ModelInfo; +import com.google.genkit.ai.ModelRequest; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.ModelResponseChunk; +import com.google.genkit.ai.Role; +import com.google.genkit.ai.ToolDefinition; +import com.google.genkit.ai.ToolRequest; +import com.google.genkit.ai.ToolResponse; +import com.google.genkit.ai.Usage; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** + * Gemini model implementation using the official Google GenAI SDK. + */ +public class GeminiModel implements Model { + + private static final Logger logger = LoggerFactory.getLogger(GeminiModel.class); + + private final String modelName; + private final GoogleGenAIPluginOptions options; + private final Client client; + private final ObjectMapper objectMapper; + private final ModelInfo info; + + /** + * Creates a new GeminiModel. + * + * @param modelName + * the model name (e.g., "gemini-2.0-flash", "gemini-2.5-pro") + * @param options + * the plugin options + */ + public GeminiModel(String modelName, GoogleGenAIPluginOptions options) { + this.modelName = modelName; + this.options = options; + this.objectMapper = new ObjectMapper(); + this.client = createClient(); + this.info = createModelInfo(); + } + + private Client createClient() { + Client.Builder builder = Client.builder(); + + if (options.isVertexAI()) { + builder.vertexAI(true); + if (options.getProject() != null) { + builder.project(options.getProject()); + } + if (options.getLocation() != null) { + builder.location(options.getLocation()); + } + // Vertex AI can also use API key for express mode + if (options.getApiKey() != null) { + builder.apiKey(options.getApiKey()); + } + } else { + builder.apiKey(options.getApiKey()); + } + + // Apply HTTP options if configured + HttpOptions httpOptions = options.toHttpOptions(); + if (httpOptions != null) { + builder.httpOptions(httpOptions); + } + + return builder.build(); + } + + private ModelInfo createModelInfo() { + ModelInfo info = new ModelInfo(); + info.setLabel("Google AI " + modelName); + + ModelInfo.ModelCapabilities caps = new ModelInfo.ModelCapabilities(); + caps.setMultiturn(true); + caps.setMedia(true); // Gemini models support multimodal + caps.setTools(!isTTSModel()); // TTS models don't support tools + caps.setSystemRole(!isTTSModel()); + caps.setOutput(Set.of("text", "json")); + info.setSupports(caps); + + return info; + } + + private boolean isTTSModel() { + return modelName.endsWith("-tts"); + } + + private boolean isImageModel() { + return modelName.contains("-image"); + } + + @Override + public String getName() { + return "googleai/" + modelName; + } + + @Override + public ModelInfo getInfo() { + return info; + } + + @Override + public boolean supportsStreaming() { + return true; + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request) { + try { + return callGemini(request); + } catch (Exception e) { + throw new GenkitException("Gemini API call failed: " + e.getMessage(), e); + } + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request, Consumer streamCallback) { + if (streamCallback == null) { + return run(context, request); + } + try { + return callGeminiStreaming(request, streamCallback); + } catch (Exception e) { + throw new GenkitException("Gemini streaming API call failed: " + e.getMessage(), e); + } + } + + private ModelResponse callGemini(ModelRequest request) { + GenerateContentConfig config = buildConfig(request); + List contents = buildContents(request); + + GenerateContentResponse response = client.models.generateContent(modelName, contents, config); + + return parseResponse(response); + } + + private ModelResponse callGeminiStreaming(ModelRequest request, Consumer streamCallback) { + GenerateContentConfig config = buildConfig(request); + List contents = buildContents(request); + + StringBuilder fullContent = new StringBuilder(); + List toolRequests = new ArrayList<>(); + String finishReason = null; + + ResponseStream responseStream = client.models.generateContentStream(modelName, + contents, config); + + try { + for (GenerateContentResponse chunk : responseStream) { + String text = chunk.text(); + if (text != null && !text.isEmpty()) { + fullContent.append(text); + ModelResponseChunk responseChunk = ModelResponseChunk.text(text); + streamCallback.accept(responseChunk); + } + + // Handle tool calls in streaming + if (chunk.candidates().isPresent()) { + for (com.google.genai.types.Candidate candidate : chunk.candidates().get()) { + if (candidate.finishReason().isPresent()) { + finishReason = candidate.finishReason().get().toString(); + } + if (candidate.content().isPresent()) { + Content candidateContent = candidate.content().get(); + if (candidateContent.parts().isPresent()) { + for (com.google.genai.types.Part part : candidateContent.parts().get()) { + if (part.functionCall().isPresent()) { + FunctionCall fc = part.functionCall().get(); + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName(fc.name().orElse("")); + if (fc.args().isPresent()) { + toolRequest.setInput(fc.args().get()); + } + toolRequests.add(toolRequest); + } + } + } + } + } + } + } + } finally { + responseStream.close(); + } + + // Build final response + ModelResponse response = new ModelResponse(); + List candidates = new ArrayList<>(); + com.google.genkit.ai.Candidate candidate = new com.google.genkit.ai.Candidate(); + + Message message = new Message(); + message.setRole(Role.MODEL); + List parts = new ArrayList<>(); + + if (fullContent.length() > 0) { + com.google.genkit.ai.Part textPart = new com.google.genkit.ai.Part(); + textPart.setText(fullContent.toString()); + parts.add(textPart); + } + + for (ToolRequest toolRequest : toolRequests) { + com.google.genkit.ai.Part toolPart = new com.google.genkit.ai.Part(); + toolPart.setToolRequest(toolRequest); + parts.add(toolPart); + } + + message.setContent(parts); + candidate.setMessage(message); + candidate.setFinishReason(mapFinishReason(finishReason)); + + candidates.add(candidate); + response.setCandidates(candidates); + + return response; + } + + private GenerateContentConfig buildConfig(ModelRequest request) { + GenerateContentConfig.Builder configBuilder = GenerateContentConfig.builder(); + + // System instruction + Message systemMessage = findSystemMessage(request); + if (systemMessage != null) { + Content systemContent = Content + .fromParts(com.google.genai.types.Part.fromText(getTextFromMessage(systemMessage))); + configBuilder.systemInstruction(systemContent); + } + + // Generation config from request + Map config = request.getConfig(); + if (config != null) { + if (config.containsKey("temperature")) { + configBuilder.temperature(((Number) config.get("temperature")).floatValue()); + } + if (config.containsKey("maxOutputTokens")) { + configBuilder.maxOutputTokens(((Number) config.get("maxOutputTokens")).intValue()); + } + if (config.containsKey("topP")) { + configBuilder.topP(((Number) config.get("topP")).floatValue()); + } + if (config.containsKey("topK")) { + configBuilder.topK(Float.valueOf(((Number) config.get("topK")).floatValue())); + } + if (config.containsKey("stopSequences")) { + @SuppressWarnings("unchecked") + List stopSequences = (List) config.get("stopSequences"); + configBuilder.stopSequences(stopSequences); + } + if (config.containsKey("candidateCount")) { + configBuilder.candidateCount(((Number) config.get("candidateCount")).intValue()); + } + + // Safety settings + if (config.containsKey("safetySettings")) { + @SuppressWarnings("unchecked") + List> safetySettingsConfig = (List>) config + .get("safetySettings"); + List safetySettings = new ArrayList<>(); + for (Map setting : safetySettingsConfig) { + safetySettings + .add(SafetySetting.builder().category(HarmCategory.Known.valueOf(setting.get("category"))) + .threshold(HarmBlockThreshold.Known.valueOf(setting.get("threshold"))).build()); + } + configBuilder.safetySettings(safetySettings); + } + + // Thinking config for Gemini 2.5+ + if (config.containsKey("thinkingConfig")) { + @SuppressWarnings("unchecked") + Map thinkingConfig = (Map) config.get("thinkingConfig"); + ThinkingConfig.Builder thinkingBuilder = ThinkingConfig.builder(); + if (thinkingConfig.containsKey("thinkingBudget")) { + thinkingBuilder.thinkingBudget(((Number) thinkingConfig.get("thinkingBudget")).intValue()); + } + configBuilder.thinkingConfig(thinkingBuilder); + } + + // JSON response schema + if (config.containsKey("responseSchema")) { + configBuilder.responseMimeType("application/json"); + @SuppressWarnings("unchecked") + Map schemaMap = (Map) config.get("responseSchema"); + configBuilder.responseSchema(convertToSchema(schemaMap)); + } + } + + // Tools + if (request.getTools() != null && !request.getTools().isEmpty()) { + List tools = new ArrayList<>(); + for (ToolDefinition toolDef : request.getTools()) { + FunctionDeclaration.Builder funcBuilder = FunctionDeclaration.builder().name(toolDef.getName()) + .description(toolDef.getDescription() != null ? toolDef.getDescription() : ""); + + if (toolDef.getInputSchema() != null) { + funcBuilder.parameters(convertToSchema(toolDef.getInputSchema())); + } + + tools.add(com.google.genai.types.Tool.builder().functionDeclarations(funcBuilder.build()).build()); + } + configBuilder.tools(tools); + } + + return configBuilder.build(); + } + + private Schema convertToSchema(Map inputSchema) { + Schema.Builder schemaBuilder = Schema.builder(); + + if (inputSchema.containsKey("type")) { + String type = (String) inputSchema.get("type"); + schemaBuilder.type(Type.Known.valueOf(type.toUpperCase())); + } + + if (inputSchema.containsKey("description")) { + schemaBuilder.description((String) inputSchema.get("description")); + } + + if (inputSchema.containsKey("properties")) { + @SuppressWarnings("unchecked") + Map properties = (Map) inputSchema.get("properties"); + Map schemaProperties = new HashMap<>(); + for (Map.Entry entry : properties.entrySet()) { + @SuppressWarnings("unchecked") + Map propSchema = (Map) entry.getValue(); + schemaProperties.put(entry.getKey(), convertToSchema(propSchema)); + } + schemaBuilder.properties(schemaProperties); + } + + if (inputSchema.containsKey("required")) { + @SuppressWarnings("unchecked") + List required = (List) inputSchema.get("required"); + schemaBuilder.required(required); + } + + if (inputSchema.containsKey("items")) { + @SuppressWarnings("unchecked") + Map items = (Map) inputSchema.get("items"); + schemaBuilder.items(convertToSchema(items)); + } + + if (inputSchema.containsKey("enum")) { + @SuppressWarnings("unchecked") + List enumValues = (List) inputSchema.get("enum"); + schemaBuilder.enum_(enumValues); + } + + return schemaBuilder.build(); + } + + private List buildContents(ModelRequest request) { + List contents = new ArrayList<>(); + + for (Message message : request.getMessages()) { + // Skip system messages (handled separately in config) + if (message.getRole() == Role.SYSTEM) { + continue; + } + + List parts = new ArrayList<>(); + + for (com.google.genkit.ai.Part part : message.getContent()) { + if (part.getText() != null) { + parts.add(com.google.genai.types.Part.fromText(part.getText())); + } else if (part.getMedia() != null) { + Media media = part.getMedia(); + String url = media.getUrl(); + String contentType = media.getContentType(); + + if (url.startsWith("data:")) { + // Inline data URL + String base64Data = url.substring(url.indexOf(",") + 1); + if (contentType == null) { + contentType = url.substring(url.indexOf(":") + 1, url.indexOf(";")); + } + parts.add(com.google.genai.types.Part.fromBytes(Base64.getDecoder().decode(base64Data), + contentType)); + } else if (url.startsWith("gs://") || url.startsWith("http://") || url.startsWith("https://")) { + // File URI + parts.add(com.google.genai.types.Part.fromUri(url, + contentType != null ? contentType : "application/octet-stream")); + } + } else if (part.getToolRequest() != null) { + // Tool request (function call from model) + ToolRequest toolReq = part.getToolRequest(); + FunctionCall.Builder fcBuilder = FunctionCall.builder().name(toolReq.getName()); + if (toolReq.getInput() != null) { + @SuppressWarnings("unchecked") + Map args = (Map) toolReq.getInput(); + fcBuilder.args(args); + } + parts.add(com.google.genai.types.Part.builder().functionCall(fcBuilder.build()).build()); + } else if (part.getToolResponse() != null) { + // Tool response + ToolResponse toolResp = part.getToolResponse(); + FunctionResponse.Builder frBuilder = FunctionResponse.builder().name(toolResp.getName()); + if (toolResp.getOutput() != null) { + @SuppressWarnings("unchecked") + Map response = toolResp.getOutput() instanceof Map + ? (Map) toolResp.getOutput() + : Map.of("result", toolResp.getOutput()); + frBuilder.response(response); + } + parts.add(com.google.genai.types.Part.builder().functionResponse(frBuilder.build()).build()); + } + } + + // Convert Genkit role to Gemini role + String geminiRole = toGeminiRole(message.getRole()); + Content content = Content.builder().role(geminiRole).parts(parts).build(); + contents.add(content); + } + + return contents; + } + + /** + * Converts Genkit role to Gemini role. Gemini only supports "user" and "model" + * roles. TOOL role maps to "user" as it represents the user providing function + * results. + */ + private String toGeminiRole(Role role) { + switch (role) { + case USER : + return "user"; + case MODEL : + return "model"; + case TOOL : + // Tool responses are sent as user role in Gemini + return "user"; + default : + return "user"; + } + } + + private Message findSystemMessage(ModelRequest request) { + for (Message message : request.getMessages()) { + if (message.getRole() == Role.SYSTEM) { + return message; + } + } + return null; + } + + private String getTextFromMessage(Message message) { + StringBuilder sb = new StringBuilder(); + for (com.google.genkit.ai.Part part : message.getContent()) { + if (part.getText() != null) { + sb.append(part.getText()); + } + } + return sb.toString(); + } + + private ModelResponse parseResponse(GenerateContentResponse response) { + ModelResponse modelResponse = new ModelResponse(); + List candidates = new ArrayList<>(); + + if (response.candidates().isPresent()) { + for (com.google.genai.types.Candidate candidate : response.candidates().get()) { + com.google.genkit.ai.Candidate genkitCandidate = new com.google.genkit.ai.Candidate(); + + Message message = new Message(); + message.setRole(Role.MODEL); + List parts = new ArrayList<>(); + + if (candidate.content().isPresent()) { + Content content = candidate.content().get(); + if (content.parts().isPresent()) { + for (com.google.genai.types.Part part : content.parts().get()) { + // Text content + if (part.text().isPresent()) { + com.google.genkit.ai.Part textPart = new com.google.genkit.ai.Part(); + textPart.setText(part.text().get()); + parts.add(textPart); + } + + // Function call + if (part.functionCall().isPresent()) { + FunctionCall fc = part.functionCall().get(); + com.google.genkit.ai.Part toolPart = new com.google.genkit.ai.Part(); + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setName(fc.name().orElse("")); + if (fc.args().isPresent()) { + toolRequest.setInput(fc.args().get()); + } + toolPart.setToolRequest(toolRequest); + parts.add(toolPart); + } + } + } + } + + message.setContent(parts); + genkitCandidate.setMessage(message); + + // Map finish reason + if (candidate.finishReason().isPresent()) { + genkitCandidate.setFinishReason(mapFinishReason(candidate.finishReason().get().toString())); + } + + candidates.add(genkitCandidate); + } + } + + modelResponse.setCandidates(candidates); + + // Usage metadata + if (response.usageMetadata().isPresent()) { + com.google.genai.types.GenerateContentResponseUsageMetadata usage = response.usageMetadata().get(); + Usage genkitUsage = new Usage(); + if (usage.promptTokenCount().isPresent()) { + genkitUsage.setInputTokens(usage.promptTokenCount().get()); + } + if (usage.candidatesTokenCount().isPresent()) { + genkitUsage.setOutputTokens(usage.candidatesTokenCount().get()); + } + if (usage.totalTokenCount().isPresent()) { + genkitUsage.setTotalTokens(usage.totalTokenCount().get()); + } + modelResponse.setUsage(genkitUsage); + } + + return modelResponse; + } + + private com.google.genkit.ai.FinishReason mapFinishReason(String reason) { + if (reason == null) { + return com.google.genkit.ai.FinishReason.OTHER; + } + switch (reason.toUpperCase()) { + case "STOP" : + return com.google.genkit.ai.FinishReason.STOP; + case "MAX_TOKENS" : + case "LENGTH" : + return com.google.genkit.ai.FinishReason.LENGTH; + case "SAFETY" : + return com.google.genkit.ai.FinishReason.BLOCKED; + case "RECITATION" : + return com.google.genkit.ai.FinishReason.BLOCKED; + default : + return com.google.genkit.ai.FinishReason.OTHER; + } + } +} diff --git a/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GoogleGenAIPlugin.java b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GoogleGenAIPlugin.java new file mode 100644 index 0000000000..98d4599089 --- /dev/null +++ b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GoogleGenAIPlugin.java @@ -0,0 +1,239 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.googlegenai; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.core.Action; +import com.google.genkit.core.Plugin; + +/** + * Google GenAI plugin for Genkit using the official Google GenAI SDK. + * + *

+ * This plugin provides access to Google's Gemini models for: + *

    + *
  • Text generation (Gemini 2.0, 2.5, 3.0 series)
  • + *
  • Multimodal content (images, video, audio)
  • + *
  • Embeddings (text-embedding-004, gemini-embedding-001)
  • + *
  • Function calling/tools
  • + *
+ * + *

+ * Supports both: + *

    + *
  • Gemini Developer API (with API key)
  • + *
  • Vertex AI API (with GCP credentials)
  • + *
+ * + *

+ * Example usage: + * + *

{@code
+ * // Using Gemini Developer API with API key
+ * Genkit genkit = Genkit.builder().addPlugin(GoogleGenAIPlugin.create()) // Uses GOOGLE_API_KEY env var
+ * 		.build();
+ *
+ * // Using Vertex AI
+ * Genkit genkit = Genkit.builder().addPlugin(GoogleGenAIPlugin.create(
+ * 		GoogleGenAIPluginOptions.builder().vertexAI(true).project("my-project").location("us-central1").build()))
+ * 		.build();
+ *
+ * // Generate content
+ * GenerateResponse response = genkit
+ * 		.generate(GenerateOptions.builder().model("googleai/gemini-2.0-flash").prompt("Hello, world!").build());
+ * }
+ */ +public class GoogleGenAIPlugin implements Plugin { + + private static final Logger logger = LoggerFactory.getLogger(GoogleGenAIPlugin.class); + + /** + * Supported Gemini models for text/multimodal generation. + */ + public static final List SUPPORTED_MODELS = Arrays.asList( + // Gemini 3.0 series + "gemini-3-pro-preview", + // Gemini 2.5 series + "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite", + // Gemini 2.0 series + "gemini-2.0-flash", "gemini-2.0-flash-lite", + // Gemini 1.5 series (still widely used) + "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b", + // Gemma models + "gemma-3-12b-it", "gemma-3-27b-it", "gemma-3-4b-it", "gemma-3-1b-it", "gemma-3n-e4b-it"); + + /** + * Supported embedding models. + */ + public static final List SUPPORTED_EMBEDDING_MODELS = Arrays.asList("text-embedding-004", + "text-embedding-005", "gemini-embedding-001", "text-multilingual-embedding-002"); + + /** + * Supported image generation models (Imagen). Note: imagen-4.0-* models are + * supported by the Gemini Developer API. imagen-3.0-* models require Vertex AI. + */ + public static final List SUPPORTED_IMAGE_MODELS = Arrays.asList("imagen-4.0-generate-001", + "imagen-4.0-fast-generate-001"); + + /** + * Supported TTS models. + */ + public static final List SUPPORTED_TTS_MODELS = Arrays.asList("gemini-2.5-flash-preview-tts", + "gemini-2.5-pro-preview-tts"); + + /** + * Supported video generation models (Veo). + */ + public static final List SUPPORTED_VEO_MODELS = Arrays.asList("veo-2.0-generate-001", + "veo-3.0-generate-001", "veo-3.0-fast-generate-001", "veo-3.1-generate-preview", + "veo-3.1-fast-generate-preview"); + + private final GoogleGenAIPluginOptions options; + + /** + * Creates a GoogleGenAIPlugin with default options. Reads API key from + * GOOGLE_API_KEY or GEMINI_API_KEY environment variable. + */ + public GoogleGenAIPlugin() { + this(GoogleGenAIPluginOptions.builder().build()); + } + + /** + * Creates a GoogleGenAIPlugin with the specified options. + * + * @param options + * the plugin options + */ + public GoogleGenAIPlugin(GoogleGenAIPluginOptions options) { + this.options = options; + } + + /** + * Creates a GoogleGenAIPlugin with the specified API key. + * + * @param apiKey + * the Google API key + * @return a new GoogleGenAIPlugin + */ + public static GoogleGenAIPlugin create(String apiKey) { + return new GoogleGenAIPlugin(GoogleGenAIPluginOptions.builder().apiKey(apiKey).build()); + } + + /** + * Creates a GoogleGenAIPlugin using environment variables for configuration. + * + * @return a new GoogleGenAIPlugin + */ + public static GoogleGenAIPlugin create() { + return new GoogleGenAIPlugin(); + } + + /** + * Creates a GoogleGenAIPlugin with the specified options. + * + * @param options + * the plugin options + * @return a new GoogleGenAIPlugin + */ + public static GoogleGenAIPlugin create(GoogleGenAIPluginOptions options) { + return new GoogleGenAIPlugin(options); + } + + /** + * Creates a GoogleGenAIPlugin configured for Vertex AI. + * + * @param project + * the GCP project ID + * @param location + * the GCP location + * @return a new GoogleGenAIPlugin configured for Vertex AI + */ + public static GoogleGenAIPlugin vertexAI(String project, String location) { + return new GoogleGenAIPlugin( + GoogleGenAIPluginOptions.builder().vertexAI(true).project(project).location(location).build()); + } + + @Override + public String getName() { + return "googleai"; + } + + @Override + public List> init() { + List> actions = new ArrayList<>(); + + // Register chat/generation models + for (String modelName : SUPPORTED_MODELS) { + GeminiModel model = new GeminiModel(modelName, options); + actions.add(model); + logger.debug("Created Gemini model: {}", modelName); + } + + // Register embedding models + for (String modelName : SUPPORTED_EMBEDDING_MODELS) { + GeminiEmbedder embedder = new GeminiEmbedder(modelName, options); + actions.add(embedder); + logger.debug("Created Gemini embedder: {}", modelName); + } + + // Register image generation (Imagen) models + for (String modelName : SUPPORTED_IMAGE_MODELS) { + ImagenModel model = new ImagenModel(modelName, options); + actions.add(model); + logger.debug("Created Imagen model: {}", modelName); + } + + // Register TTS models + for (String modelName : SUPPORTED_TTS_MODELS) { + TtsModel model = new TtsModel(modelName, options); + actions.add(model); + logger.debug("Created TTS model: {}", modelName); + } + + // Register video generation (Veo) models + for (String modelName : SUPPORTED_VEO_MODELS) { + VeoModel model = new VeoModel(modelName, options); + actions.add(model); + logger.debug("Created Veo model: {}", modelName); + } + + String backend = options.isVertexAI() ? "Vertex AI" : "Gemini Developer API"; + logger.info( + "Google GenAI plugin initialized with {} models, {} embedders, {} image models, {} TTS models, and {} video models using {}", + SUPPORTED_MODELS.size(), SUPPORTED_EMBEDDING_MODELS.size(), SUPPORTED_IMAGE_MODELS.size(), + SUPPORTED_TTS_MODELS.size(), SUPPORTED_VEO_MODELS.size(), backend); + + return actions; + } + + /** + * Gets the plugin options. + * + * @return the options + */ + public GoogleGenAIPluginOptions getOptions() { + return options; + } +} diff --git a/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GoogleGenAIPluginOptions.java b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GoogleGenAIPluginOptions.java new file mode 100644 index 0000000000..43579ec159 --- /dev/null +++ b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/GoogleGenAIPluginOptions.java @@ -0,0 +1,283 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.googlegenai; + +import com.google.genai.types.HttpOptions; + +/** + * Options for configuring the Google GenAI plugin. + * + *

+ * The plugin can be configured to use either: + *

    + *
  • Gemini Developer API (default): Set the API key
  • + *
  • Vertex AI API: Set project, location, and enable vertexAI
  • + *
+ */ +public class GoogleGenAIPluginOptions { + + private final String apiKey; + private final String project; + private final String location; + private final boolean vertexAI; + private final String apiVersion; + private final String baseUrl; + private final int timeout; + + private GoogleGenAIPluginOptions(Builder builder) { + this.apiKey = builder.apiKey; + this.project = builder.project; + this.location = builder.location; + this.vertexAI = builder.vertexAI; + this.apiVersion = builder.apiVersion; + this.baseUrl = builder.baseUrl; + this.timeout = builder.timeout; + } + + /** + * Creates a new builder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the API key. + * + * @return the API key + */ + public String getApiKey() { + return apiKey; + } + + /** + * Gets the Google Cloud project ID (for Vertex AI). + * + * @return the project ID + */ + public String getProject() { + return project; + } + + /** + * Gets the Google Cloud location (for Vertex AI). + * + * @return the location + */ + public String getLocation() { + return location; + } + + /** + * Returns whether to use Vertex AI backend. + * + * @return true if using Vertex AI, false for Gemini Developer API + */ + public boolean isVertexAI() { + return vertexAI; + } + + /** + * Gets the API version. + * + * @return the API version + */ + public String getApiVersion() { + return apiVersion; + } + + /** + * Gets the base URL override. + * + * @return the base URL + */ + public String getBaseUrl() { + return baseUrl; + } + + /** + * Gets the request timeout in milliseconds. + * + * @return the timeout in milliseconds + */ + public int getTimeout() { + return timeout; + } + + /** + * Converts these options to HttpOptions for the Google GenAI SDK. + * + * @return HttpOptions + */ + public HttpOptions toHttpOptions() { + HttpOptions.Builder builder = HttpOptions.builder(); + if (apiVersion != null) { + builder.apiVersion(apiVersion); + } + if (baseUrl != null) { + builder.baseUrl(baseUrl); + } + if (timeout > 0) { + builder.timeout(timeout); + } + return builder.build(); + } + + /** + * Builder for GoogleGenAIPluginOptions. + */ + public static class Builder { + private String apiKey = getApiKeyFromEnv(); + private String project = getProjectFromEnv(); + private String location = getLocationFromEnv(); + private boolean vertexAI = getVertexAIFromEnv(); + private String apiVersion; + private String baseUrl; + private int timeout = 600000; // 10 minutes default (in milliseconds) + + private static String getApiKeyFromEnv() { + // GOOGLE_API_KEY takes precedence over GEMINI_API_KEY (legacy) + String apiKey = System.getenv("GOOGLE_API_KEY"); + if (apiKey == null || apiKey.isEmpty()) { + apiKey = System.getenv("GEMINI_API_KEY"); + } + return apiKey; + } + + private static String getProjectFromEnv() { + return System.getenv("GOOGLE_CLOUD_PROJECT"); + } + + private static String getLocationFromEnv() { + String location = System.getenv("GOOGLE_CLOUD_LOCATION"); + return location != null ? location : "us-central1"; + } + + private static boolean getVertexAIFromEnv() { + String useVertexAI = System.getenv("GOOGLE_GENAI_USE_VERTEXAI"); + return "true".equalsIgnoreCase(useVertexAI); + } + + /** + * Sets the API key for Gemini Developer API. + * + * @param apiKey + * the API key + * @return this builder + */ + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Sets the Google Cloud project ID for Vertex AI. + * + * @param project + * the project ID + * @return this builder + */ + public Builder project(String project) { + this.project = project; + return this; + } + + /** + * Sets the Google Cloud location for Vertex AI. + * + * @param location + * the location + * @return this builder + */ + public Builder location(String location) { + this.location = location; + return this; + } + + /** + * Sets whether to use Vertex AI backend. + * + * @param vertexAI + * true to use Vertex AI, false for Gemini Developer API + * @return this builder + */ + public Builder vertexAI(boolean vertexAI) { + this.vertexAI = vertexAI; + return this; + } + + /** + * Sets the API version. + * + * @param apiVersion + * the API version (e.g., "v1", "v1beta") + * @return this builder + */ + public Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + + /** + * Sets the base URL override. + * + * @param baseUrl + * the base URL + * @return this builder + */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets the request timeout in milliseconds. + * + * @param timeout + * the timeout in milliseconds + * @return this builder + */ + public Builder timeout(int timeout) { + this.timeout = timeout; + return this; + } + + /** + * Builds the GoogleGenAIPluginOptions. + * + * @return the built options + */ + public GoogleGenAIPluginOptions build() { + // Validate configuration + if (!vertexAI && (apiKey == null || apiKey.isEmpty())) { + throw new IllegalStateException("Google API key is required for Gemini Developer API. " + + "Set GOOGLE_API_KEY or GEMINI_API_KEY environment variable, " + + "or provide it in options, or enable vertexAI mode."); + } + if (vertexAI && (project == null || project.isEmpty()) && (apiKey == null || apiKey.isEmpty())) { + throw new IllegalStateException( + "For Vertex AI, either set GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION " + + "environment variables, or provide an API key for express mode."); + } + return new GoogleGenAIPluginOptions(this); + } + } +} diff --git a/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/ImagenModel.java b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/ImagenModel.java new file mode 100644 index 0000000000..2de2f8530b --- /dev/null +++ b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/ImagenModel.java @@ -0,0 +1,343 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.googlegenai; + +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genai.Client; +import com.google.genai.types.GenerateImagesConfig; +import com.google.genai.types.GenerateImagesResponse; +import com.google.genai.types.GeneratedImage; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.Image; +import com.google.genai.types.PersonGeneration; +import com.google.genkit.ai.FinishReason; +import com.google.genkit.ai.Media; +import com.google.genkit.ai.Message; +import com.google.genkit.ai.Model; +import com.google.genkit.ai.ModelInfo; +import com.google.genkit.ai.ModelRequest; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.ModelResponseChunk; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.Role; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** + * Imagen model implementation for image generation using the official Google + * GenAI SDK. + * + *

+ * Imagen is Google's text-to-image model that generates high-quality images + * from text prompts. + * + *

+ * Configuration options (passed via request.config): + *

    + *
  • numberOfImages - Number of images to generate (1-4)
  • + *
  • aspectRatio - Aspect ratio: "1:1", "3:4", "4:3", "9:16", "16:9"
  • + *
  • personGeneration - Control people generation: "dont_allow", + * "allow_adult", "allow_all"
  • + *
  • negativePrompt - Description of what to avoid in the generated + * images
  • + *
  • outputMimeType - MIME type of output: "image/png" or "image/jpeg"
  • + *
+ */ +public class ImagenModel implements Model { + + private static final Logger logger = LoggerFactory.getLogger(ImagenModel.class); + + private final String modelName; + private final GoogleGenAIPluginOptions options; + private final Client client; + private final ModelInfo info; + + /** + * Creates a new ImagenModel. + * + * @param modelName + * the model name (e.g., "imagen-3.0-generate-002") + * @param options + * the plugin options + */ + public ImagenModel(String modelName, GoogleGenAIPluginOptions options) { + this.modelName = modelName; + this.options = options; + this.client = createClient(); + this.info = createModelInfo(); + } + + private Client createClient() { + Client.Builder builder = Client.builder(); + + if (options.isVertexAI()) { + builder.vertexAI(true); + if (options.getProject() != null) { + builder.project(options.getProject()); + } + if (options.getLocation() != null) { + builder.location(options.getLocation()); + } + if (options.getApiKey() != null) { + builder.apiKey(options.getApiKey()); + } + } else { + builder.apiKey(options.getApiKey()); + } + + HttpOptions httpOptions = options.toHttpOptions(); + if (httpOptions != null) { + builder.httpOptions(httpOptions); + } + + return builder.build(); + } + + private ModelInfo createModelInfo() { + ModelInfo info = new ModelInfo(); + info.setLabel("Google AI " + modelName); + + ModelInfo.ModelCapabilities caps = new ModelInfo.ModelCapabilities(); + caps.setMultiturn(false); // Image generation is single-turn + caps.setMedia(true); // Input can include reference images + caps.setTools(false); // No tool support + caps.setSystemRole(false); // No system role + caps.setOutput(Set.of("media")); // Outputs media (images) + info.setSupports(caps); + + return info; + } + + @Override + public String getName() { + return "googleai/" + modelName; + } + + @Override + public ModelInfo getInfo() { + return info; + } + + @Override + public boolean supportsStreaming() { + return false; // Image generation doesn't support streaming + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request) { + try { + return generateImages(request); + } catch (Exception e) { + throw new GenkitException("Imagen API call failed: " + e.getMessage(), e); + } + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request, Consumer streamCallback) { + // Image generation doesn't support streaming, just call the regular method + return run(context, request); + } + + private ModelResponse generateImages(ModelRequest request) { + String prompt = extractPrompt(request); + if (prompt == null || prompt.isEmpty()) { + throw new GenkitException("Prompt is required for image generation"); + } + + GenerateImagesConfig config = buildConfig(request); + + logger.debug("Generating images with model {} and prompt: {}", modelName, prompt); + + GenerateImagesResponse response = client.models.generateImages(modelName, prompt, config); + + return parseResponse(response); + } + + private String extractPrompt(ModelRequest request) { + if (request.getMessages() == null || request.getMessages().isEmpty()) { + return null; + } + + // Get the last user message + for (int i = request.getMessages().size() - 1; i >= 0; i--) { + Message msg = request.getMessages().get(i); + if (msg.getRole() == Role.USER) { + return msg.getText(); + } + } + + return null; + } + + private GenerateImagesConfig buildConfig(ModelRequest request) { + GenerateImagesConfig.Builder configBuilder = GenerateImagesConfig.builder(); + + Map config = request.getConfig(); + if (config == null) { + // Default config + configBuilder.numberOfImages(1); + configBuilder.outputMimeType("image/png"); + return configBuilder.build(); + } + + // Number of images + if (config.containsKey("numberOfImages")) { + configBuilder.numberOfImages(((Number) config.get("numberOfImages")).intValue()); + } else { + configBuilder.numberOfImages(1); + } + + // Aspect ratio + if (config.containsKey("aspectRatio")) { + configBuilder.aspectRatio((String) config.get("aspectRatio")); + } + + // Person generation + if (config.containsKey("personGeneration")) { + String personGen = (String) config.get("personGeneration"); + switch (personGen.toLowerCase()) { + case "dont_allow" : + case "allow_none" : + configBuilder.personGeneration(PersonGeneration.Known.DONT_ALLOW); + break; + case "allow_adult" : + configBuilder.personGeneration(PersonGeneration.Known.ALLOW_ADULT); + break; + case "allow_all" : + configBuilder.personGeneration(PersonGeneration.Known.ALLOW_ALL); + break; + default : + configBuilder.personGeneration(personGen); + } + } + + // Negative prompt + if (config.containsKey("negativePrompt")) { + configBuilder.negativePrompt((String) config.get("negativePrompt")); + } + + // Output MIME type + if (config.containsKey("outputMimeType")) { + configBuilder.outputMimeType((String) config.get("outputMimeType")); + } else { + configBuilder.outputMimeType("image/png"); + } + + // Safety filter level + if (config.containsKey("safetyFilterLevel")) { + configBuilder.safetyFilterLevel((String) config.get("safetyFilterLevel")); + } + + // Include safety attributes + if (config.containsKey("includeSafetyAttributes")) { + configBuilder.includeSafetyAttributes((Boolean) config.get("includeSafetyAttributes")); + } + + // Guidance scale + if (config.containsKey("guidanceScale")) { + configBuilder.guidanceScale(((Number) config.get("guidanceScale")).floatValue()); + } + + // Seed + if (config.containsKey("seed")) { + configBuilder.seed(((Number) config.get("seed")).intValue()); + } + + return configBuilder.build(); + } + + private ModelResponse parseResponse(GenerateImagesResponse response) { + ModelResponse modelResponse = new ModelResponse(); + List candidates = new ArrayList<>(); + com.google.genkit.ai.Candidate candidate = new com.google.genkit.ai.Candidate(); + + Message message = new Message(); + message.setRole(Role.MODEL); + List parts = new ArrayList<>(); + + // Get generated images + if (response.generatedImages().isPresent()) { + List generatedImages = response.generatedImages().get(); + + for (GeneratedImage genImage : generatedImages) { + if (genImage.image().isPresent()) { + Image image = genImage.image().get(); + Part imagePart = createImagePart(image); + if (imagePart != null) { + parts.add(imagePart); + } + } + } + + logger.debug("Generated {} images", generatedImages.size()); + } else { + logger.warn("No images generated in response"); + } + + message.setContent(parts); + candidate.setMessage(message); + candidate.setFinishReason(FinishReason.STOP); + candidate.setIndex(0); + candidates.add(candidate); + + modelResponse.setCandidates(candidates); + modelResponse.setFinishReason(FinishReason.STOP); + + return modelResponse; + } + + private Part createImagePart(Image image) { + Part part = new Part(); + + // Image can have either imageBytes or gcsUri + if (image.imageBytes().isPresent()) { + byte[] imageBytes = image.imageBytes().get(); + String base64 = Base64.getEncoder().encodeToString(imageBytes); + String mimeType = image.mimeType().orElse("image/png"); + + Media media = new Media(); + media.setUrl("data:" + mimeType + ";base64," + base64); + media.setContentType(mimeType); + part.setMedia(media); + + return part; + } else if (image.gcsUri().isPresent()) { + String gcsUri = image.gcsUri().get(); + String mimeType = image.mimeType().orElse("image/png"); + + Media media = new Media(); + media.setUrl(gcsUri); + media.setContentType(mimeType); + part.setMedia(media); + + return part; + } + + return null; + } +} diff --git a/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/TtsModel.java b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/TtsModel.java new file mode 100644 index 0000000000..00f863812e --- /dev/null +++ b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/TtsModel.java @@ -0,0 +1,321 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.googlegenai; + +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genai.Client; +import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.PrebuiltVoiceConfig; +import com.google.genai.types.SpeechConfig; +import com.google.genai.types.VoiceConfig; +import com.google.genkit.ai.Candidate; +import com.google.genkit.ai.FinishReason; +import com.google.genkit.ai.Media; +import com.google.genkit.ai.Message; +import com.google.genkit.ai.Model; +import com.google.genkit.ai.ModelInfo; +import com.google.genkit.ai.ModelRequest; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.ModelResponseChunk; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.Role; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** + * Text-to-Speech model using Gemini TTS models. + * + *

+ * This model uses Gemini's TTS capabilities via responseModalities=AUDIO and + * speechConfig for voice configuration. + * + *

+ * Supported models: + *

    + *
  • gemini-2.5-flash-preview-tts
  • + *
  • gemini-2.5-pro-preview-tts
  • + *
+ * + *

+ * Configuration options (via custom config): + *

    + *
  • voiceName - Name of the voice to use (e.g., "Zephyr", "Puck", "Charon", + * "Kore", etc.)
  • + *
+ * + *

+ * Available voices: Zephyr, Puck, Charon, Kore, Fenrir, Leda, Orus, Aoede, + * Callirrhoe, Autonoe, Enceladus, Iapetus, Umbriel, Algieba, Despina, Erinome, + * Algenib, Rasalgethi, Laomedeia, Achernar, Alnilam, Schedar, Gacrux, + * Pulcherrima, Achird, Zubenelgenubi, Vindemiatrix, Sadachbia, Sadaltager, + * Sulafat + */ +public class TtsModel implements Model { + + private static final Logger logger = LoggerFactory.getLogger(TtsModel.class); + + private static final Set SUPPORTED_TTS_MODELS = Set.of("gemini-2.5-flash-preview-tts", + "gemini-2.5-pro-preview-tts"); + + private final String modelName; + private final GoogleGenAIPluginOptions options; + private final Client client; + private final ModelInfo info; + + /** + * Creates a TtsModel for the specified model. + * + * @param modelName + * the TTS model name + * @param options + * the plugin options + */ + public TtsModel(String modelName, GoogleGenAIPluginOptions options) { + this.modelName = modelName; + this.options = options; + this.client = createClient(); + this.info = createModelInfo(); + logger.debug("Initialized TTS model: {}", modelName); + } + + private Client createClient() { + Client.Builder builder = Client.builder(); + + if (options.isVertexAI()) { + builder.vertexAI(true); + if (options.getProject() != null) { + builder.project(options.getProject()); + } + if (options.getLocation() != null) { + builder.location(options.getLocation()); + } + if (options.getApiKey() != null) { + builder.apiKey(options.getApiKey()); + } + } else { + builder.apiKey(options.getApiKey()); + } + + HttpOptions httpOptions = options.toHttpOptions(); + if (httpOptions != null) { + builder.httpOptions(httpOptions); + } + + return builder.build(); + } + + private ModelInfo createModelInfo() { + ModelInfo info = new ModelInfo(); + info.setLabel("Google AI TTS " + modelName); + + ModelInfo.ModelCapabilities caps = new ModelInfo.ModelCapabilities(); + caps.setMultiturn(false); + caps.setMedia(false); + caps.setTools(false); + caps.setSystemRole(false); + caps.setOutput(Set.of("media")); + info.setSupports(caps); + + return info; + } + + @Override + public String getName() { + return "googleai/" + modelName; + } + + @Override + public ModelInfo getInfo() { + return info; + } + + @Override + public boolean supportsStreaming() { + return false; // TTS doesn't support streaming in the same way + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request) { + try { + return callTts(request); + } catch (Exception e) { + throw new GenkitException("TTS API call failed: " + e.getMessage(), e); + } + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request, Consumer streamCallback) { + // TTS doesn't support streaming - just return final audio + return run(context, request); + } + + private ModelResponse callTts(ModelRequest request) throws Exception { + String prompt = extractPrompt(request); + GenerateContentConfig config = buildConfig(request); + + logger.debug("Calling TTS model {} with prompt length: {}", modelName, prompt.length()); + + GenerateContentResponse response = client.models.generateContent(modelName, prompt, config); + + return parseResponse(response); + } + + private String extractPrompt(ModelRequest request) { + StringBuilder prompt = new StringBuilder(); + + if (request.getMessages() != null) { + for (Message message : request.getMessages()) { + if (message.getContent() != null) { + for (Part part : message.getContent()) { + if (part.getText() != null) { + if (prompt.length() > 0) { + prompt.append("\n"); + } + prompt.append(part.getText()); + } + } + } + } + } + + return prompt.toString(); + } + + @SuppressWarnings("unchecked") + private GenerateContentConfig buildConfig(ModelRequest request) { + GenerateContentConfig.Builder configBuilder = GenerateContentConfig.builder(); + + // Set response modalities to AUDIO + configBuilder.responseModalities("AUDIO"); + + // Build speech config + SpeechConfig.Builder speechConfigBuilder = SpeechConfig.builder(); + + // Check for voice configuration in custom config + Map config = request.getConfig(); + if (config != null) { + String voiceName = null; + if (config.containsKey("voiceName")) { + voiceName = (String) config.get("voiceName"); + } + + if (voiceName != null) { + VoiceConfig voiceConfig = VoiceConfig.builder() + .prebuiltVoiceConfig(PrebuiltVoiceConfig.builder().voiceName(voiceName).build()).build(); + speechConfigBuilder.voiceConfig(voiceConfig); + } + } + + configBuilder.speechConfig(speechConfigBuilder.build()); + + return configBuilder.build(); + } + + private ModelResponse parseResponse(GenerateContentResponse response) { + ModelResponse modelResponse = new ModelResponse(); + List candidates = new ArrayList<>(); + Candidate candidate = new Candidate(); + Message message = new Message(); + message.setRole(Role.MODEL); + List parts = new ArrayList<>(); + + // Extract audio parts from response + if (response.candidates().isPresent()) { + for (com.google.genai.types.Candidate genaiCandidate : response.candidates().get()) { + if (genaiCandidate.content().isPresent()) { + Content content = genaiCandidate.content().get(); + if (content.parts().isPresent()) { + for (com.google.genai.types.Part genaiPart : content.parts().get()) { + // Check for inline audio data + if (genaiPart.inlineData().isPresent()) { + com.google.genai.types.Blob blob = genaiPart.inlineData().get(); + Part audioPart = createAudioPart(blob); + if (audioPart != null) { + parts.add(audioPart); + } + } + } + } + } + } + } + + if (!parts.isEmpty()) { + logger.debug("Generated {} audio part(s)", parts.size()); + } else { + logger.warn("No audio generated in response"); + } + + message.setContent(parts); + candidate.setMessage(message); + candidate.setFinishReason(FinishReason.STOP); + candidate.setIndex(0); + candidates.add(candidate); + + modelResponse.setCandidates(candidates); + modelResponse.setFinishReason(FinishReason.STOP); + + return modelResponse; + } + + private Part createAudioPart(com.google.genai.types.Blob blob) { + Part part = new Part(); + + if (blob.data().isPresent()) { + byte[] audioBytes = blob.data().get(); + String base64 = Base64.getEncoder().encodeToString(audioBytes); + String mimeType = blob.mimeType().orElse("audio/wav"); + + Media media = new Media(); + media.setContentType(mimeType); + // Create data URL + media.setUrl("data:" + mimeType + ";base64," + base64); + part.setMedia(media); + + logger.debug("Created audio part with {} bytes, mime type: {}", audioBytes.length, mimeType); + return part; + } + + return null; + } + + /** + * Checks if the given model name is a supported TTS model. + * + * @param modelName + * the model name to check + * @return true if the model is a TTS model + */ + public static boolean isTtsModel(String modelName) { + return SUPPORTED_TTS_MODELS.contains(modelName) + || (modelName.startsWith("gemini-") && modelName.endsWith("-tts")); + } +} diff --git a/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/VeoModel.java b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/VeoModel.java new file mode 100644 index 0000000000..46aaec3310 --- /dev/null +++ b/java/plugins/google-genai/src/main/java/com/google/genkit/plugins/googlegenai/VeoModel.java @@ -0,0 +1,527 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.googlegenai; + +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genai.Client; +import com.google.genai.types.GenerateVideosConfig; +import com.google.genai.types.GenerateVideosOperation; +import com.google.genai.types.GenerateVideosResponse; +import com.google.genai.types.GenerateVideosSource; +import com.google.genai.types.GeneratedVideo; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.Video; +import com.google.genkit.ai.Candidate; +import com.google.genkit.ai.FinishReason; +import com.google.genkit.ai.Media; +import com.google.genkit.ai.Message; +import com.google.genkit.ai.Model; +import com.google.genkit.ai.ModelInfo; +import com.google.genkit.ai.ModelRequest; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.ModelResponseChunk; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.Role; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** + * Video generation model using Google Veo. + * + *

+ * This model generates videos using Google's Veo video generation models. Veo + * supports both text-to-video and image-to-video generation. + * + *

+ * Supported models: + *

    + *
  • veo-2.0-generate-001
  • + *
  • veo-3.0-generate-001
  • + *
  • veo-3.0-fast-generate-001
  • + *
  • veo-3.1-generate-preview
  • + *
  • veo-3.1-fast-generate-preview
  • + *
+ * + *

+ * Configuration options (via custom config): + *

    + *
  • numberOfVideos - Number of videos to generate (1-4, default: 1)
  • + *
  • durationSeconds - Video duration (5-8 seconds, default: 5)
  • + *
  • aspectRatio - Aspect ratio (16:9 or 9:16, default: 16:9)
  • + *
  • personGeneration - Allow person generation (allowed/disallowed)
  • + *
  • negativePrompt - Negative prompt for generation
  • + *
  • enhancePrompt - Enable prompt enhancement (default: true)
  • + *
  • seed - Random seed for reproducibility
  • + *
  • outputGcsUri - GCS URI for output
  • + *
  • generateAudio - Generate audio with video (veo-3.0+ only)
  • + *
  • pollIntervalMs - Polling interval in ms (default: 5000)
  • + *
  • timeoutMs - Operation timeout in ms (default: 300000)
  • + *
+ */ +public class VeoModel implements Model { + + private static final Logger logger = LoggerFactory.getLogger(VeoModel.class); + + private static final Set SUPPORTED_VEO_MODELS = Set.of("veo-2.0-generate-001", "veo-3.0-generate-001", + "veo-3.0-fast-generate-001", "veo-3.1-generate-preview", "veo-3.1-fast-generate-preview"); + + private static final long DEFAULT_POLL_INTERVAL_MS = 5000; + private static final long DEFAULT_TIMEOUT_MS = 300000; // 5 minutes + + private final String modelName; + private final GoogleGenAIPluginOptions options; + private final Client client; + private final ModelInfo info; + + /** + * Creates a VeoModel for the specified model. + * + * @param modelName + * the Veo model name + * @param options + * the plugin options + */ + public VeoModel(String modelName, GoogleGenAIPluginOptions options) { + this.modelName = modelName; + this.options = options; + this.client = createClient(); + this.info = createModelInfo(); + logger.debug("Initialized Veo model: {}", modelName); + } + + private Client createClient() { + Client.Builder builder = Client.builder(); + + if (options.isVertexAI()) { + builder.vertexAI(true); + if (options.getProject() != null) { + builder.project(options.getProject()); + } + if (options.getLocation() != null) { + builder.location(options.getLocation()); + } + if (options.getApiKey() != null) { + builder.apiKey(options.getApiKey()); + } + } else { + builder.apiKey(options.getApiKey()); + } + + // Use a longer timeout for video generation operations (10 minutes) + // Video generation involves long-polling operations that can take several + // minutes + HttpOptions.Builder httpBuilder = HttpOptions.builder(); + if (options.getApiVersion() != null) { + httpBuilder.apiVersion(options.getApiVersion()); + } + if (options.getBaseUrl() != null) { + httpBuilder.baseUrl(options.getBaseUrl()); + } + // Set a 10-minute timeout for HTTP operations + httpBuilder.timeout(600000); + builder.httpOptions(httpBuilder.build()); + + return builder.build(); + } + + private ModelInfo createModelInfo() { + ModelInfo info = new ModelInfo(); + info.setLabel("Google Veo " + modelName); + + ModelInfo.ModelCapabilities caps = new ModelInfo.ModelCapabilities(); + caps.setMultiturn(false); + caps.setMedia(true); // Supports image input for image-to-video + caps.setTools(false); + caps.setSystemRole(false); + caps.setOutput(Set.of("media")); + info.setSupports(caps); + + return info; + } + + @Override + public String getName() { + return "googleai/" + modelName; + } + + @Override + public ModelInfo getInfo() { + return info; + } + + @Override + public boolean supportsStreaming() { + return false; // Video generation doesn't support streaming + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request) { + try { + return generateVideo(request); + } catch (Exception e) { + throw new GenkitException("Video generation failed: " + e.getMessage(), e); + } + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request, Consumer streamCallback) { + // Video generation doesn't support streaming + return run(context, request); + } + + private ModelResponse generateVideo(ModelRequest request) throws Exception { + String prompt = extractPrompt(request); + GenerateVideosConfig config = buildConfig(request); + GenerateVideosSource source = buildSource(request, prompt); + + logger.debug("Calling Veo model {} with prompt: {}", modelName, + prompt.substring(0, Math.min(100, prompt.length()))); + + // Start video generation operation + GenerateVideosOperation operation = client.models.generateVideos(modelName, source, config); + + // Poll for completion + Map customConfig = request.getConfig(); + long pollIntervalMs = DEFAULT_POLL_INTERVAL_MS; + long timeoutMs = DEFAULT_TIMEOUT_MS; + + if (customConfig != null) { + if (customConfig.containsKey("pollIntervalMs")) { + pollIntervalMs = ((Number) customConfig.get("pollIntervalMs")).longValue(); + } + if (customConfig.containsKey("timeoutMs")) { + timeoutMs = ((Number) customConfig.get("timeoutMs")).longValue(); + } + } + + GenerateVideosResponse response = pollForCompletion(operation, pollIntervalMs, timeoutMs); + + return parseResponse(response); + } + + private GenerateVideosResponse pollForCompletion(GenerateVideosOperation operation, long pollIntervalMs, + long timeoutMs) throws Exception { + long startTime = System.currentTimeMillis(); + String operationName = operation.name().orElse(""); + + while (true) { + // Check if done + if (operation.done().orElse(false)) { + logger.debug("Video generation completed for operation: {}", operationName); + if (operation.response().isPresent()) { + return operation.response().get(); + } + // Check for error + if (operation.error().isPresent()) { + Map error = operation.error().get(); + String errorMsg = error.containsKey("message") + ? String.valueOf(error.get("message")) + : "Unknown error"; + throw new GenkitException("Video generation failed: " + errorMsg); + } + throw new GenkitException("Video generation completed but no response"); + } + + // Check timeout + if (System.currentTimeMillis() - startTime > timeoutMs) { + throw new GenkitException("Video generation timed out after " + timeoutMs + "ms"); + } + + // Sleep and poll again + Thread.sleep(pollIntervalMs); + operation = client.operations.getVideosOperation(operation, null); + } + } + + private String extractPrompt(ModelRequest request) { + StringBuilder prompt = new StringBuilder(); + + if (request.getMessages() != null) { + for (Message message : request.getMessages()) { + if (message.getContent() != null) { + for (Part part : message.getContent()) { + if (part.getText() != null) { + if (prompt.length() > 0) { + prompt.append("\n"); + } + prompt.append(part.getText()); + } + } + } + } + } + + return prompt.toString(); + } + + private GenerateVideosSource buildSource(ModelRequest request, String prompt) { + GenerateVideosSource.Builder sourceBuilder = GenerateVideosSource.builder(); + sourceBuilder.prompt(prompt); + + // Look for image in the messages for image-to-video + if (request.getMessages() != null) { + for (Message message : request.getMessages()) { + if (message.getContent() != null) { + for (Part part : message.getContent()) { + if (part.getMedia() != null) { + Media media = part.getMedia(); + String contentType = media.getContentType(); + if (contentType != null && contentType.startsWith("image/")) { + com.google.genai.types.Image image = createImage(media); + if (image != null) { + sourceBuilder.image(image); + logger.debug("Added reference image for image-to-video generation"); + } + } + } + } + } + } + } + + return sourceBuilder.build(); + } + + private com.google.genai.types.Image createImage(Media media) { + com.google.genai.types.Image.Builder builder = com.google.genai.types.Image.builder(); + + String url = media.getUrl(); + if (url != null) { + if (url.startsWith("data:")) { + // Parse data URL + int commaIndex = url.indexOf(','); + if (commaIndex > 0) { + String base64Data = url.substring(commaIndex + 1); + byte[] imageBytes = Base64.getDecoder().decode(base64Data); + builder.imageBytes(imageBytes); + + // Extract mime type + String header = url.substring(5, commaIndex); + int semiIndex = header.indexOf(';'); + if (semiIndex > 0) { + builder.mimeType(header.substring(0, semiIndex)); + } + } + } else if (url.startsWith("gs://")) { + builder.gcsUri(url); + } + // Note: HTTP URLs not directly supported by Image.Builder + // Would need to download and use imageBytes instead + } + + return builder.build(); + } + + @SuppressWarnings("unchecked") + private GenerateVideosConfig buildConfig(ModelRequest request) { + GenerateVideosConfig.Builder configBuilder = GenerateVideosConfig.builder(); + + Map config = request.getConfig(); + if (config != null) { + // Number of videos + if (config.containsKey("numberOfVideos")) { + configBuilder.numberOfVideos(((Number) config.get("numberOfVideos")).intValue()); + } + + // Duration (5-8 seconds) + if (config.containsKey("durationSeconds")) { + configBuilder.durationSeconds(((Number) config.get("durationSeconds")).intValue()); + } + + // Aspect ratio + if (config.containsKey("aspectRatio")) { + configBuilder.aspectRatio((String) config.get("aspectRatio")); + } + + // Person generation + if (config.containsKey("personGeneration")) { + configBuilder.personGeneration((String) config.get("personGeneration")); + } + + // Negative prompt + if (config.containsKey("negativePrompt")) { + configBuilder.negativePrompt((String) config.get("negativePrompt")); + } + + // Enhance prompt + if (config.containsKey("enhancePrompt")) { + configBuilder.enhancePrompt((Boolean) config.get("enhancePrompt")); + } + + // Seed + if (config.containsKey("seed")) { + configBuilder.seed(((Number) config.get("seed")).intValue()); + } + + // Output GCS URI + if (config.containsKey("outputGcsUri")) { + configBuilder.outputGcsUri((String) config.get("outputGcsUri")); + } + + // Generate audio (veo-3.0+ only) + if (config.containsKey("generateAudio")) { + configBuilder.generateAudio((Boolean) config.get("generateAudio")); + } + } + + return configBuilder.build(); + } + + private ModelResponse parseResponse(GenerateVideosResponse response) { + ModelResponse modelResponse = new ModelResponse(); + List candidates = new ArrayList<>(); + Candidate candidate = new Candidate(); + Message message = new Message(); + message.setRole(Role.MODEL); + List parts = new ArrayList<>(); + + // Log the raw response for debugging + logger.info("Video response: generatedVideos present={}", response.generatedVideos().isPresent()); + if (response.generatedVideos().isPresent()) { + logger.info("Number of generated videos: {}", response.generatedVideos().get().size()); + } + + // Extract generated videos + if (response.generatedVideos().isPresent()) { + for (GeneratedVideo generatedVideo : response.generatedVideos().get()) { + logger.info("GeneratedVideo: video present={}", generatedVideo.video().isPresent()); + if (generatedVideo.video().isPresent()) { + Video video = generatedVideo.video().get(); + logger.info("Video: uri={}, videoBytes present={}, mimeType={}", video.uri().orElse("none"), + video.videoBytes().isPresent(), video.mimeType().orElse("none")); + Part videoPart = createVideoPart(video); + if (videoPart != null) { + parts.add(videoPart); + } + } + } + } + + if (!parts.isEmpty()) { + logger.debug("Generated {} video(s)", parts.size()); + } else { + logger.warn("No videos generated in response"); + } + + message.setContent(parts); + candidate.setMessage(message); + candidate.setFinishReason(FinishReason.STOP); + candidate.setIndex(0); + candidates.add(candidate); + + modelResponse.setCandidates(candidates); + modelResponse.setFinishReason(FinishReason.STOP); + + return modelResponse; + } + + private Part createVideoPart(Video video) { + Part part = new Part(); + Media media = new Media(); + + // Check for video bytes first + if (video.videoBytes().isPresent()) { + byte[] videoBytes = video.videoBytes().get(); + String base64 = Base64.getEncoder().encodeToString(videoBytes); + String mimeType = video.mimeType().orElse("video/mp4"); + media.setContentType(mimeType); + media.setUrl("data:" + mimeType + ";base64," + base64); + logger.debug("Created video part with {} bytes", videoBytes.length); + } else if (video.uri().isPresent()) { + String uri = video.uri().get(); + + // If it's an HTTP(S) URL, download the video and convert to base64 + if (uri.startsWith("http://") || uri.startsWith("https://")) { + try { + byte[] videoBytes = downloadVideo(uri); + String base64 = Base64.getEncoder().encodeToString(videoBytes); + String mimeType = video.mimeType().orElse("video/mp4"); + media.setContentType(mimeType); + media.setUrl("data:" + mimeType + ";base64," + base64); + logger.info("Downloaded and encoded video from URL, {} bytes", videoBytes.length); + } catch (Exception e) { + logger.warn("Failed to download video from URL: {}, using URL directly", uri, e); + media.setUrl(uri); + media.setContentType(video.mimeType().orElse("video/mp4")); + } + } else { + // Use URI directly (e.g., gs:// URLs) + media.setUrl(uri); + media.setContentType(video.mimeType().orElse("video/mp4")); + logger.debug("Created video part with URI: {}", uri); + } + } else { + logger.warn("Video has neither bytes nor URI"); + return null; + } + + part.setMedia(media); + return part; + } + + private byte[] downloadVideo(String urlString) throws Exception { + // Append API key to URL for authentication + String authenticatedUrl = urlString; + if (options.getApiKey() != null && !options.isVertexAI()) { + String separator = urlString.contains("?") ? "&" : "?"; + authenticatedUrl = urlString + separator + "key=" + options.getApiKey(); + } + + URL url = new URL(authenticatedUrl); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setConnectTimeout(30000); + connection.setReadTimeout(300000); // 5 minutes for large videos + + try (InputStream in = connection.getInputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream()) { + byte[] buffer = new byte[8192]; + int bytesRead; + while ((bytesRead = in.read(buffer)) != -1) { + out.write(buffer, 0, bytesRead); + } + return out.toByteArray(); + } finally { + connection.disconnect(); + } + } + + /** + * Checks if the given model name is a supported Veo model. + * + * @param modelName + * the model name to check + * @return true if the model is a Veo model + */ + public static boolean isVeoModel(String modelName) { + return SUPPORTED_VEO_MODELS.contains(modelName) || modelName.startsWith("veo-"); + } +} diff --git a/java/plugins/jetty/pom.xml b/java/plugins/jetty/pom.xml new file mode 100644 index 0000000000..f663223676 --- /dev/null +++ b/java/plugins/jetty/pom.xml @@ -0,0 +1,73 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + genkit-plugin-jetty + jar + Genkit Jetty Plugin + Jetty HTTP server plugin for Genkit + + + + com.google.genkit + genkit + ${project.version} + + + + + org.eclipse.jetty + jetty-server + + + org.eclipse.jetty.ee10 + jetty-ee10-servlet + + + + + com.fasterxml.jackson.core + jackson-databind + + + + + org.slf4j + slf4j-api + + + + + org.junit.jupiter + junit-jupiter + test + + + diff --git a/java/plugins/jetty/src/main/java/com/google/genkit/plugins/jetty/JettyPlugin.java b/java/plugins/jetty/src/main/java/com/google/genkit/plugins/jetty/JettyPlugin.java new file mode 100644 index 0000000000..479fddf6a0 --- /dev/null +++ b/java/plugins/jetty/src/main/java/com/google/genkit/plugins/jetty/JettyPlugin.java @@ -0,0 +1,330 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.jetty; + +import java.io.*; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.*; + +import org.eclipse.jetty.server.Handler; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Response; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.server.handler.ContextHandler; +import org.eclipse.jetty.server.handler.ContextHandlerCollection; +import org.eclipse.jetty.util.Callback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genkit.core.*; + +/** + * JettyPlugin provides HTTP endpoints for Genkit flows. + * + *

+ * This plugin exposes registered flows as HTTP endpoints, making it easy to + * deploy Genkit applications as web services. + * + *

+ * Example usage: + * + *

{@code
+ * Genkit genkit = Genkit.builder().plugin(new JettyPlugin(JettyPluginOptions.builder().port(8080).build())).build();
+ * 
+ * // Define your flows...
+ * 
+ * // Start the server and block (keeps application running)
+ * genkit.start();
+ * }
+ */ +public class JettyPlugin implements ServerPlugin { + + private static final Logger logger = LoggerFactory.getLogger(JettyPlugin.class); + + private final JettyPluginOptions options; + private Server server; + private Registry registry; + private ObjectMapper objectMapper; + + /** + * Creates a JettyPlugin with default options. + */ + public JettyPlugin() { + this(JettyPluginOptions.builder().build()); + } + + /** + * Creates a JettyPlugin with the specified options. + * + * @param options + * the plugin options + */ + public JettyPlugin(JettyPluginOptions options) { + this.options = options; + this.objectMapper = new ObjectMapper(); + } + + /** + * Creates a JettyPlugin with the specified port. + * + * @param port + * the HTTP port + * @return a new JettyPlugin + */ + public static JettyPlugin create(int port) { + return new JettyPlugin(JettyPluginOptions.builder().port(port).build()); + } + + @Override + public String getName() { + return "jetty"; + } + + @Override + public List> init() { + // Jetty plugin doesn't provide actions itself + return Collections.emptyList(); + } + + @Override + public List> init(Registry registry) { + this.registry = registry; + return Collections.emptyList(); + } + + /** + * Starts the Jetty server and blocks until it is stopped. + * + *

+ * This is the recommended way to start the server in a main() method. Similar + * to Express's app.listen() in JavaScript, this method will keep your + * application running until the server is explicitly stopped. + * + *

+ * Example usage: + * + *

{@code
+   * JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build());
+   * 
+   * Genkit genkit = Genkit.builder().plugin(jetty).build();
+   * 
+   * // Define your flows...
+   * 
+   * // Start and block
+   * jetty.start();
+   * }
+ * + * @throws Exception + * if the server cannot be started or if interrupted while waiting + */ + @Override + public void start() throws Exception { + if (registry == null) { + throw new GenkitException( + "Registry not set. Make sure JettyPlugin is added to Genkit before calling start()."); + } + + startServer(); + server.join(); + } + + /** + * Starts the Jetty server without blocking. + * + * @throws Exception + * if the server cannot be started + */ + private void startServer() throws Exception { + if (server != null) { + return; + } + + if (registry == null) { + throw new GenkitException( + "Registry not set. Make sure JettyPlugin is added to Genkit before calling start()."); + } + + server = new Server(); + + ServerConnector connector = new ServerConnector(server); + connector.setPort(options.getPort()); + connector.setHost(options.getHost()); + server.addConnector(connector); + + // Create handler collection + ContextHandlerCollection handlers = new ContextHandlerCollection(); + + // Add flow endpoints + addFlowHandlers(handlers); + + // Add health endpoint + ContextHandler healthHandler = new ContextHandler("/health"); + healthHandler.setHandler(new HealthHandler()); + handlers.addHandler(healthHandler); + + server.setHandler(handlers); + server.start(); + + logger.info("Jetty server started on {}:{}", options.getHost(), options.getPort()); + } + + /** + * Stops the Jetty server. + * + * @throws Exception + * if the server cannot be stopped + */ + @Override + public void stop() throws Exception { + if (server != null) { + server.stop(); + server = null; + logger.info("Jetty server stopped"); + } + } + + /** + * Returns the port the server is listening on. + * + * @return the configured port + */ + @Override + public int getPort() { + return options.getPort(); + } + + /** + * Returns true if the server is currently running. + * + * @return true if the server is running, false otherwise + */ + @Override + public boolean isRunning() { + return server != null && server.isRunning(); + } + + /** + * Adds HTTP handlers for all registered flows. + */ + private void addFlowHandlers(ContextHandlerCollection handlers) { + List> flows = registry.listActions(ActionType.FLOW); + + for (Action action : flows) { + String path = options.getBasePath() + "/" + action.getName(); + + ContextHandler handler = new ContextHandler(path); + handler.setAllowNullPathInContext(true); + handler.setHandler(new FlowHandler(action)); + handlers.addHandler(handler); + + logger.info("Registered flow endpoint: {}", path); + } + } + + /** + * Handler for health check endpoint. + */ + private class HealthHandler extends Handler.Abstract { + @Override + public boolean handle(Request request, Response response, Callback callback) throws Exception { + response.setStatus(200); + response.getHeaders().put("Content-Type", "application/json"); + + String json = "{\"status\":\"ok\"}"; + response.write(true, ByteBuffer.wrap(json.getBytes(StandardCharsets.UTF_8)), callback); + + return true; + } + } + + /** + * Handler for flow endpoints. + */ + private class FlowHandler extends Handler.Abstract { + private final Action action; + + @SuppressWarnings("unchecked") + FlowHandler(Action action) { + this.action = (Action) action; + } + + @Override + public boolean handle(Request request, Response response, Callback callback) throws Exception { + try { + // Only accept POST requests + if (!"POST".equals(request.getMethod())) { + response.setStatus(405); + response.getHeaders().put("Content-Type", "application/json"); + String error = "{\"error\":\"Method not allowed\"}"; + response.write(true, ByteBuffer.wrap(error.getBytes(StandardCharsets.UTF_8)), callback); + return true; + } + + // Read request body + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + Request.asInputStream(request).transferTo(baos); + String body = baos.toString(StandardCharsets.UTF_8); + + // Parse input + Object input = null; + if (body != null && !body.isEmpty()) { + input = objectMapper.readValue(body, Object.class); + } + + // Run the action + ActionContext context = new ActionContext(registry); + Object result = action.run(context, input); + + // Send response + response.setStatus(200); + response.getHeaders().put("Content-Type", "application/json"); + + String json = objectMapper.writeValueAsString(result); + response.write(true, ByteBuffer.wrap(json.getBytes(StandardCharsets.UTF_8)), callback); + + return true; + } catch (Exception e) { + logger.error("Error handling flow request", e); + + response.setStatus(500); + response.getHeaders().put("Content-Type", "application/json"); + + // Format error with structured error status for proper UI display + // For HTTP 500, send error status directly (no wrapper) + // Format: {code, message, details: {stack}} + String errorMessage = e.getMessage() != null ? e.getMessage() : "Unknown error"; + java.io.StringWriter sw = new java.io.StringWriter(); + e.printStackTrace(new java.io.PrintWriter(sw)); + String stacktrace = sw.toString(); + + Map errorDetails = Map.of("stack", stacktrace); + Map errorStatus = Map.of("code", 2, // INTERNAL error code + "message", errorMessage, "details", errorDetails); + + String json = objectMapper.writeValueAsString(errorStatus); + response.write(true, ByteBuffer.wrap(json.getBytes(StandardCharsets.UTF_8)), callback); + + return true; + } + } + } +} diff --git a/java/plugins/jetty/src/main/java/com/google/genkit/plugins/jetty/JettyPluginOptions.java b/java/plugins/jetty/src/main/java/com/google/genkit/plugins/jetty/JettyPluginOptions.java new file mode 100644 index 0000000000..1ec399e3c0 --- /dev/null +++ b/java/plugins/jetty/src/main/java/com/google/genkit/plugins/jetty/JettyPluginOptions.java @@ -0,0 +1,99 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.jetty; + +/** + * Options for configuring the Jetty plugin. + */ +public class JettyPluginOptions { + + private final int port; + private final String host; + private final String basePath; + + private JettyPluginOptions(Builder builder) { + this.port = builder.port; + this.host = builder.host; + this.basePath = builder.basePath; + } + + /** + * Creates a new builder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the HTTP port. + * + * @return the port + */ + public int getPort() { + return port; + } + + /** + * Gets the host to bind to. + * + * @return the host + */ + public String getHost() { + return host; + } + + /** + * Gets the base path for flow endpoints. + * + * @return the base path + */ + public String getBasePath() { + return basePath; + } + + /** + * Builder for JettyPluginOptions. + */ + public static class Builder { + private int port = 8080; + private String host = "0.0.0.0"; + private String basePath = "/api/flows"; + + public Builder port(int port) { + this.port = port; + return this; + } + + public Builder host(String host) { + this.host = host; + return this; + } + + public Builder basePath(String basePath) { + this.basePath = basePath; + return this; + } + + public JettyPluginOptions build() { + return new JettyPluginOptions(this); + } + } +} diff --git a/java/plugins/localvec/pom.xml b/java/plugins/localvec/pom.xml new file mode 100644 index 0000000000..0f4cafd950 --- /dev/null +++ b/java/plugins/localvec/pom.xml @@ -0,0 +1,64 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + genkit-plugin-localvec + jar + Genkit Local Vector Store Plugin + Local file-based vector store for development and testing + + + + com.google.genkit + genkit-core + ${project.version} + + + com.google.genkit + genkit-ai + ${project.version} + + + com.fasterxml.jackson.core + jackson-databind + + + org.slf4j + slf4j-api + + + + + org.junit.jupiter + junit-jupiter + test + + + diff --git a/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/DbValue.java b/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/DbValue.java new file mode 100644 index 0000000000..335101582f --- /dev/null +++ b/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/DbValue.java @@ -0,0 +1,95 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.localvec; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.genkit.ai.Document; + +/** + * Represents a stored document value with its embedding. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class DbValue { + + @JsonProperty("doc") + private Document doc; + + @JsonProperty("embedding") + private List embedding; + + /** + * Default constructor for Jackson. + */ + public DbValue() { + } + + /** + * Creates a new DbValue. + * + * @param doc + * the document + * @param embedding + * the embedding vector + */ + public DbValue(Document doc, List embedding) { + this.doc = doc; + this.embedding = embedding; + } + + /** + * Gets the document. + * + * @return the document + */ + public Document getDoc() { + return doc; + } + + /** + * Sets the document. + * + * @param doc + * the document + */ + public void setDoc(Document doc) { + this.doc = doc; + } + + /** + * Gets the embedding. + * + * @return the embedding vector + */ + public List getEmbedding() { + return embedding; + } + + /** + * Sets the embedding. + * + * @param embedding + * the embedding vector + */ + public void setEmbedding(List embedding) { + this.embedding = embedding; + } +} diff --git a/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/LocalVecConfig.java b/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/LocalVecConfig.java new file mode 100644 index 0000000000..3dd67053a3 --- /dev/null +++ b/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/LocalVecConfig.java @@ -0,0 +1,225 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.localvec; + +import java.nio.file.Path; +import java.nio.file.Paths; + +import com.google.genkit.ai.Embedder; + +/** + * Configuration for a local vector store. + */ +public class LocalVecConfig { + + private final String indexName; + private Embedder embedder; + private final String embedderName; + private final Path directory; + private final Object embedderOptions; + + private LocalVecConfig(Builder builder) { + this.indexName = builder.indexName; + this.embedder = builder.embedder; + this.embedderName = builder.embedderName; + this.directory = builder.directory; + this.embedderOptions = builder.embedderOptions; + } + + /** + * Creates a new builder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the index name. + * + * @return the index name + */ + public String getIndexName() { + return indexName; + } + + /** + * Gets the embedder. + * + * @return the embedder + */ + public Embedder getEmbedder() { + return embedder; + } + + /** + * Gets the embedder name for deferred resolution. + * + * @return the embedder name, or null if embedder was set directly + */ + public String getEmbedderName() { + return embedderName; + } + + /** + * Sets the embedder (used for deferred resolution). + * + * @param embedder + * the embedder + */ + void setEmbedder(Embedder embedder) { + this.embedder = embedder; + } + + /** + * Gets the directory where data is stored. + * + * @return the directory path + */ + public Path getDirectory() { + return directory; + } + + /** + * Gets the embedder options. + * + * @return the embedder options, or null + */ + public Object getEmbedderOptions() { + return embedderOptions; + } + + /** + * Gets the filename for this index. + * + * @return the filename + */ + public String getFilename() { + return "__db_" + indexName + ".json"; + } + + /** + * Gets the full path to the data file. + * + * @return the full file path + */ + public Path getFilePath() { + return directory.resolve(getFilename()); + } + + /** + * Builder for LocalVecConfig. + */ + public static class Builder { + private String indexName; + private Embedder embedder; + private String embedderName; + private Path directory = Paths.get(System.getProperty("java.io.tmpdir")); + private Object embedderOptions; + + /** + * Sets the index name. + * + * @param indexName + * the index name + * @return this builder + */ + public Builder indexName(String indexName) { + this.indexName = indexName; + return this; + } + + /** + * Sets the embedder. + * + * @param embedder + * the embedder + * @return this builder + */ + public Builder embedder(Embedder embedder) { + this.embedder = embedder; + return this; + } + + /** + * Sets the embedder by name for deferred resolution. The embedder will be + * resolved from the registry during plugin initialization. + * + * @param embedderName + * the embedder name (e.g., "openai/text-embedding-3-small") + * @return this builder + */ + public Builder embedderName(String embedderName) { + this.embedderName = embedderName; + return this; + } + + /** + * Sets the directory for storing data. + * + * @param directory + * the directory path + * @return this builder + */ + public Builder directory(Path directory) { + this.directory = directory; + return this; + } + + /** + * Sets the directory for storing data. + * + * @param directory + * the directory path as string + * @return this builder + */ + public Builder directory(String directory) { + this.directory = Paths.get(directory); + return this; + } + + /** + * Sets the embedder options. + * + * @param embedderOptions + * the embedder options + * @return this builder + */ + public Builder embedderOptions(Object embedderOptions) { + this.embedderOptions = embedderOptions; + return this; + } + + /** + * Builds the configuration. + * + * @return the configuration + */ + public LocalVecConfig build() { + if (indexName == null || indexName.isEmpty()) { + throw new IllegalStateException("Index name is required"); + } + if (embedder == null && embedderName == null) { + throw new IllegalStateException("Either embedder or embedderName is required"); + } + return new LocalVecConfig(this); + } + } +} diff --git a/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/LocalVecDocStore.java b/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/LocalVecDocStore.java new file mode 100644 index 0000000000..8eb6380f1c --- /dev/null +++ b/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/LocalVecDocStore.java @@ -0,0 +1,324 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.localvec; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genkit.ai.*; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** + * Local file-based document store implementation. + * + *

+ * Stores documents and their embeddings in a JSON file for simple similarity + * search. Uses cosine similarity for retrieval. + */ +public class LocalVecDocStore { + + private static final Logger logger = LoggerFactory.getLogger(LocalVecDocStore.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private final LocalVecConfig config; + private final Map data; + + /** + * Creates a new LocalVecDocStore. + * + * @param config + * the configuration + */ + public LocalVecDocStore(LocalVecConfig config) { + this.config = config; + this.data = new ConcurrentHashMap<>(); + loadFromFile(); + } + + /** + * Loads data from the file if it exists. + */ + private void loadFromFile() { + Path filePath = config.getFilePath(); + if (Files.exists(filePath)) { + try { + String content = Files.readString(filePath); + Map loaded = objectMapper.readValue(content, + new TypeReference>() { + }); + if (loaded != null) { + data.putAll(loaded); + logger.info("Loaded {} documents from {}", data.size(), filePath); + } + } catch (IOException e) { + logger.warn("Failed to load data from {}: {}", filePath, e.getMessage()); + } + } + } + + /** + * Saves data to the file. + */ + private synchronized void saveToFile() { + try { + Path directory = config.getDirectory(); + if (!Files.exists(directory)) { + Files.createDirectories(directory); + } + Path filePath = config.getFilePath(); + String json = objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(data); + Files.writeString(filePath, json); + logger.debug("Saved {} documents to {}", data.size(), filePath); + } catch (IOException e) { + throw new GenkitException("Failed to save data to file", e); + } + } + + /** + * Indexes documents with their embeddings. + * + * @param ctx + * the action context + * @param documents + * the documents to index + * @throws GenkitException + * if indexing fails + */ + public void index(ActionContext ctx, List documents) throws GenkitException { + if (documents == null || documents.isEmpty()) { + return; + } + + try { + // Get embeddings for all documents + EmbedRequest embedRequest = new EmbedRequest(documents); + EmbedResponse embedResponse = config.getEmbedder().run(ctx, embedRequest); + + List embeddings = embedResponse.getEmbeddings(); + if (embeddings.size() != documents.size()) { + throw new GenkitException( + "Embedding count mismatch: expected " + documents.size() + ", got " + embeddings.size()); + } + + // Store each document with its embedding + for (int i = 0; i < documents.size(); i++) { + Document doc = documents.get(i); + List embedding = floatArrayToList(embeddings.get(i).getValues()); + + String id = computeDocumentId(doc); + if (!data.containsKey(id)) { + data.put(id, new DbValue(doc, embedding)); + logger.debug("Indexed document: {}", id); + } else { + logger.debug("Skipping duplicate document: {}", id); + } + } + + saveToFile(); + logger.info("Indexed {} documents to {}", documents.size(), config.getIndexName()); + + } catch (Exception e) { + throw new GenkitException("Failed to index documents: " + e.getMessage(), e); + } + } + + /** + * Retrieves documents similar to the query. + * + * @param ctx + * the action context + * @param request + * the retriever request + * @return the retriever response with matched documents + * @throws GenkitException + * if retrieval fails + */ + public RetrieverResponse retrieve(ActionContext ctx, RetrieverRequest request) throws GenkitException { + try { + // Get query document + Document queryDoc = request.getQuery(); + if (queryDoc == null) { + throw new GenkitException("Query document is required"); + } + + // Get embedding for the query + EmbedRequest embedRequest = new EmbedRequest(List.of(queryDoc)); + EmbedResponse embedResponse = config.getEmbedder().run(ctx, embedRequest); + List queryEmbedding = floatArrayToList(embedResponse.getEmbeddings().get(0).getValues()); + + // Get k parameter from options + int k = 3; + if (request.getOptions() != null) { + Object optionsObj = request.getOptions(); + if (optionsObj instanceof Map) { + @SuppressWarnings("unchecked") + Map options = (Map) optionsObj; + if (options.containsKey("k")) { + k = ((Number) options.get("k")).intValue(); + } + } else if (optionsObj instanceof RetrieverOptions) { + k = ((RetrieverOptions) optionsObj).getK(); + } + } + + // Score all documents by similarity + List scoredDocs = new ArrayList<>(); + for (DbValue dbValue : data.values()) { + double score = cosineSimilarity(queryEmbedding, dbValue.getEmbedding()); + scoredDocs.add(new ScoredDocument(score, dbValue.getDoc())); + } + + // Sort by score descending + scoredDocs.sort((a, b) -> Double.compare(b.score, a.score)); + + // Return top k documents + List results = new ArrayList<>(); + for (int i = 0; i < Math.min(k, scoredDocs.size()); i++) { + results.add(scoredDocs.get(i).doc); + } + + logger.debug("Retrieved {} documents for query", results.size()); + return new RetrieverResponse(results); + + } catch (Exception e) { + throw new GenkitException("Failed to retrieve documents: " + e.getMessage(), e); + } + } + + /** + * Computes the MD5 hash of a document for deduplication. + */ + private String computeDocumentId(Document doc) { + try { + MessageDigest md = MessageDigest.getInstance("MD5"); + String content = objectMapper.writeValueAsString(doc); + byte[] digest = md.digest(content.getBytes()); + StringBuilder sb = new StringBuilder(); + for (byte b : digest) { + sb.append(String.format("%02x", b)); + } + return sb.toString(); + } catch (NoSuchAlgorithmException | IOException e) { + throw new GenkitException("Failed to compute document ID", e); + } + } + + /** + * Converts a float array to a List of Float. + */ + private List floatArrayToList(float[] arr) { + List list = new ArrayList<>(arr.length); + for (float f : arr) { + list.add(f); + } + return list; + } + + /** + * Computes cosine similarity between two vectors. + */ + private double cosineSimilarity(List vec1, List vec2) { + if (vec1.size() != vec2.size()) { + throw new IllegalArgumentException("Vectors must have same length"); + } + + double dotProduct = 0.0; + double norm1 = 0.0; + double norm2 = 0.0; + + for (int i = 0; i < vec1.size(); i++) { + float v1 = vec1.get(i); + float v2 = vec2.get(i); + dotProduct += v1 * v2; + norm1 += v1 * v1; + norm2 += v2 * v2; + } + + double denominator = Math.sqrt(norm1) * Math.sqrt(norm2); + if (denominator == 0) { + return 0.0; + } + + return dotProduct / denominator; + } + + /** + * Creates a retriever action for this document store. + * + * @return the retriever + */ + public Retriever createRetriever() { + String name = LocalVecPlugin.PROVIDER + "/" + config.getIndexName(); + return Retriever.builder().name(name).handler((ctx, request) -> retrieve(ctx, request)).build(); + } + + /** + * Creates an indexer action for this document store. + * + * @return the indexer + */ + public Indexer createIndexer() { + String name = LocalVecPlugin.PROVIDER + "/" + config.getIndexName(); + return Indexer.builder().name(name).handler((ctx, request) -> { + index(ctx, request.getDocuments()); + return new IndexerResponse(); + }).build(); + } + + /** + * Gets the number of documents in the store. + * + * @return the document count + */ + public int size() { + return data.size(); + } + + /** + * Clears all documents from the store. + */ + public void clear() { + data.clear(); + saveToFile(); + } + + /** + * Internal class for scoring documents. + */ + private static class ScoredDocument { + final double score; + final Document doc; + + ScoredDocument(double score, Document doc) { + this.score = score; + this.doc = doc; + } + } +} diff --git a/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/LocalVecPlugin.java b/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/LocalVecPlugin.java new file mode 100644 index 0000000000..c5dde65732 --- /dev/null +++ b/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/LocalVecPlugin.java @@ -0,0 +1,194 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.localvec; + +import java.util.ArrayList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.ai.*; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.Plugin; +import com.google.genkit.core.Registry; + +/** + * Local file-based vector store plugin for development and testing. + * + *

+ * This plugin provides a simple file-based vector store implementation suitable + * for local development and testing. It stores document embeddings in JSON + * files and performs similarity search using cosine similarity. + * + *

+ * NOT INTENDED FOR PRODUCTION USE. + * + *

+ * Example usage with embedder name (recommended): + * + *

{@code
+ * Genkit genkit = Genkit.builder().plugin(OpenAIPlugin.create())
+ * 		.plugin(LocalVecPlugin.builder().addStore(
+ * 				LocalVecConfig.builder().indexName("my-docs").embedderName("openai/text-embedding-3-small").build())
+ * 				.build())
+ * 		.build();
+ * }
+ */ +public class LocalVecPlugin implements Plugin { + + private static final Logger logger = LoggerFactory.getLogger(LocalVecPlugin.class); + public static final String PROVIDER = "devLocalVectorStore"; + + private final List configurations; + + private LocalVecPlugin(List configurations) { + this.configurations = configurations; + } + + /** + * Creates a builder for LocalVecPlugin. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + @Override + public String getName() { + return PROVIDER; + } + + @Override + public List> init() { + // For backward compatibility, but prefer init(Registry) + return initializeStores(null); + } + + @Override + public List> init(Registry registry) { + return initializeStores(registry); + } + + private List> initializeStores(Registry registry) { + List> actions = new ArrayList<>(); + + for (LocalVecConfig config : configurations) { + // Resolve embedder by name if needed + if (config.getEmbedder() == null && config.getEmbedderName() != null) { + if (registry == null) { + throw new IllegalStateException( + "Registry is required to resolve embedder by name: " + config.getEmbedderName() + + ". Use init(Registry) or provide an Embedder instance directly."); + } + String embedderKey = ActionType.EMBEDDER.keyFromName(config.getEmbedderName()); + Action embedderAction = registry.lookupAction(embedderKey); + if (embedderAction == null) { + throw new IllegalStateException("Embedder not found: " + config.getEmbedderName() + + ". Make sure the embedder plugin is registered before LocalVecPlugin."); + } + if (!(embedderAction instanceof Embedder)) { + throw new IllegalStateException("Action " + config.getEmbedderName() + " is not an Embedder"); + } + config.setEmbedder((Embedder) embedderAction); + logger.info("Resolved embedder: {} for index: {}", config.getEmbedderName(), config.getIndexName()); + } + + logger.info("Initializing local vector store: {}", config.getIndexName()); + + LocalVecDocStore docStore = new LocalVecDocStore(config); + + // Create and add retriever + Retriever retriever = docStore.createRetriever(); + actions.add(retriever); + + // Create and add indexer + Indexer indexer = docStore.createIndexer(); + actions.add(indexer); + + logger.info("Registered local vector store indexer and retriever: {}", config.getIndexName()); + } + + return actions; + } + + /** + * Builder for LocalVecPlugin. + */ + public static class Builder { + private final List configurations = new ArrayList<>(); + + /** + * Adds a vector store configuration. + * + * @param config + * the configuration + * @return this builder + */ + public Builder addStore(LocalVecConfig config) { + configurations.add(config); + return this; + } + + /** + * Convenience method to add a store with minimal configuration using an + * embedder instance. + * + * @param indexName + * the index name + * @param embedder + * the embedder to use + * @return this builder + */ + public Builder addStore(String indexName, Embedder embedder) { + configurations.add(LocalVecConfig.builder().indexName(indexName).embedder(embedder).build()); + return this; + } + + /** + * Convenience method to add a store with minimal configuration using an + * embedder name. The embedder will be resolved from the registry during plugin + * initialization. + * + * @param indexName + * the index name + * @param embedderName + * the embedder name (e.g., "openai/text-embedding-3-small") + * @return this builder + */ + public Builder addStore(String indexName, String embedderName) { + configurations.add(LocalVecConfig.builder().indexName(indexName).embedderName(embedderName).build()); + return this; + } + + /** + * Builds the plugin. + * + * @return the configured plugin + */ + public LocalVecPlugin build() { + if (configurations.isEmpty()) { + throw new IllegalStateException("At least one store configuration is required"); + } + return new LocalVecPlugin(configurations); + } + } +} diff --git a/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/RetrieverOptions.java b/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/RetrieverOptions.java new file mode 100644 index 0000000000..3eb71deaf2 --- /dev/null +++ b/java/plugins/localvec/src/main/java/com/google/genkit/plugins/localvec/RetrieverOptions.java @@ -0,0 +1,65 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.localvec; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Options for the local vector store retriever. + */ +public class RetrieverOptions { + + @JsonProperty("k") + private int k = 3; + + /** + * Default constructor. + */ + public RetrieverOptions() { + } + + /** + * Creates options with specified k value. + * + * @param k + * the number of documents to retrieve + */ + public RetrieverOptions(int k) { + this.k = k; + } + + /** + * Gets the number of documents to retrieve. + * + * @return the k value + */ + public int getK() { + return k; + } + + /** + * Sets the number of documents to retrieve. + * + * @param k + * the k value + */ + public void setK(int k) { + this.k = k; + } +} diff --git a/java/plugins/mcp/README.md b/java/plugins/mcp/README.md new file mode 100644 index 0000000000..f91ec410c6 --- /dev/null +++ b/java/plugins/mcp/README.md @@ -0,0 +1,229 @@ +# Genkit MCP Plugin + +This plugin enables Genkit to integrate with [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) servers, allowing you to use MCP tools and resources in your Genkit applications. + +## Features + +- Connect to multiple MCP servers simultaneously +- Support for STDIO transport (local processes) and HTTP/SSE transport (remote servers) +- Automatic conversion of MCP tools to Genkit tools +- Access MCP resources programmatically +- Seamless integration with Genkit's AI model workflows + +## Installation + +Add the MCP plugin dependency to your `pom.xml`: + +```xml + + com.google.genkit + genkit-plugin-mcp + ${genkit.version} + +``` + +## Quick Start + +### Basic Usage with STDIO Server + +```java +import com.google.genkit.Genkit; +import com.google.genkit.plugins.mcp.MCPPlugin; +import com.google.genkit.plugins.mcp.MCPPluginOptions; +import com.google.genkit.plugins.mcp.MCPServerConfig; + +// Create MCP plugin with a filesystem server +MCPPlugin mcpPlugin = MCPPlugin.create(MCPPluginOptions.builder() + .name("my-mcp-host") + .addServer("filesystem", MCPServerConfig.stdio( + "npx", "-y", "@modelcontextprotocol/server-filesystem", "/tmp")) + .build()); + +// Create Genkit with the MCP plugin +Genkit genkit = Genkit.builder() + .plugin(mcpPlugin) + .build(); + +// MCP tools are now available as Genkit tools +List> mcpTools = mcpPlugin.getTools(); +``` + +### Using MCP Tools with AI Models + +```java +import com.google.genkit.ai.GenerateOptions; +import com.google.genkit.ai.ModelResponse; + +// Use MCP tools in AI-powered flows +Flow assistantFlow = genkit.defineFlow( + "fileAssistant", String.class, String.class, + (ctx, userRequest) -> { + ModelResponse response = genkit.generate(GenerateOptions.builder() + .model("openai/gpt-4o") + .system("You are a helpful file assistant.") + .prompt(userRequest) + .tools(mcpPlugin.getTools()) + .build()); + return response.getText(); + }); +``` + +### HTTP Server Connection + +```java +MCPPlugin mcpPlugin = MCPPlugin.create(MCPPluginOptions.builder() + .addServer("weather", MCPServerConfig.http("http://localhost:3001/mcp")) + .build()); +``` + +## Server Configuration + +### STDIO Transport + +Used for running local MCP servers as child processes: + +```java +MCPServerConfig config = MCPServerConfig.builder() + .command("npx") + .args("-y", "@modelcontextprotocol/server-filesystem", "/allowed/path") + .env("SOME_VAR", "value") // Optional environment variables + .build(); +``` + +### HTTP Transport + +Used for connecting to remote MCP servers: + +```java +MCPServerConfig config = MCPServerConfig.http("http://localhost:3001/mcp"); + +// Or with builder for more options +MCPServerConfig config = MCPServerConfig.builder() + .url("http://localhost:3001/mcp") + .transportType(MCPServerConfig.TransportType.HTTP) + .build(); +``` + +### Streamable HTTP Transport + +```java +MCPServerConfig config = MCPServerConfig.streamableHttp("http://localhost:3001/mcp"); +``` + +## Plugin Options + +```java +MCPPluginOptions options = MCPPluginOptions.builder() + .name("my-mcp-host") // Host name for identification + .addServer("server1", config1) // Add servers + .addServer("server2", config2) + .requestTimeout(Duration.ofSeconds(30)) // Request timeout + .rawToolResponses(false) // Process tool responses + .build(); +``` + +## Direct Tool Invocation + +Call MCP tools directly without going through AI: + +```java +// Call a specific tool +Object result = mcpPlugin.callTool("filesystem", "read_file", + Map.of("path", "/tmp/myfile.txt")); + +// Get tools from a specific server +List> filesystemTools = mcpPlugin.getTools("filesystem"); +``` + +## Resource Access + +Access MCP resources programmatically: + +```java +// List resources from a server +List resources = mcpPlugin.getResources("filesystem"); + +// Read a resource +MCPResourceContent content = mcpPlugin.readResource("filesystem", "file:///tmp/data.txt"); +String text = content.getText(); +``` + +## Popular MCP Servers + +Here are some commonly used MCP servers you can connect to: + +| Server | Package | Description | +|--------|---------|-------------| +| Filesystem | `@modelcontextprotocol/server-filesystem` | File operations (read, write, list) | +| Everything | `@modelcontextprotocol/server-everything` | Demo server with various tools | +| Git | `@modelcontextprotocol/server-git` | Git repository operations | +| GitHub | `@modelcontextprotocol/server-github` | GitHub API access | +| Postgres | `@modelcontextprotocol/server-postgres` | PostgreSQL database access | +| Slack | `@modelcontextprotocol/server-slack` | Slack messaging | +| Memory | `@modelcontextprotocol/server-memory` | Knowledge graph memory | + +## Example: Multi-Server Setup + +```java +MCPPluginOptions options = MCPPluginOptions.builder() + .name("multi-server-host") + .addServer("files", MCPServerConfig.stdio( + "npx", "-y", "@modelcontextprotocol/server-filesystem", "/data")) + .addServer("git", MCPServerConfig.builder() + .command("npx") + .args("-y", "@modelcontextprotocol/server-git") + .env("GIT_AUTHOR_NAME", "Genkit User") + .build()) + .addServer("github", MCPServerConfig.builder() + .command("npx") + .args("-y", "@modelcontextprotocol/server-github") + .env("GITHUB_TOKEN", System.getenv("GITHUB_TOKEN")) + .build()) + .build(); +``` + +## Cleanup + +The plugin manages connections automatically, but you can manually disconnect: + +```java +// Disconnect all servers +mcpPlugin.disconnect(); + +// Or add a shutdown hook +Runtime.getRuntime().addShutdownHook(new Thread(() -> { + mcpPlugin.disconnect(); +})); +``` + +## Error Handling + +```java +try { + Object result = mcpPlugin.callTool("filesystem", "read_file", + Map.of("path", "/nonexistent/file")); +} catch (GenkitException e) { + logger.error("MCP tool call failed: {}", e.getMessage()); +} +``` + +## Logging + +The plugin uses SLF4J for logging. Configure your logging framework to see MCP-related logs: + +```xml + + +``` + +## Requirements + +- Java 17 or later +- Node.js and npm (for running MCP server packages via npx) +- Network access for HTTP-based MCP servers + +## See Also + +- [MCP Protocol Documentation](https://modelcontextprotocol.io/) +- [MCP Java SDK](https://github.com/modelcontextprotocol/java-sdk) +- [Available MCP Servers](https://github.com/modelcontextprotocol/servers) diff --git a/java/plugins/mcp/pom.xml b/java/plugins/mcp/pom.xml new file mode 100644 index 0000000000..4f70565c61 --- /dev/null +++ b/java/plugins/mcp/pom.xml @@ -0,0 +1,80 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + genkit-plugin-mcp + jar + Genkit MCP Plugin + Model Context Protocol (MCP) plugin for Genkit - enables integration with MCP servers and tools + + + 0.17.0 + + + + + + com.google.genkit + genkit + ${project.version} + + + + + io.modelcontextprotocol.sdk + mcp + ${mcp.sdk.version} + + + + + com.fasterxml.jackson.core + jackson-databind + + + + + org.slf4j + slf4j-api + + + + + org.junit.jupiter + junit-jupiter + test + + + org.mockito + mockito-core + test + + + diff --git a/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPClient.java b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPClient.java new file mode 100644 index 0000000000..f97c205a4d --- /dev/null +++ b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPClient.java @@ -0,0 +1,427 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.mcp; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genkit.ai.Tool; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.Registry; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.ServerParameters; +import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.ListResourcesResult; +import io.modelcontextprotocol.spec.McpSchema.ListToolsResult; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; + +/** + * MCP Client that manages connections to MCP servers and provides access to + * their tools and resources. + * + *

+ * This client wraps the MCP Java SDK and converts MCP tools to Genkit tools, + * allowing them to be used seamlessly in Genkit applications. + * + *

+ * Example usage: + * + *

{@code
+ * MCPClient client = new MCPClient("filesystem",
+ * 		MCPServerConfig.stdio("npx", "-y", "@modelcontextprotocol/server-filesystem", "/tmp"),
+ * 		Duration.ofSeconds(30), false);
+ *
+ * client.connect();
+ * List> tools = client.getTools(registry);
+ * client.disconnect();
+ * }
+ */ +public class MCPClient { + + private static final Logger logger = LoggerFactory.getLogger(MCPClient.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private final String serverName; + private final MCPServerConfig config; + private final Duration requestTimeout; + private final boolean rawToolResponses; + + private McpSyncClient client; + private McpClientTransport transport; + private boolean connected = false; + + // Cache for tools and resources + private final Map> toolCache = new ConcurrentHashMap<>(); + + /** + * Creates a new MCP client. + * + * @param serverName + * the name to identify this server + * @param config + * the server configuration + * @param requestTimeout + * timeout for requests + * @param rawToolResponses + * whether to return raw MCP responses + */ + public MCPClient(String serverName, MCPServerConfig config, Duration requestTimeout, boolean rawToolResponses) { + this.serverName = serverName; + this.config = config; + this.requestTimeout = requestTimeout; + this.rawToolResponses = rawToolResponses; + } + + /** + * Connects to the MCP server. + * + * @throws GenkitException + * if connection fails + */ + public void connect() throws GenkitException { + if (connected) { + logger.debug("Already connected to MCP server: {}", serverName); + return; + } + + if (config.isDisabled()) { + logger.info("MCP server {} is disabled, skipping connection", serverName); + return; + } + + try { + logger.info("Connecting to MCP server: {}", serverName); + transport = createTransport(); + + client = McpClient.sync(transport).requestTimeout(requestTimeout) + .capabilities(ClientCapabilities.builder().roots(true).build()).build(); + + client.initialize(); + connected = true; + logger.info("Successfully connected to MCP server: {}", serverName); + } catch (Exception e) { + throw new GenkitException("Failed to connect to MCP server: " + serverName, e); + } + } + + /** + * Disconnects from the MCP server. + */ + public void disconnect() { + if (!connected || client == null) { + return; + } + + try { + logger.info("Disconnecting from MCP server: {}", serverName); + client.closeGracefully(); + connected = false; + toolCache.clear(); + logger.info("Disconnected from MCP server: {}", serverName); + } catch (Exception e) { + logger.warn("Error disconnecting from MCP server {}: {}", serverName, e.getMessage()); + } + } + + /** + * Gets tools from the MCP server as Genkit tools. + * + * @param registry + * the Genkit registry for tool registration + * @return list of Genkit tools + * @throws GenkitException + * if listing tools fails + */ + public List> getTools(Registry registry) throws GenkitException { + if (!connected) { + throw new GenkitException("Not connected to MCP server: " + serverName); + } + + List> tools = new ArrayList<>(); + + try { + ListToolsResult result = client.listTools(); + + for (McpSchema.Tool mcpTool : result.tools()) { + Tool tool = createGenkitTool(mcpTool, registry); + tools.add(tool); + toolCache.put(mcpTool.name(), tool); + } + + logger.info("Loaded {} tools from MCP server: {}", tools.size(), serverName); + } catch (Exception e) { + throw new GenkitException("Failed to list tools from MCP server: " + serverName, e); + } + + return tools; + } + + /** + * Gets resources from the MCP server. + * + * @return list of MCP resources + * @throws GenkitException + * if listing resources fails + */ + public List getResources() throws GenkitException { + if (!connected) { + throw new GenkitException("Not connected to MCP server: " + serverName); + } + + List resources = new ArrayList<>(); + + try { + ListResourcesResult result = client.listResources(); + + for (McpSchema.Resource mcpResource : result.resources()) { + MCPResource resource = new MCPResource(mcpResource.uri(), mcpResource.name(), + mcpResource.description() != null ? mcpResource.description() : "", mcpResource.mimeType()); + resources.add(resource); + } + + logger.info("Loaded {} resources from MCP server: {}", resources.size(), serverName); + } catch (Exception e) { + throw new GenkitException("Failed to list resources from MCP server: " + serverName, e); + } + + return resources; + } + + /** + * Reads a resource by URI. + * + * @param uri + * the resource URI + * @return the resource content + * @throws GenkitException + * if reading fails + */ + public MCPResourceContent readResource(String uri) throws GenkitException { + if (!connected) { + throw new GenkitException("Not connected to MCP server: " + serverName); + } + + try { + ReadResourceResult result = client.readResource(new McpSchema.ReadResourceRequest(uri)); + + List parts = new ArrayList<>(); + for (McpSchema.ResourceContents content : result.contents()) { + if (content instanceof McpSchema.TextResourceContents textContent) { + parts.add(new MCPResourceContent.ContentPart(textContent.text(), null, content.mimeType())); + } else if (content instanceof McpSchema.BlobResourceContents blobContent) { + parts.add(new MCPResourceContent.ContentPart(null, blobContent.blob(), content.mimeType())); + } + } + + return new MCPResourceContent(uri, parts); + } catch (Exception e) { + throw new GenkitException("Failed to read resource: " + uri, e); + } + } + + /** + * Calls an MCP tool directly. + * + * @param toolName + * the tool name + * @param arguments + * the tool arguments + * @return the tool result + * @throws GenkitException + * if the call fails + */ + @SuppressWarnings("unchecked") + public Object callTool(String toolName, Map arguments) throws GenkitException { + if (!connected) { + throw new GenkitException("Not connected to MCP server: " + serverName); + } + + try { + logger.debug("Calling MCP tool {}/{} with arguments: {}", serverName, toolName, arguments); + + CallToolResult result = client.callTool(new McpSchema.CallToolRequest(toolName, arguments)); + + if (rawToolResponses) { + return result; + } + + return processToolResult(result); + } catch (Exception e) { + throw new GenkitException("Failed to call MCP tool: " + toolName, e); + } + } + + /** + * Gets the server name. + * + * @return the server name + */ + public String getServerName() { + return serverName; + } + + /** + * Checks if connected to the MCP server. + * + * @return true if connected + */ + public boolean isConnected() { + return connected; + } + + // Private methods + + private McpClientTransport createTransport() { + McpJsonMapper jsonMapper = McpJsonMapper.getDefault(); + + switch (config.getTransportType()) { + case STDIO : + ServerParameters.Builder paramsBuilder = ServerParameters.builder(config.getCommand()); + if (!config.getArgs().isEmpty()) { + paramsBuilder.args(config.getArgs().toArray(new String[0])); + } + if (!config.getEnv().isEmpty()) { + paramsBuilder.env(config.getEnv()); + } + return new StdioClientTransport(paramsBuilder.build(), jsonMapper); + + case HTTP : + return HttpClientSseClientTransport.builder(config.getUrl()).build(); + + case STREAMABLE_HTTP : + return HttpClientStreamableHttpTransport.builder(config.getUrl()).build(); + + default : + throw new IllegalArgumentException("Unsupported transport type: " + config.getTransportType()); + } + } + + @SuppressWarnings("unchecked") + private Tool, Object> createGenkitTool(McpSchema.Tool mcpTool, Registry registry) { + String toolName = serverName + "/" + mcpTool.name(); + + // Convert MCP input schema to Map + Map inputSchema = convertJsonSchema(mcpTool.inputSchema()); + + Tool, Object> tool = Tool., Object>builder().name(toolName) + .description(mcpTool.description() != null ? mcpTool.description() : "") + .inputSchema(inputSchema != null ? inputSchema : new HashMap<>()) + .inputClass((Class>) (Class) Map.class).handler((ctx, input) -> { + return callTool(mcpTool.name(), input); + }).build(); + + // Register the tool + tool.register(registry); + + logger.debug("Created Genkit tool: {} from MCP server: {}", toolName, serverName); + return tool; + } + + @SuppressWarnings("unchecked") + private Map convertJsonSchema(Object schema) { + if (schema == null) { + return new HashMap<>(); + } + if (schema instanceof Map) { + return (Map) schema; + } + try { + JsonNode node = objectMapper.valueToTree(schema); + return objectMapper.convertValue(node, Map.class); + } catch (Exception e) { + logger.warn("Failed to convert schema: {}", e.getMessage()); + return new HashMap<>(); + } + } + + private Object processToolResult(CallToolResult result) { + if (result.isError() != null && result.isError()) { + StringBuilder errorText = new StringBuilder(); + for (McpSchema.Content content : result.content()) { + if (content instanceof McpSchema.TextContent textContent) { + errorText.append(textContent.text()); + } + } + return Map.of("error", errorText.toString()); + } + + // Check if all content is text + boolean allText = result.content().stream().allMatch(c -> c instanceof McpSchema.TextContent); + + if (allText) { + StringBuilder text = new StringBuilder(); + for (McpSchema.Content content : result.content()) { + if (content instanceof McpSchema.TextContent textContent) { + text.append(textContent.text()); + } + } + String textResult = text.toString(); + + // Try to parse as JSON + if (textResult.trim().startsWith("{") || textResult.trim().startsWith("[")) { + try { + return objectMapper.readValue(textResult, Object.class); + } catch (Exception e) { + // Return as plain text + } + } + return textResult; + } + + // Return first content item or the whole result + if (result.content().size() == 1) { + McpSchema.Content content = result.content().get(0); + if (content instanceof McpSchema.TextContent textContent) { + return textContent.text(); + } else if (content instanceof McpSchema.ImageContent imageContent) { + return Map.of("type", "image", "data", imageContent.data(), "mimeType", imageContent.mimeType()); + } + } + + // Return raw content list + List> contentList = new ArrayList<>(); + for (McpSchema.Content content : result.content()) { + if (content instanceof McpSchema.TextContent textContent) { + contentList.add(Map.of("type", "text", "text", textContent.text())); + } else if (content instanceof McpSchema.ImageContent imageContent) { + contentList + .add(Map.of("type", "image", "data", imageContent.data(), "mimeType", imageContent.mimeType())); + } + } + return contentList; + } +} diff --git a/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPPlugin.java b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPPlugin.java new file mode 100644 index 0000000000..777b747302 --- /dev/null +++ b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPPlugin.java @@ -0,0 +1,330 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.mcp; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.ai.Tool; +import com.google.genkit.core.Action; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.Plugin; +import com.google.genkit.core.Registry; + +/** + * MCP Plugin for Genkit. + * + *

+ * This plugin enables Genkit to connect to MCP (Model Context Protocol) servers + * and use their tools. MCP provides a standardized way for AI applications to + * interact with external tools and data sources. + * + *

+ * Features: + *

    + *
  • Connect to multiple MCP servers (STDIO or HTTP transports)
  • + *
  • Automatic conversion of MCP tools to Genkit tools
  • + *
  • Access MCP resources programmatically
  • + *
  • Support for tool caching and lazy loading
  • + *
+ * + *

+ * Example usage: + * + *

{@code
+ * // Create the MCP plugin with server configurations
+ * MCPPlugin mcpPlugin = MCPPlugin.create(MCPPluginOptions.builder().name("my-mcp-host")
+ * 		.addServer("filesystem",
+ * 				MCPServerConfig.stdio("npx", "-y", "@modelcontextprotocol/server-filesystem", "/tmp"))
+ * 		.addServer("weather", MCPServerConfig.http("http://localhost:3001/mcp")).build());
+ *
+ * // Create Genkit with the MCP plugin
+ * Genkit genkit = Genkit.builder().plugin(mcpPlugin).build();
+ *
+ * // Use MCP tools in flows
+ * Flow myFlow = genkit.defineFlow("myFlow", String.class, String.class, (ctx, input) -> {
+ * 	// MCP tools are available as: "serverName/toolName"
+ * 	// e.g., "filesystem/readFile", "weather/getForecast"
+ * 	ModelResponse response = genkit.generate(
+ * 			GenerateOptions.builder().model("openai/gpt-4o").prompt(input).tools(mcpPlugin.getTools()).build());
+ * 	return response.getText();
+ * });
+ * }
+ * + * @see MCPPluginOptions + * @see MCPServerConfig + * @see MCPClient + */ +public class MCPPlugin implements Plugin { + + private static final Logger logger = LoggerFactory.getLogger(MCPPlugin.class); + + private final MCPPluginOptions options; + private final Map clients = new ConcurrentHashMap<>(); + private final List> allTools = new ArrayList<>(); + private Registry registry; + private boolean initialized = false; + + /** + * Creates a new MCP plugin with the given options. + * + * @param options + * the plugin options + */ + public MCPPlugin(MCPPluginOptions options) { + this.options = options; + } + + /** + * Creates an MCP plugin with the given options. + * + * @param options + * the plugin options + * @return a new MCPPlugin + */ + public static MCPPlugin create(MCPPluginOptions options) { + return new MCPPlugin(options); + } + + /** + * Creates an MCP plugin with a single STDIO server. + * + * @param serverName + * the name for the server + * @param command + * the command to execute + * @param args + * the command arguments + * @return a new MCPPlugin + */ + public static MCPPlugin stdio(String serverName, String command, String... args) { + return create(MCPPluginOptions.builder().addServer(serverName, MCPServerConfig.stdio(command, args)).build()); + } + + /** + * Creates an MCP plugin with a single HTTP server. + * + * @param serverName + * the name for the server + * @param url + * the server URL + * @return a new MCPPlugin + */ + public static MCPPlugin http(String serverName, String url) { + return create(MCPPluginOptions.builder().addServer(serverName, MCPServerConfig.http(url)).build()); + } + + @Override + public String getName() { + return "mcp"; + } + + @Override + public List> init() { + // This method doesn't have access to the registry, so we return empty + // The actual initialization happens in init(Registry) + return new ArrayList<>(); + } + + @Override + public List> init(Registry registry) { + this.registry = registry; + List> actions = new ArrayList<>(); + + logger.info("Initializing MCP plugin: {}", options.getName()); + + // Connect to all configured servers + for (Map.Entry entry : options.getServers().entrySet()) { + String serverName = entry.getKey(); + MCPServerConfig config = entry.getValue(); + + if (config.isDisabled()) { + logger.info("MCP server {} is disabled, skipping", serverName); + continue; + } + + try { + MCPClient client = new MCPClient(serverName, config, options.getRequestTimeout(), + options.isRawToolResponses()); + + client.connect(); + clients.put(serverName, client); + + // Load tools from this server + List> tools = client.getTools(registry); + allTools.addAll(tools); + actions.addAll(tools); + + logger.info("Connected to MCP server {} with {} tools", serverName, tools.size()); + } catch (Exception e) { + logger.error("Failed to connect to MCP server {}: {}", serverName, e.getMessage()); + // Continue with other servers + } + } + + initialized = true; + logger.info("MCP plugin initialized with {} servers and {} total tools", clients.size(), allTools.size()); + + return actions; + } + + /** + * Gets all tools from all connected MCP servers. + * + * @return list of tools + */ + public List> getTools() { + return new ArrayList<>(allTools); + } + + /** + * Gets tools from a specific MCP server. + * + * @param serverName + * the server name + * @return list of tools from that server + * @throws GenkitException + * if the server is not found or not connected + */ + public List> getTools(String serverName) throws GenkitException { + MCPClient client = clients.get(serverName); + if (client == null) { + throw new GenkitException("MCP server not found: " + serverName); + } + return client.getTools(registry); + } + + /** + * Gets resources from a specific MCP server. + * + * @param serverName + * the server name + * @return list of resources + * @throws GenkitException + * if the server is not found or not connected + */ + public List getResources(String serverName) throws GenkitException { + MCPClient client = clients.get(serverName); + if (client == null) { + throw new GenkitException("MCP server not found: " + serverName); + } + return client.getResources(); + } + + /** + * Reads a resource from an MCP server. + * + * @param serverName + * the server name + * @param uri + * the resource URI + * @return the resource content + * @throws GenkitException + * if reading fails + */ + public MCPResourceContent readResource(String serverName, String uri) throws GenkitException { + MCPClient client = clients.get(serverName); + if (client == null) { + throw new GenkitException("MCP server not found: " + serverName); + } + return client.readResource(uri); + } + + /** + * Calls an MCP tool directly. + * + * @param serverName + * the server name + * @param toolName + * the tool name (without server prefix) + * @param arguments + * the tool arguments + * @return the tool result + * @throws GenkitException + * if the call fails + */ + public Object callTool(String serverName, String toolName, Map arguments) throws GenkitException { + MCPClient client = clients.get(serverName); + if (client == null) { + throw new GenkitException("MCP server not found: " + serverName); + } + return client.callTool(toolName, arguments); + } + + /** + * Gets the client for a specific server. + * + * @param serverName + * the server name + * @return the client, or null if not found + */ + public MCPClient getClient(String serverName) { + return clients.get(serverName); + } + + /** + * Gets all connected clients. + * + * @return map of server name to client + */ + public Map getClients() { + return new HashMap<>(clients); + } + + /** + * Disconnects all MCP clients. + */ + public void disconnect() { + logger.info("Disconnecting all MCP clients"); + for (MCPClient client : clients.values()) { + try { + client.disconnect(); + } catch (Exception e) { + logger.warn("Error disconnecting MCP client {}: {}", client.getServerName(), e.getMessage()); + } + } + clients.clear(); + allTools.clear(); + initialized = false; + } + + /** + * Checks if the plugin is initialized. + * + * @return true if initialized + */ + public boolean isInitialized() { + return initialized; + } + + /** + * Gets the plugin options. + * + * @return the options + */ + public MCPPluginOptions getOptions() { + return options; + } +} diff --git a/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPPluginOptions.java b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPPluginOptions.java new file mode 100644 index 0000000000..0c3b42087c --- /dev/null +++ b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPPluginOptions.java @@ -0,0 +1,192 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.mcp; + +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Configuration options for the MCP plugin. + * + *

+ * This class allows configuration of MCP server connections, including: + *

    + *
  • Multiple MCP servers with different transports (STDIO, HTTP)
  • + *
  • Connection timeouts and retry settings
  • + *
  • Raw tool response handling
  • + *
+ * + *

+ * Example usage: + * + *

{@code
+ * MCPPluginOptions options = MCPPluginOptions.builder().name("my-mcp-host")
+ * 		.addServer("filesystem",
+ * 				MCPServerConfig.stdio("npx", "-y", "@modelcontextprotocol/server-filesystem", "/tmp"))
+ * 		.addServer("weather", MCPServerConfig.http("http://localhost:3001/mcp"))
+ * 		.requestTimeout(Duration.ofSeconds(30)).build();
+ * }
+ */ +public class MCPPluginOptions { + + private final String name; + private final Map servers; + private final Duration requestTimeout; + private final boolean rawToolResponses; + + private MCPPluginOptions(Builder builder) { + this.name = builder.name; + this.servers = Collections.unmodifiableMap(new HashMap<>(builder.servers)); + this.requestTimeout = builder.requestTimeout; + this.rawToolResponses = builder.rawToolResponses; + } + + /** + * Creates a new builder for MCPPluginOptions. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the name of the MCP host. + * + * @return the host name + */ + public String getName() { + return name; + } + + /** + * Gets the configured MCP servers. + * + * @return map of server name to configuration + */ + public Map getServers() { + return servers; + } + + /** + * Gets the request timeout. + * + * @return the request timeout + */ + public Duration getRequestTimeout() { + return requestTimeout; + } + + /** + * Whether to return raw MCP tool responses. + * + * @return true if raw responses should be returned + */ + public boolean isRawToolResponses() { + return rawToolResponses; + } + + /** + * Builder for MCPPluginOptions. + */ + public static class Builder { + + private String name = "genkit-mcp"; + private Map servers = new HashMap<>(); + private Duration requestTimeout = Duration.ofSeconds(30); + private boolean rawToolResponses = false; + + /** + * Sets the name of the MCP host. + * + * @param name + * the host name + * @return this builder + */ + public Builder name(String name) { + this.name = name; + return this; + } + + /** + * Adds an MCP server configuration. + * + * @param serverName + * the name to identify this server + * @param config + * the server configuration + * @return this builder + */ + public Builder addServer(String serverName, MCPServerConfig config) { + this.servers.put(serverName, config); + return this; + } + + /** + * Sets all MCP server configurations at once. + * + * @param servers + * map of server name to configuration + * @return this builder + */ + public Builder servers(Map servers) { + this.servers = new HashMap<>(servers); + return this; + } + + /** + * Sets the request timeout for MCP operations. + * + * @param timeout + * the timeout duration + * @return this builder + */ + public Builder requestTimeout(Duration timeout) { + this.requestTimeout = timeout; + return this; + } + + /** + * Sets whether to return raw MCP tool responses. + * + *

+ * When true, tool responses are returned in their raw MCP format. When false + * (default), responses are processed for better Genkit compatibility. + * + * @param rawToolResponses + * true to return raw responses + * @return this builder + */ + public Builder rawToolResponses(boolean rawToolResponses) { + this.rawToolResponses = rawToolResponses; + return this; + } + + /** + * Builds the MCPPluginOptions. + * + * @return the built options + */ + public MCPPluginOptions build() { + return new MCPPluginOptions(this); + } + } +} diff --git a/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPResource.java b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPResource.java new file mode 100644 index 0000000000..7a32dcd102 --- /dev/null +++ b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPResource.java @@ -0,0 +1,95 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.mcp; + +/** + * Represents an MCP resource. + * + *

+ * Resources in MCP are data sources that can be read by clients. They have a + * URI, name, optional description, and MIME type. + */ +public class MCPResource { + + private final String uri; + private final String name; + private final String description; + private final String mimeType; + + /** + * Creates a new MCP resource. + * + * @param uri + * the resource URI + * @param name + * the resource name + * @param description + * the resource description + * @param mimeType + * the MIME type + */ + public MCPResource(String uri, String name, String description, String mimeType) { + this.uri = uri; + this.name = name; + this.description = description; + this.mimeType = mimeType; + } + + /** + * Gets the resource URI. + * + * @return the URI + */ + public String getUri() { + return uri; + } + + /** + * Gets the resource name. + * + * @return the name + */ + public String getName() { + return name; + } + + /** + * Gets the resource description. + * + * @return the description + */ + public String getDescription() { + return description; + } + + /** + * Gets the MIME type. + * + * @return the MIME type + */ + public String getMimeType() { + return mimeType; + } + + @Override + public String toString() { + return "MCPResource{" + "uri='" + uri + '\'' + ", name='" + name + '\'' + ", description='" + description + '\'' + + ", mimeType='" + mimeType + '\'' + '}'; + } +} diff --git a/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPResourceContent.java b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPResourceContent.java new file mode 100644 index 0000000000..a7eb9eba50 --- /dev/null +++ b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPResourceContent.java @@ -0,0 +1,154 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.mcp; + +import java.util.List; + +/** + * Represents the content of an MCP resource. + * + *

+ * Resource content can contain multiple parts, each of which can be text or + * binary data. + */ +public class MCPResourceContent { + + private final String uri; + private final List parts; + + /** + * Creates a new resource content. + * + * @param uri + * the resource URI + * @param parts + * the content parts + */ + public MCPResourceContent(String uri, List parts) { + this.uri = uri; + this.parts = parts; + } + + /** + * Gets the resource URI. + * + * @return the URI + */ + public String getUri() { + return uri; + } + + /** + * Gets the content parts. + * + * @return the parts + */ + public List getParts() { + return parts; + } + + /** + * Gets the text content if this resource has a single text part. + * + * @return the text content, or null if not available + */ + public String getText() { + if (parts.isEmpty()) { + return null; + } + StringBuilder text = new StringBuilder(); + for (ContentPart part : parts) { + if (part.getText() != null) { + text.append(part.getText()); + } + } + return text.length() > 0 ? text.toString() : null; + } + + /** + * Represents a single part of resource content. + */ + public static class ContentPart { + + private final String text; + private final String blob; + private final String mimeType; + + /** + * Creates a new content part. + * + * @param text + * the text content (or null for binary) + * @param blob + * the base64-encoded binary content (or null for text) + * @param mimeType + * the MIME type + */ + public ContentPart(String text, String blob, String mimeType) { + this.text = text; + this.blob = blob; + this.mimeType = mimeType; + } + + /** + * Gets the text content. + * + * @return the text, or null if binary + */ + public String getText() { + return text; + } + + /** + * Gets the base64-encoded binary content. + * + * @return the blob, or null if text + */ + public String getBlob() { + return blob; + } + + /** + * Gets the MIME type. + * + * @return the MIME type + */ + public String getMimeType() { + return mimeType; + } + + /** + * Checks if this part is text. + * + * @return true if text + */ + public boolean isText() { + return text != null; + } + + /** + * Checks if this part is binary. + * + * @return true if binary + */ + public boolean isBinary() { + return blob != null; + } + } +} diff --git a/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPServer.java b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPServer.java new file mode 100644 index 0000000000..a93a03c79e --- /dev/null +++ b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPServer.java @@ -0,0 +1,287 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.mcp; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.ai.Tool; +import com.google.genkit.core.Action; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.ActionType; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.Registry; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * MCP Server that exposes Genkit tools, prompts, and flows as MCP endpoints. + * + *

+ * This server allows external MCP clients (like Claude Desktop, or other AI + * agents) to discover and invoke Genkit tools. + * + *

+ * Example usage: + * + *

{@code
+ * // Create Genkit and define some tools
+ * Genkit genkit = Genkit.builder().build();
+ *
+ * genkit.defineTool("calculator", "Performs basic math",
+ *     Map.of("type", "object", ...),
+ *     Map.class,
+ *     (ctx, input) -> {
+ *         // Tool implementation
+ *         return result;
+ *     });
+ *
+ * // Create and start MCP server
+ * MCPServer mcpServer = new MCPServer(genkit.getRegistry(),
+ *     MCPServerOptions.builder()
+ *         .name("my-tools-server")
+ *         .version("1.0.0")
+ *         .build());
+ *
+ * // Start with STDIO transport (for use with Claude Desktop, etc.)
+ * mcpServer.start();
+ * }
+ * + * @see MCPServerOptions + */ +public class MCPServer { + + private static final Logger logger = LoggerFactory.getLogger(MCPServer.class); + private static final McpJsonMapper jsonMapper = McpJsonMapper.getDefault(); + + private final Registry registry; + private final MCPServerOptions options; + private McpSyncServer server; + private boolean running = false; + + /** + * Creates a new MCP server. + * + * @param registry + * the Genkit registry containing tools to expose + * @param options + * the server options + */ + public MCPServer(Registry registry, MCPServerOptions options) { + this.registry = registry; + this.options = options; + } + + /** + * Creates a new MCP server with default options. + * + * @param registry + * the Genkit registry containing tools to expose + */ + public MCPServer(Registry registry) { + this(registry, MCPServerOptions.builder().build()); + } + + /** + * Starts the MCP server with STDIO transport. + * + *

+ * This is the standard transport for use with Claude Desktop and other MCP + * clients that launch the server as a subprocess. + * + * @throws GenkitException + * if the server fails to start + */ + public void start() throws GenkitException { + start(new StdioServerTransportProvider(jsonMapper)); + } + + /** + * Starts the MCP server with a custom transport provider. + * + * @param transportProvider + * the transport provider to use + * @throws GenkitException + * if the server fails to start + */ + public void start(McpServerTransportProvider transportProvider) throws GenkitException { + if (running) { + logger.warn("MCP server is already running"); + return; + } + + try { + logger.info("Starting MCP server: {} v{}", options.getName(), options.getVersion()); + + // Build the server with capabilities + server = McpServer.sync(transportProvider).serverInfo(options.getName(), options.getVersion()) + .capabilities(ServerCapabilities.builder().tools(true) // Enable tool support + .resources(false, false) // Resources not yet supported + .prompts(false) // Prompts not yet fully supported + .logging() // Enable logging + .build()) + .build(); + + // Register all tools from the registry + registerTools(); + + running = true; + logger.info("MCP server started successfully with {} tools", getToolCount()); + + } catch (Exception e) { + throw new GenkitException("Failed to start MCP server", e); + } + } + + /** + * Stops the MCP server. + */ + public void stop() { + if (!running || server == null) { + return; + } + + try { + logger.info("Stopping MCP server: {}", options.getName()); + server.close(); + running = false; + logger.info("MCP server stopped"); + } catch (Exception e) { + logger.error("Error stopping MCP server: {}", e.getMessage()); + } + } + + /** + * Checks if the server is running. + * + * @return true if running + */ + public boolean isRunning() { + return running; + } + + /** + * Gets the server options. + * + * @return the options + */ + public MCPServerOptions getOptions() { + return options; + } + + // Private methods + + @SuppressWarnings("unchecked") + private void registerTools() { + // Get all tool actions from the registry + List> actions = registry.listActions(); + + for (Action action : actions) { + // Only register tool actions + if (action.getType() == ActionType.TOOL) { + try { + registerTool((Tool) action); + } catch (Exception e) { + logger.error("Failed to register tool {}: {}", action.getName(), e.getMessage()); + } + } + } + } + + private void registerTool(Tool tool) { + String name = tool.getName(); + String description = tool.getDescription() != null ? tool.getDescription() : ""; + + // Convert input schema to JsonSchema + McpSchema.JsonSchema inputSchema; + try { + Map schema = tool.getInputSchema(); + if (schema == null || schema.isEmpty()) { + inputSchema = new McpSchema.JsonSchema("object", Collections.emptyMap(), null, null, null, null); + } else { + String schemaJson = jsonMapper.writeValueAsString(schema); + inputSchema = jsonMapper.readValue(schemaJson, McpSchema.JsonSchema.class); + } + } catch (Exception e) { + logger.warn("Failed to serialize input schema for tool {}, using empty schema", name); + inputSchema = new McpSchema.JsonSchema("object", Collections.emptyMap(), null, null, null, null); + } + + // Create MCP tool using the builder + McpSchema.Tool mcpTool = McpSchema.Tool.builder().name(name).description(description).inputSchema(inputSchema) + .build(); + + // Create MCP tool specification + McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(mcpTool, + (exchange, arguments) -> { + try { + logger.debug("Executing tool: {} with arguments: {}", name, arguments); + + // Create action context with registry + ActionContext ctx = new ActionContext(registry); + + // Execute the tool + Object result = tool.run(ctx, arguments); + + // Convert result to text content + String resultText; + if (result instanceof String) { + resultText = (String) result; + } else { + resultText = jsonMapper.writeValueAsString(result); + } + + logger.debug("Tool {} completed successfully", name); + + return CallToolResult.builder().addTextContent(resultText).isError(false).build(); + + } catch (Exception e) { + logger.error("Tool {} failed: {}", name, e.getMessage()); + return CallToolResult.builder().addTextContent("Error: " + e.getMessage()).isError(true) + .build(); + } + }); + + server.addTool(toolSpec); + logger.debug("Registered MCP tool: {}", name); + } + + private int getToolCount() { + int count = 0; + List> actions = registry.listActions(); + for (Action action : actions) { + if (action.getType() == ActionType.TOOL) { + count++; + } + } + return count; + } +} diff --git a/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPServerConfig.java b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPServerConfig.java new file mode 100644 index 0000000000..5afb9d1f52 --- /dev/null +++ b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPServerConfig.java @@ -0,0 +1,319 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.mcp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Configuration for connecting to an MCP server. + * + *

+ * Supports two transport types: + *

    + *
  • STDIO: Launches a local process and communicates via standard I/O
  • + *
  • HTTP: Connects to a remote MCP server via HTTP/SSE
  • + *
+ * + *

+ * Example usage: + * + *

{@code
+ * // STDIO transport - launch a local MCP server
+ * MCPServerConfig filesystemServer = MCPServerConfig.stdio("npx", "-y", "@modelcontextprotocol/server-filesystem",
+ * 		"/tmp");
+ *
+ * // HTTP transport - connect to remote server
+ * MCPServerConfig remoteServer = MCPServerConfig.http("http://localhost:3001/mcp");
+ *
+ * // With environment variables
+ * MCPServerConfig serverWithEnv = MCPServerConfig.builder().command("npx")
+ * 		.args("-y", "@modelcontextprotocol/server-github").env("GITHUB_TOKEN", System.getenv("GITHUB_TOKEN"))
+ * 		.build();
+ * }
+ */ +public class MCPServerConfig { + + /** + * Transport type for MCP communication. + */ + public enum TransportType { + /** + * Standard I/O transport - launches a local process. + */ + STDIO, + /** + * HTTP transport with SSE for streaming. + */ + HTTP, + /** + * Streamable HTTP transport. + */ + STREAMABLE_HTTP + } + + private final TransportType transportType; + private final String command; + private final List args; + private final Map env; + private final String url; + private final boolean disabled; + + private MCPServerConfig(Builder builder) { + this.transportType = builder.transportType; + this.command = builder.command; + this.args = Collections.unmodifiableList(new ArrayList<>(builder.args)); + this.env = Collections.unmodifiableMap(new HashMap<>(builder.env)); + this.url = builder.url; + this.disabled = builder.disabled; + } + + /** + * Creates a STDIO server configuration. + * + * @param command + * the command to execute + * @param args + * arguments for the command + * @return the server configuration + */ + public static MCPServerConfig stdio(String command, String... args) { + return builder().command(command).args(args).build(); + } + + /** + * Creates an HTTP server configuration. + * + * @param url + * the server URL + * @return the server configuration + */ + public static MCPServerConfig http(String url) { + return builder().url(url).transportType(TransportType.HTTP).build(); + } + + /** + * Creates a Streamable HTTP server configuration. + * + * @param url + * the server URL + * @return the server configuration + */ + public static MCPServerConfig streamableHttp(String url) { + return builder().url(url).transportType(TransportType.STREAMABLE_HTTP).build(); + } + + /** + * Creates a new builder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the transport type. + * + * @return the transport type + */ + public TransportType getTransportType() { + return transportType; + } + + /** + * Gets the command for STDIO transport. + * + * @return the command, or null for HTTP transport + */ + public String getCommand() { + return command; + } + + /** + * Gets the command arguments. + * + * @return the arguments + */ + public List getArgs() { + return args; + } + + /** + * Gets the environment variables. + * + * @return the environment variables + */ + public Map getEnv() { + return env; + } + + /** + * Gets the URL for HTTP transport. + * + * @return the URL, or null for STDIO transport + */ + public String getUrl() { + return url; + } + + /** + * Whether this server is disabled. + * + * @return true if disabled + */ + public boolean isDisabled() { + return disabled; + } + + /** + * Builder for MCPServerConfig. + */ + public static class Builder { + + private TransportType transportType = TransportType.STDIO; + private String command; + private List args = new ArrayList<>(); + private Map env = new HashMap<>(); + private String url; + private boolean disabled = false; + + /** + * Sets the transport type. + * + * @param transportType + * the transport type + * @return this builder + */ + public Builder transportType(TransportType transportType) { + this.transportType = transportType; + return this; + } + + /** + * Sets the command for STDIO transport. + * + * @param command + * the command to execute + * @return this builder + */ + public Builder command(String command) { + this.command = command; + this.transportType = TransportType.STDIO; + return this; + } + + /** + * Sets the command arguments. + * + * @param args + * the arguments + * @return this builder + */ + public Builder args(String... args) { + this.args = new ArrayList<>(Arrays.asList(args)); + return this; + } + + /** + * Sets the command arguments. + * + * @param args + * the arguments + * @return this builder + */ + public Builder args(List args) { + this.args = new ArrayList<>(args); + return this; + } + + /** + * Adds an environment variable. + * + * @param key + * the variable name + * @param value + * the variable value + * @return this builder + */ + public Builder env(String key, String value) { + this.env.put(key, value); + return this; + } + + /** + * Sets all environment variables. + * + * @param env + * the environment variables + * @return this builder + */ + public Builder env(Map env) { + this.env = new HashMap<>(env); + return this; + } + + /** + * Sets the URL for HTTP transport. + * + * @param url + * the server URL + * @return this builder + */ + public Builder url(String url) { + this.url = url; + if (this.transportType == TransportType.STDIO) { + this.transportType = TransportType.HTTP; + } + return this; + } + + /** + * Sets whether this server is disabled. + * + * @param disabled + * true to disable + * @return this builder + */ + public Builder disabled(boolean disabled) { + this.disabled = disabled; + return this; + } + + /** + * Builds the MCPServerConfig. + * + * @return the built configuration + */ + public MCPServerConfig build() { + if (transportType == TransportType.STDIO && command == null) { + throw new IllegalStateException("Command is required for STDIO transport"); + } + if ((transportType == TransportType.HTTP || transportType == TransportType.STREAMABLE_HTTP) + && url == null) { + throw new IllegalStateException("URL is required for HTTP transport"); + } + return new MCPServerConfig(this); + } + } +} diff --git a/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPServerOptions.java b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPServerOptions.java new file mode 100644 index 0000000000..cd34501e52 --- /dev/null +++ b/java/plugins/mcp/src/main/java/com/google/genkit/plugins/mcp/MCPServerOptions.java @@ -0,0 +1,109 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.mcp; + +/** + * Configuration options for an MCP server. + * + *

+ * Example usage: + * + *

{@code
+ * MCPServerOptions options = MCPServerOptions.builder().name("my-genkit-server").version("1.0.0").build();
+ * }
+ */ +public class MCPServerOptions { + + private final String name; + private final String version; + + private MCPServerOptions(Builder builder) { + this.name = builder.name; + this.version = builder.version; + } + + /** + * Creates a new builder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the server name. + * + * @return the server name + */ + public String getName() { + return name; + } + + /** + * Gets the server version. + * + * @return the server version + */ + public String getVersion() { + return version; + } + + /** + * Builder for MCPServerOptions. + */ + public static class Builder { + + private String name = "genkit-mcp-server"; + private String version = "1.0.0"; + + /** + * Sets the server name. + * + * @param name + * the server name + * @return this builder + */ + public Builder name(String name) { + this.name = name; + return this; + } + + /** + * Sets the server version. + * + * @param version + * the server version + * @return this builder + */ + public Builder version(String version) { + this.version = version; + return this; + } + + /** + * Builds the MCPServerOptions. + * + * @return the built options + */ + public MCPServerOptions build() { + return new MCPServerOptions(this); + } + } +} diff --git a/java/plugins/openai/pom.xml b/java/plugins/openai/pom.xml new file mode 100644 index 0000000000..cfcec92fdd --- /dev/null +++ b/java/plugins/openai/pom.xml @@ -0,0 +1,80 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + genkit-plugin-openai + jar + Genkit OpenAI Plugin + OpenAI model plugin for Genkit + + + + com.google.genkit + genkit + ${project.version} + + + + + com.squareup.okhttp3 + okhttp + + + + + com.squareup.okhttp3 + okhttp-sse + + + + + com.fasterxml.jackson.core + jackson-databind + + + + + org.slf4j + slf4j-api + + + + + org.junit.jupiter + junit-jupiter + test + + + com.squareup.okhttp3 + mockwebserver + test + + + diff --git a/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/ImageGenerationConfig.java b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/ImageGenerationConfig.java new file mode 100644 index 0000000000..ede939c048 --- /dev/null +++ b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/ImageGenerationConfig.java @@ -0,0 +1,250 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.openai; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Configuration options for OpenAI image generation models (DALL-E, gpt-image). + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ImageGenerationConfig { + + /** + * Size of the generated image. For DALL-E 3: "1024x1024", "1792x1024", + * "1024x1792" For DALL-E 2: "256x256", "512x512", "1024x1024" For gpt-image-1: + * "1024x1024", "1536x1024", "1024x1536", "auto" + */ + @JsonProperty("size") + private String size; + + /** + * Image quality: "standard" or "hd" (DALL-E 3 only). For gpt-image-1: "low", + * "medium", "high" + */ + @JsonProperty("quality") + private String quality; + + /** + * Image style: "vivid" or "natural" (DALL-E 3 only). + */ + @JsonProperty("style") + private String style; + + /** + * Number of images to generate (1-10). Default is 1. + */ + @JsonProperty("n") + private Integer n; + + /** + * Response format: "url" or "b64_json". Default is "b64_json". + */ + @JsonProperty("responseFormat") + private String responseFormat; + + /** + * User identifier for abuse monitoring. + */ + @JsonProperty("user") + private String user; + + /** + * Background setting for gpt-image-1: "transparent", "opaque", "auto". + */ + @JsonProperty("background") + private String background; + + /** + * Output format for gpt-image-1: "png", "jpeg", "webp". + */ + @JsonProperty("outputFormat") + private String outputFormat; + + /** + * Output compression for gpt-image-1 (1-100). + */ + @JsonProperty("outputCompression") + private Integer outputCompression; + + /** + * Moderation level for gpt-image-1: "low", "auto". + */ + @JsonProperty("moderation") + private String moderation; + + /** + * Default constructor. + */ + public ImageGenerationConfig() { + } + + /** + * Creates a builder for ImageGenerationConfig. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + // Getters and setters + + public String getSize() { + return size; + } + + public void setSize(String size) { + this.size = size; + } + + public String getQuality() { + return quality; + } + + public void setQuality(String quality) { + this.quality = quality; + } + + public String getStyle() { + return style; + } + + public void setStyle(String style) { + this.style = style; + } + + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + + public String getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + + public String getBackground() { + return background; + } + + public void setBackground(String background) { + this.background = background; + } + + public String getOutputFormat() { + return outputFormat; + } + + public void setOutputFormat(String outputFormat) { + this.outputFormat = outputFormat; + } + + public Integer getOutputCompression() { + return outputCompression; + } + + public void setOutputCompression(Integer outputCompression) { + this.outputCompression = outputCompression; + } + + public String getModeration() { + return moderation; + } + + public void setModeration(String moderation) { + this.moderation = moderation; + } + + /** + * Builder for ImageGenerationConfig. + */ + public static class Builder { + private final ImageGenerationConfig config = new ImageGenerationConfig(); + + public Builder size(String size) { + config.size = size; + return this; + } + + public Builder quality(String quality) { + config.quality = quality; + return this; + } + + public Builder style(String style) { + config.style = style; + return this; + } + + public Builder n(Integer n) { + config.n = n; + return this; + } + + public Builder responseFormat(String responseFormat) { + config.responseFormat = responseFormat; + return this; + } + + public Builder user(String user) { + config.user = user; + return this; + } + + public Builder background(String background) { + config.background = background; + return this; + } + + public Builder outputFormat(String outputFormat) { + config.outputFormat = outputFormat; + return this; + } + + public Builder outputCompression(Integer outputCompression) { + config.outputCompression = outputCompression; + return this; + } + + public Builder moderation(String moderation) { + config.moderation = moderation; + return this; + } + + public ImageGenerationConfig build() { + return config; + } + } +} diff --git a/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIEmbedder.java b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIEmbedder.java new file mode 100644 index 0000000000..916dc6dfca --- /dev/null +++ b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIEmbedder.java @@ -0,0 +1,175 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.openai; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.genkit.ai.*; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +import okhttp3.*; + +/** + * OpenAI embedder implementation for Genkit. + */ +public class OpenAIEmbedder extends Embedder { + + private static final Logger logger = LoggerFactory.getLogger(OpenAIEmbedder.class); + private static final MediaType JSON_MEDIA_TYPE = MediaType.parse("application/json"); + + private final String modelName; + private final OpenAIPluginOptions options; + private final OkHttpClient client; + private final ObjectMapper objectMapper; + + /** + * Creates a new OpenAIEmbedder. + * + * @param modelName + * the model name + * @param options + * the plugin options + */ + public OpenAIEmbedder(String modelName, OpenAIPluginOptions options) { + super("openai/" + modelName, createEmbedderInfo(modelName), (ctx, req) -> { + // This will be overridden by the actual implementation + throw new GenkitException("Handler not initialized"); + }); + this.modelName = modelName; + this.options = options; + this.objectMapper = new ObjectMapper(); + this.client = new OkHttpClient.Builder().connectTimeout(options.getTimeout(), TimeUnit.SECONDS) + .readTimeout(options.getTimeout(), TimeUnit.SECONDS) + .writeTimeout(options.getTimeout(), TimeUnit.SECONDS).build(); + } + + private static EmbedderInfo createEmbedderInfo(String modelName) { + EmbedderInfo info = new EmbedderInfo(); + info.setLabel("OpenAI " + modelName); + + // Set dimensions based on model + switch (modelName) { + case "text-embedding-3-small" : + info.setDimensions(1536); + break; + case "text-embedding-3-large" : + info.setDimensions(3072); + break; + case "text-embedding-ada-002" : + info.setDimensions(1536); + break; + } + + return info; + } + + @Override + public EmbedResponse run(ActionContext context, EmbedRequest request) { + if (request == null) { + throw new GenkitException("Embed request is required. Please provide an input with documents to embed."); + } + if (request.getDocuments() == null || request.getDocuments().isEmpty()) { + throw new GenkitException("Embed request must contain at least one document to embed."); + } + try { + return callOpenAI(request); + } catch (IOException e) { + throw new GenkitException("OpenAI Embedding API call failed", e); + } + } + + private EmbedResponse callOpenAI(EmbedRequest request) throws IOException { + ObjectNode requestBody = objectMapper.createObjectNode(); + requestBody.put("model", modelName); + + // Convert documents to text array + ArrayNode input = requestBody.putArray("input"); + for (Document doc : request.getDocuments()) { + String text = doc.text(); + logger.debug("Document text: '{}' (length: {})", + text != null ? text.substring(0, Math.min(50, text.length())) : "null", + text != null ? text.length() : 0); + if (text == null || text.isEmpty()) { + logger.warn("Document has empty text, skipping"); + continue; + } + input.add(text); + } + + // Validate that we have at least one input + if (input.isEmpty()) { + throw new GenkitException("No valid documents to embed - all documents had empty text"); + } + + // Log the request for debugging + String requestJson = requestBody.toString(); + logger.info("OpenAI Embedding request body: {}", requestJson); + + Request httpRequest = new Request.Builder().url(options.getBaseUrl() + "/embeddings") + .header("Authorization", "Bearer " + options.getApiKey()).header("Content-Type", "application/json") + .post(RequestBody.create(requestBody.toString(), JSON_MEDIA_TYPE)).build(); + + if (options.getOrganization() != null) { + httpRequest = httpRequest.newBuilder().header("OpenAI-Organization", options.getOrganization()).build(); + } + + try (Response response = client.newCall(httpRequest).execute()) { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : "No error body"; + throw new GenkitException("OpenAI Embedding API error: " + response.code() + " - " + errorBody); + } + + String responseBody = response.body().string(); + return parseResponse(responseBody); + } + } + + private EmbedResponse parseResponse(String responseBody) throws IOException { + JsonNode root = objectMapper.readTree(responseBody); + + List embeddings = new ArrayList<>(); + + JsonNode dataNode = root.get("data"); + if (dataNode != null && dataNode.isArray()) { + for (JsonNode item : dataNode) { + JsonNode embeddingNode = item.get("embedding"); + if (embeddingNode != null && embeddingNode.isArray()) { + float[] values = new float[embeddingNode.size()]; + for (int i = 0; i < embeddingNode.size(); i++) { + values[i] = (float) embeddingNode.get(i).asDouble(); + } + embeddings.add(new EmbedResponse.Embedding(values)); + } + } + } + + return new EmbedResponse(embeddings); + } +} diff --git a/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIImageModel.java b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIImageModel.java new file mode 100644 index 0000000000..60e4fc193b --- /dev/null +++ b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIImageModel.java @@ -0,0 +1,295 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.openai; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.genkit.ai.*; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; + +/** + * OpenAI image generation model implementation for Genkit. Supports DALL-E 2, + * DALL-E 3, and gpt-image-1 models. + */ +public class OpenAIImageModel implements Model { + + private static final Logger logger = LoggerFactory.getLogger(OpenAIImageModel.class); + private static final MediaType JSON_MEDIA_TYPE = MediaType.parse("application/json"); + + private final String modelName; + private final OpenAIPluginOptions options; + private final OkHttpClient client; + private final ObjectMapper objectMapper; + private final ModelInfo info; + + /** + * Creates a new OpenAIImageModel. + * + * @param modelName + * the model name (dall-e-2, dall-e-3, or gpt-image-1) + * @param options + * the plugin options + */ + public OpenAIImageModel(String modelName, OpenAIPluginOptions options) { + this.modelName = modelName; + this.options = options; + this.objectMapper = new ObjectMapper(); + this.client = new OkHttpClient.Builder().connectTimeout(options.getTimeout() * 2, TimeUnit.SECONDS) + .readTimeout(options.getTimeout() * 2, TimeUnit.SECONDS) + .writeTimeout(options.getTimeout() * 2, TimeUnit.SECONDS).build(); + this.info = createModelInfo(); + } + + private ModelInfo createModelInfo() { + ModelInfo info = new ModelInfo(); + info.setLabel("OpenAI " + modelName); + + ModelInfo.ModelCapabilities caps = new ModelInfo.ModelCapabilities(); + caps.setMultiturn(false); + caps.setMedia(false); // Image generation models don't accept image inputs + caps.setTools(false); + caps.setSystemRole(false); + caps.setOutput(Set.of("media")); // Outputs media (images) + info.setSupports(caps); + + return info; + } + + @Override + public String getName() { + return "openai/" + modelName; + } + + @Override + public ModelInfo getInfo() { + return info; + } + + @Override + public boolean supportsStreaming() { + return false; // Image generation doesn't support streaming + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request) { + try { + return callOpenAIImages(request); + } catch (IOException e) { + throw new GenkitException("OpenAI Images API call failed", e); + } + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request, Consumer streamCallback) { + // Image generation doesn't support streaming, just call the non-streaming + // version + return run(context, request); + } + + private ModelResponse callOpenAIImages(ModelRequest request) throws IOException { + ObjectNode requestBody = buildRequestBody(request); + + Request httpRequest = new Request.Builder().url(options.getBaseUrl() + "/images/generations") + .header("Authorization", "Bearer " + options.getApiKey()).header("Content-Type", "application/json") + .post(RequestBody.create(requestBody.toString(), JSON_MEDIA_TYPE)).build(); + + if (options.getOrganization() != null) { + httpRequest = httpRequest.newBuilder().header("OpenAI-Organization", options.getOrganization()).build(); + } + + logger.debug("Calling OpenAI Images API with model: {}", modelName); + + try (Response response = client.newCall(httpRequest).execute()) { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : "No error body"; + throw new GenkitException("OpenAI Images API error: " + response.code() + " - " + errorBody); + } + + String responseBody = response.body().string(); + return parseResponse(responseBody); + } + } + + private ObjectNode buildRequestBody(ModelRequest request) { + ObjectNode body = objectMapper.createObjectNode(); + body.put("model", modelName); + + // Extract prompt from messages + String prompt = extractPrompt(request); + body.put("prompt", prompt); + + // Get config - check both Map config and custom config + Map config = request.getConfig(); + + // Default response format to b64_json for data URI support + String responseFormat = "b64_json"; + + if (config != null) { + // Size + if (config.containsKey("size")) { + body.put("size", (String) config.get("size")); + } + + // Quality (DALL-E 3 or gpt-image-1) + if (config.containsKey("quality")) { + body.put("quality", (String) config.get("quality")); + } + + // Style (DALL-E 3) + if (config.containsKey("style")) { + body.put("style", (String) config.get("style")); + } + + // Number of images + if (config.containsKey("n")) { + body.put("n", ((Number) config.get("n")).intValue()); + } + + // Response format + if (config.containsKey("responseFormat")) { + responseFormat = (String) config.get("responseFormat"); + } + + // User + if (config.containsKey("user")) { + body.put("user", (String) config.get("user")); + } + + // gpt-image-1 specific options + if (modelName.contains("gpt-image")) { + if (config.containsKey("background")) { + body.put("background", (String) config.get("background")); + } + if (config.containsKey("outputFormat")) { + body.put("output_format", (String) config.get("outputFormat")); + } + if (config.containsKey("outputCompression")) { + body.put("output_compression", ((Number) config.get("outputCompression")).intValue()); + } + if (config.containsKey("moderation")) { + body.put("moderation", (String) config.get("moderation")); + } + } + } + + body.put("response_format", responseFormat); + + return body; + } + + private String extractPrompt(ModelRequest request) { + // Get the prompt from the messages + List messages = request.getMessages(); + if (messages == null || messages.isEmpty()) { + throw new GenkitException("No messages provided for image generation"); + } + + // Find user message with text content + for (Message message : messages) { + if (message.getRole() == Role.USER || message.getRole() == null) { + List content = message.getContent(); + if (content != null) { + for (Part part : content) { + if (part.getText() != null) { + return part.getText(); + } + } + } + } + } + + throw new GenkitException("No text prompt found in messages for image generation"); + } + + private ModelResponse parseResponse(String responseBody) throws IOException { + JsonNode root = objectMapper.readTree(responseBody); + + ModelResponse response = new ModelResponse(); + List candidates = new ArrayList<>(); + Candidate candidate = new Candidate(); + + Message message = new Message(); + message.setRole(Role.MODEL); + List parts = new ArrayList<>(); + + // Parse image data + JsonNode dataNode = root.get("data"); + if (dataNode != null && dataNode.isArray()) { + for (JsonNode imageNode : dataNode) { + Part part = new Part(); + + // Determine content type based on model + String contentType = "image/png"; + if (modelName.contains("gpt-image")) { + // gpt-image-1 might return different formats + contentType = "image/png"; // Default, could be detected from config + } + + // Get URL or base64 data + String url = null; + if (imageNode.has("url") && !imageNode.get("url").isNull()) { + url = imageNode.get("url").asText(); + } else if (imageNode.has("b64_json") && !imageNode.get("b64_json").isNull()) { + String b64Data = imageNode.get("b64_json").asText(); + url = "data:" + contentType + ";base64," + b64Data; + } + + if (url != null) { + part.setMedia(new Media(contentType, url)); + parts.add(part); + } + } + } + + // Track image count in usage + Usage usage = new Usage(); + usage.setOutputImages(parts.size()); + response.setUsage(usage); + + message.setContent(parts); + candidate.setMessage(message); + candidate.setFinishReason(FinishReason.STOP); + + candidates.add(candidate); + response.setCandidates(candidates); + + logger.debug("Generated {} images", parts.size()); + + return response; + } +} diff --git a/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIModel.java b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIModel.java new file mode 100644 index 0000000000..ec1530258a --- /dev/null +++ b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIModel.java @@ -0,0 +1,624 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.openai; + +import java.io.IOException; +import java.util.*; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.genkit.ai.*; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +import okhttp3.*; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import okhttp3.sse.EventSources; + +/** + * OpenAI model implementation for Genkit. + */ +public class OpenAIModel implements Model { + + private static final Logger logger = LoggerFactory.getLogger(OpenAIModel.class); + private static final MediaType JSON_MEDIA_TYPE = MediaType.parse("application/json"); + + private final String modelName; + private final OpenAIPluginOptions options; + private final OkHttpClient client; + private final ObjectMapper objectMapper; + private final ModelInfo info; + + /** + * Creates a new OpenAIModel. + * + * @param modelName + * the model name + * @param options + * the plugin options + */ + public OpenAIModel(String modelName, OpenAIPluginOptions options) { + this.modelName = modelName; + this.options = options; + this.objectMapper = new ObjectMapper(); + this.client = new OkHttpClient.Builder().connectTimeout(options.getTimeout(), TimeUnit.SECONDS) + .readTimeout(options.getTimeout(), TimeUnit.SECONDS) + .writeTimeout(options.getTimeout(), TimeUnit.SECONDS).build(); + this.info = createModelInfo(); + } + + private ModelInfo createModelInfo() { + ModelInfo info = new ModelInfo(); + info.setLabel("OpenAI " + modelName); + + ModelInfo.ModelCapabilities caps = new ModelInfo.ModelCapabilities(); + caps.setMultiturn(true); + caps.setMedia(modelName.contains("gpt-4o") || modelName.contains("gpt-4-vision")); + caps.setTools(!modelName.startsWith("o1")); + caps.setSystemRole(!modelName.startsWith("o1")); + caps.setOutput(Set.of("text", "json")); + info.setSupports(caps); + + return info; + } + + @Override + public String getName() { + return "openai/" + modelName; + } + + @Override + public ModelInfo getInfo() { + return info; + } + + @Override + public boolean supportsStreaming() { + return true; + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request) { + try { + return callOpenAI(request); + } catch (IOException e) { + throw new GenkitException("OpenAI API call failed", e); + } + } + + @Override + public ModelResponse run(ActionContext context, ModelRequest request, Consumer streamCallback) { + if (streamCallback == null) { + return run(context, request); + } + try { + return callOpenAIStreaming(request, streamCallback); + } catch (Exception e) { + throw new GenkitException("OpenAI streaming API call failed", e); + } + } + + private ModelResponse callOpenAI(ModelRequest request) throws IOException { + ObjectNode requestBody = buildRequestBody(request); + + Request httpRequest = new Request.Builder().url(options.getBaseUrl() + "/chat/completions") + .header("Authorization", "Bearer " + options.getApiKey()).header("Content-Type", "application/json") + .post(RequestBody.create(requestBody.toString(), JSON_MEDIA_TYPE)).build(); + + if (options.getOrganization() != null) { + httpRequest = httpRequest.newBuilder().header("OpenAI-Organization", options.getOrganization()).build(); + } + + try (Response response = client.newCall(httpRequest).execute()) { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : "No error body"; + throw new GenkitException("OpenAI API error: " + response.code() + " - " + errorBody); + } + + String responseBody = response.body().string(); + return parseResponse(responseBody); + } + } + + private ModelResponse callOpenAIStreaming(ModelRequest request, Consumer streamCallback) + throws IOException, InterruptedException { + ObjectNode requestBody = buildRequestBody(request); + requestBody.put("stream", true); + + Request httpRequest = new Request.Builder().url(options.getBaseUrl() + "/chat/completions") + .header("Authorization", "Bearer " + options.getApiKey()).header("Content-Type", "application/json") + .header("Accept", "text/event-stream").post(RequestBody.create(requestBody.toString(), JSON_MEDIA_TYPE)) + .build(); + + if (options.getOrganization() != null) { + httpRequest = httpRequest.newBuilder().header("OpenAI-Organization", options.getOrganization()).build(); + } + + StringBuilder fullContent = new StringBuilder(); + AtomicReference finishReason = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + // Track tool calls being streamed (tool calls come in chunks) + List> toolCallsInProgress = new ArrayList<>(); + + EventSourceListener listener = new EventSourceListener() { + @Override + public void onEvent(EventSource eventSource, String id, String type, String data) { + if ("[DONE]".equals(data)) { + latch.countDown(); + return; + } + + try { + JsonNode chunk = objectMapper.readTree(data); + JsonNode choices = chunk.get("choices"); + if (choices != null && choices.isArray() && choices.size() > 0) { + JsonNode choice = choices.get(0); + JsonNode delta = choice.get("delta"); + + if (delta != null) { + // Handle text content + JsonNode contentNode = delta.get("content"); + if (contentNode != null && !contentNode.isNull()) { + String text = contentNode.asText(); + fullContent.append(text); + + // Create and send chunk + ModelResponseChunk responseChunk = ModelResponseChunk.text(text); + responseChunk.setIndex(choice.has("index") ? choice.get("index").asInt() : 0); + streamCallback.accept(responseChunk); + } + + // Handle tool calls (streamed incrementally) + JsonNode toolCallsNode = delta.get("tool_calls"); + if (toolCallsNode != null && toolCallsNode.isArray()) { + for (JsonNode toolCallDelta : toolCallsNode) { + int index = toolCallDelta.has("index") ? toolCallDelta.get("index").asInt() : 0; + + // Expand list if needed + while (toolCallsInProgress.size() <= index) { + Map newToolCall = new java.util.HashMap<>(); + newToolCall.put("arguments", new StringBuilder()); + toolCallsInProgress.add(newToolCall); + } + + Map toolCall = toolCallsInProgress.get(index); + + // Capture id if present + if (toolCallDelta.has("id")) { + toolCall.put("id", toolCallDelta.get("id").asText()); + } + + // Capture function name and arguments + JsonNode functionNode = toolCallDelta.get("function"); + if (functionNode != null) { + if (functionNode.has("name")) { + toolCall.put("name", functionNode.get("name").asText()); + } + if (functionNode.has("arguments")) { + StringBuilder args = (StringBuilder) toolCall.get("arguments"); + args.append(functionNode.get("arguments").asText()); + } + } + } + } + } + + JsonNode finishReasonNode = choice.get("finish_reason"); + if (finishReasonNode != null && !finishReasonNode.isNull()) { + finishReason.set(finishReasonNode.asText()); + } + } + } catch (Exception e) { + logger.error("Error parsing streaming chunk", e); + } + } + + @Override + public void onFailure(EventSource eventSource, Throwable t, Response response) { + String errorMsg = "Streaming failed"; + if (response != null) { + try { + errorMsg = "Streaming failed: " + response.code(); + if (response.body() != null) { + errorMsg += " - " + response.body().string(); + } + } catch (IOException e) { + // Ignore + } + } + error.set(new GenkitException(errorMsg, t)); + latch.countDown(); + } + + @Override + public void onClosed(EventSource eventSource) { + latch.countDown(); + } + }; + + EventSource.Factory factory = EventSources.createFactory(client); + EventSource eventSource = factory.newEventSource(httpRequest, listener); + + // Wait for streaming to complete + boolean completed = latch.await(options.getTimeout(), TimeUnit.SECONDS); + if (!completed) { + eventSource.cancel(); + throw new GenkitException("Streaming request timed out"); + } + + if (error.get() != null) { + throw error.get(); + } + + // Build the final response + ModelResponse response = new ModelResponse(); + List candidates = new ArrayList<>(); + Candidate candidate = new Candidate(); + + Message message = new Message(); + message.setRole(Role.MODEL); + List parts = new ArrayList<>(); + + // Add text content if present + if (fullContent.length() > 0) { + Part textPart = new Part(); + textPart.setText(fullContent.toString()); + parts.add(textPart); + } + + // Add tool calls if present + for (Map toolCall : toolCallsInProgress) { + String toolId = (String) toolCall.get("id"); + String toolName = (String) toolCall.get("name"); + StringBuilder argsBuilder = (StringBuilder) toolCall.get("arguments"); + + if (toolId != null && toolName != null) { + Part toolPart = new Part(); + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setRef(toolId); + toolRequest.setName(toolName); + + // Parse arguments JSON + String argsJson = argsBuilder.toString(); + if (argsJson != null && !argsJson.isEmpty()) { + try { + @SuppressWarnings("unchecked") + Map args = objectMapper.readValue(argsJson, Map.class); + toolRequest.setInput(args); + } catch (Exception e) { + logger.warn("Failed to parse tool arguments: {}", argsJson, e); + toolRequest.setInput(new java.util.HashMap<>()); + } + } + + toolPart.setToolRequest(toolRequest); + parts.add(toolPart); + } + } + + message.setContent(parts); + candidate.setMessage(message); + + // Set finish reason + String reason = finishReason.get(); + if (reason != null) { + switch (reason) { + case "stop" : + candidate.setFinishReason(FinishReason.STOP); + break; + case "length" : + candidate.setFinishReason(FinishReason.LENGTH); + break; + case "tool_calls" : + candidate.setFinishReason(FinishReason.OTHER); + break; + default : + candidate.setFinishReason(FinishReason.OTHER); + } + } + + candidates.add(candidate); + response.setCandidates(candidates); + + return response; + } + + private ObjectNode buildRequestBody(ModelRequest request) { + ObjectNode body = objectMapper.createObjectNode(); + body.put("model", modelName); + + // Convert messages + ArrayNode messages = body.putArray("messages"); + for (Message message : request.getMessages()) { + ObjectNode msg = messages.addObject(); + String role = convertRole(message.getRole()); + msg.put("role", role); + + // Handle content + List content = message.getContent(); + + // Check if this message contains tool requests (assistant with tool_calls) + boolean hasToolRequests = content.stream().anyMatch(p -> p.getToolRequest() != null); + // Check if this message contains tool responses + boolean hasToolResponses = content.stream().anyMatch(p -> p.getToolResponse() != null); + + if (hasToolRequests) { + // Assistant message with tool calls + // Add text content if present + String textContent = content.stream().filter(p -> p.getText() != null).map(Part::getText).findFirst() + .orElse(null); + if (textContent != null) { + msg.put("content", textContent); + } else { + msg.putNull("content"); + } + + // Add tool_calls array + ArrayNode toolCallsArray = msg.putArray("tool_calls"); + for (Part part : content) { + if (part.getToolRequest() != null) { + ToolRequest toolReq = part.getToolRequest(); + ObjectNode toolCall = toolCallsArray.addObject(); + toolCall.put("id", toolReq.getRef()); + toolCall.put("type", "function"); + ObjectNode function = toolCall.putObject("function"); + function.put("name", toolReq.getName()); + if (toolReq.getInput() != null) { + try { + function.put("arguments", objectMapper.writeValueAsString(toolReq.getInput())); + } catch (Exception e) { + function.put("arguments", "{}"); + } + } else { + function.put("arguments", "{}"); + } + } + } + } else if (hasToolResponses) { + // Tool response messages - each tool response is a separate message + // Remove the current message from array and add individual tool responses + messages.remove(messages.size() - 1); + + for (Part part : content) { + if (part.getToolResponse() != null) { + ToolResponse toolResp = part.getToolResponse(); + ObjectNode toolMsg = messages.addObject(); + toolMsg.put("role", "tool"); + toolMsg.put("tool_call_id", toolResp.getRef()); + + // Convert output to string + String outputStr; + if (toolResp.getOutput() instanceof String) { + outputStr = (String) toolResp.getOutput(); + } else { + try { + outputStr = objectMapper.writeValueAsString(toolResp.getOutput()); + } catch (Exception e) { + outputStr = String.valueOf(toolResp.getOutput()); + } + } + toolMsg.put("content", outputStr); + } + } + } else if (content.size() == 1 && content.get(0).getText() != null) { + // Simple text message + msg.put("content", content.get(0).getText()); + } else { + // Multi-part message + ArrayNode contentArray = msg.putArray("content"); + for (Part part : content) { + ObjectNode partNode = contentArray.addObject(); + if (part.getText() != null) { + partNode.put("type", "text"); + partNode.put("text", part.getText()); + } else if (part.getMedia() != null) { + partNode.put("type", "image_url"); + ObjectNode imageUrl = partNode.putObject("image_url"); + imageUrl.put("url", part.getMedia().getUrl()); + } + } + } + } + + // Add tools if present + if (request.getTools() != null && !request.getTools().isEmpty()) { + ArrayNode tools = body.putArray("tools"); + for (ToolDefinition tool : request.getTools()) { + ObjectNode toolNode = tools.addObject(); + toolNode.put("type", "function"); + ObjectNode function = toolNode.putObject("function"); + function.put("name", tool.getName()); + if (tool.getDescription() != null) { + function.put("description", tool.getDescription()); + } + if (tool.getInputSchema() != null) { + function.set("parameters", objectMapper.valueToTree(tool.getInputSchema())); + } + } + } + + // Add generation config + Map config = request.getConfig(); + if (config != null) { + if (config.containsKey("temperature")) { + body.put("temperature", ((Number) config.get("temperature")).doubleValue()); + } + if (config.containsKey("maxOutputTokens")) { + body.put("max_tokens", ((Number) config.get("maxOutputTokens")).intValue()); + } + if (config.containsKey("topP")) { + body.put("top_p", ((Number) config.get("topP")).doubleValue()); + } + if (config.containsKey("presencePenalty")) { + body.put("presence_penalty", ((Number) config.get("presencePenalty")).doubleValue()); + } + if (config.containsKey("frequencyPenalty")) { + body.put("frequency_penalty", ((Number) config.get("frequencyPenalty")).doubleValue()); + } + if (config.containsKey("stopSequences")) { + ArrayNode stop = body.putArray("stop"); + @SuppressWarnings("unchecked") + List stopSequences = (List) config.get("stopSequences"); + for (String seq : stopSequences) { + stop.add(seq); + } + } + if (config.containsKey("seed")) { + body.put("seed", ((Number) config.get("seed")).intValue()); + } + } + + // Handle output format + OutputConfig output = request.getOutput(); + if (output != null && output.getFormat() == OutputFormat.JSON) { + ObjectNode responseFormat = body.putObject("response_format"); + responseFormat.put("type", "json_object"); + } + + return body; + } + + private String convertRole(Role role) { + switch (role) { + case SYSTEM : + return "system"; + case USER : + return "user"; + case MODEL : + return "assistant"; + case TOOL : + return "tool"; + default : + return "user"; + } + } + + private ModelResponse parseResponse(String responseBody) throws IOException { + JsonNode root = objectMapper.readTree(responseBody); + + ModelResponse response = new ModelResponse(); + List candidates = new ArrayList<>(); + + JsonNode choices = root.get("choices"); + if (choices != null && choices.isArray()) { + for (JsonNode choice : choices) { + Candidate candidate = new Candidate(); + + // Parse message + JsonNode messageNode = choice.get("message"); + if (messageNode != null) { + Message message = new Message(); + message.setRole(Role.MODEL); + + List parts = new ArrayList<>(); + + // Text content + JsonNode contentNode = messageNode.get("content"); + if (contentNode != null && !contentNode.isNull()) { + Part part = new Part(); + part.setText(contentNode.asText()); + parts.add(part); + } + + // Tool calls + JsonNode toolCallsNode = messageNode.get("tool_calls"); + if (toolCallsNode != null && toolCallsNode.isArray()) { + for (JsonNode toolCallNode : toolCallsNode) { + Part part = new Part(); + ToolRequest toolRequest = new ToolRequest(); + toolRequest.setRef(toolCallNode.get("id").asText()); + + JsonNode functionNode = toolCallNode.get("function"); + if (functionNode != null) { + toolRequest.setName(functionNode.get("name").asText()); + JsonNode argsNode = functionNode.get("arguments"); + if (argsNode != null) { + toolRequest.setInput(objectMapper.readValue(argsNode.asText(), Map.class)); + } + } + + part.setToolRequest(toolRequest); + parts.add(part); + } + } + + message.setContent(parts); + candidate.setMessage(message); + } + + // Parse finish reason + JsonNode finishReasonNode = choice.get("finish_reason"); + if (finishReasonNode != null) { + String reason = finishReasonNode.asText(); + switch (reason) { + case "stop" : + candidate.setFinishReason(FinishReason.STOP); + break; + case "length" : + candidate.setFinishReason(FinishReason.LENGTH); + break; + case "tool_calls" : + candidate.setFinishReason(FinishReason.STOP); + break; + case "content_filter" : + candidate.setFinishReason(FinishReason.BLOCKED); + break; + default : + candidate.setFinishReason(FinishReason.OTHER); + } + } + + candidates.add(candidate); + } + } + + response.setCandidates(candidates); + + // Parse usage + JsonNode usageNode = root.get("usage"); + if (usageNode != null) { + Usage usage = new Usage(); + if (usageNode.has("prompt_tokens")) { + usage.setInputTokens(usageNode.get("prompt_tokens").asInt()); + } + if (usageNode.has("completion_tokens")) { + usage.setOutputTokens(usageNode.get("completion_tokens").asInt()); + } + if (usageNode.has("total_tokens")) { + usage.setTotalTokens(usageNode.get("total_tokens").asInt()); + } + response.setUsage(usage); + } + + return response; + } +} diff --git a/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIPlugin.java b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIPlugin.java new file mode 100644 index 0000000000..ffdd5c919b --- /dev/null +++ b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIPlugin.java @@ -0,0 +1,143 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.openai; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.core.Action; +import com.google.genkit.core.Plugin; + +/** + * OpenAIPlugin provides OpenAI model integrations for Genkit. + * + * This plugin registers OpenAI models (GPT-4, GPT-3.5-turbo, etc.), embeddings + * (text-embedding-ada-002, etc.), and image generation models (DALL-E, + * gpt-image-1) as Genkit actions. + */ +public class OpenAIPlugin implements Plugin { + + private static final Logger logger = LoggerFactory.getLogger(OpenAIPlugin.class); + + /** + * Supported GPT models. + */ + public static final List SUPPORTED_MODELS = Arrays.asList("gpt-5.2", "gpt-5.1", "gpt-5", "gpt-4o", + "gpt-4o-mini", "gpt-4-turbo", "gpt-4-turbo-preview", "gpt-4", "gpt-4-32k", "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", "o1-preview", "o1-mini"); + + /** + * Supported embedding models. + */ + public static final List SUPPORTED_EMBEDDING_MODELS = Arrays.asList("text-embedding-3-small", + "text-embedding-3-large", "text-embedding-ada-002"); + + /** + * Supported image generation models. + */ + public static final List SUPPORTED_IMAGE_MODELS = Arrays.asList("dall-e-3", "dall-e-2", "gpt-image-1"); + + private final OpenAIPluginOptions options; + + /** + * Creates an OpenAIPlugin with default options. + */ + public OpenAIPlugin() { + this(OpenAIPluginOptions.builder().build()); + } + + /** + * Creates an OpenAIPlugin with the specified options. + * + * @param options + * the plugin options + */ + public OpenAIPlugin(OpenAIPluginOptions options) { + this.options = options; + } + + /** + * Creates an OpenAIPlugin with the specified API key. + * + * @param apiKey + * the OpenAI API key + * @return a new OpenAIPlugin + */ + public static OpenAIPlugin create(String apiKey) { + return new OpenAIPlugin(OpenAIPluginOptions.builder().apiKey(apiKey).build()); + } + + /** + * Creates an OpenAIPlugin using the OPENAI_API_KEY environment variable. + * + * @return a new OpenAIPlugin + */ + public static OpenAIPlugin create() { + return new OpenAIPlugin(); + } + + @Override + public String getName() { + return "openai"; + } + + @Override + public List> init() { + List> actions = new ArrayList<>(); + + // Register chat models + for (String modelName : SUPPORTED_MODELS) { + OpenAIModel model = new OpenAIModel(modelName, options); + actions.add(model); + logger.debug("Created OpenAI model: {}", modelName); + } + + // Register embedding models + for (String modelName : SUPPORTED_EMBEDDING_MODELS) { + OpenAIEmbedder embedder = new OpenAIEmbedder(modelName, options); + actions.add(embedder); + logger.debug("Created OpenAI embedder: {}", modelName); + } + + // Register image generation models + for (String modelName : SUPPORTED_IMAGE_MODELS) { + OpenAIImageModel imageModel = new OpenAIImageModel(modelName, options); + actions.add(imageModel); + logger.debug("Created OpenAI image model: {}", modelName); + } + + logger.info("OpenAI plugin initialized with {} models, {} embedders, and {} image models", + SUPPORTED_MODELS.size(), SUPPORTED_EMBEDDING_MODELS.size(), SUPPORTED_IMAGE_MODELS.size()); + + return actions; + } + + /** + * Gets the plugin options. + * + * @return the options + */ + public OpenAIPluginOptions getOptions() { + return options; + } +} diff --git a/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIPluginOptions.java b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIPluginOptions.java new file mode 100644 index 0000000000..b05d621a88 --- /dev/null +++ b/java/plugins/openai/src/main/java/com/google/genkit/plugins/openai/OpenAIPluginOptions.java @@ -0,0 +1,124 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.plugins.openai; + +/** + * Options for configuring the OpenAI plugin. + */ +public class OpenAIPluginOptions { + + private final String apiKey; + private final String baseUrl; + private final String organization; + private final int timeout; + + private OpenAIPluginOptions(Builder builder) { + this.apiKey = builder.apiKey; + this.baseUrl = builder.baseUrl; + this.organization = builder.organization; + this.timeout = builder.timeout; + } + + /** + * Creates a new builder. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the API key. + * + * @return the API key + */ + public String getApiKey() { + return apiKey; + } + + /** + * Gets the base URL for API requests. + * + * @return the base URL + */ + public String getBaseUrl() { + return baseUrl; + } + + /** + * Gets the organization ID. + * + * @return the organization ID + */ + public String getOrganization() { + return organization; + } + + /** + * Gets the request timeout in seconds. + * + * @return the timeout + */ + public int getTimeout() { + return timeout; + } + + /** + * Builder for OpenAIPluginOptions. + */ + public static class Builder { + private String apiKey = getApiKeyFromEnv(); + private String baseUrl = "https://api.openai.com/v1"; + private String organization; + private int timeout = 60; + + private static String getApiKeyFromEnv() { + return System.getenv("OPENAI_API_KEY"); + } + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder organization(String organization) { + this.organization = organization; + return this; + } + + public Builder timeout(int timeout) { + this.timeout = timeout; + return this; + } + + public OpenAIPluginOptions build() { + if (apiKey == null || apiKey.isEmpty()) { + throw new IllegalStateException( + "OpenAI API key is required. Set OPENAI_API_KEY environment variable or provide it in options."); + } + return new OpenAIPluginOptions(this); + } + } +} diff --git a/java/pom.xml b/java/pom.xml new file mode 100644 index 0000000000..2b2f674459 --- /dev/null +++ b/java/pom.xml @@ -0,0 +1,350 @@ + + + + 4.0.0 + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + pom + + Genkit for Java + Genkit - A framework for building AI-powered applications in Java + https://github.com/firebase/genkit + + + + Apache License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + + Genkit Team + Google + https://firebase.google.com + + + + + scm:git:git://github.com/firebase/genkit.git + scm:git:ssh://github.com:firebase/genkit.git + https://github.com/firebase/genkit/tree/main + + + + core + ai + genkit + plugins/jetty + plugins/openai + plugins/google-genai + plugins/localvec + plugins/mcp + samples/openai + samples/google-genai + samples/rag + samples/evaluations + samples/complex-io + samples/middleware + samples/mcp + samples/chat-session + samples/multi-agent + samples/interrupts + + + + UTF-8 + 17 + 17 + 17 + + + 2.17.0 + 2.0.12 + 1.5.3 + 12.0.7 + 1.36.0 + 5.10.2 + 5.11.0 + 4.36.0 + 4.4.0 + 4.12.0 + 2.44.4 + 10.14.0 + + + + + + + com.google.genkit + genkit-core + ${project.version} + + + com.google.genkit + genkit-ai + ${project.version} + + + com.google.genkit + genkit + ${project.version} + + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + + + com.fasterxml.jackson.core + jackson-core + ${jackson.version} + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + ${jackson.version} + + + + + org.slf4j + slf4j-api + ${slf4j.version} + + + ch.qos.logback + logback-classic + ${logback.version} + + + + + io.opentelemetry + opentelemetry-api + ${opentelemetry.version} + + + io.opentelemetry + opentelemetry-sdk + ${opentelemetry.version} + + + io.opentelemetry + opentelemetry-sdk-trace + ${opentelemetry.version} + + + io.opentelemetry + opentelemetry-exporter-otlp + ${opentelemetry.version} + + + + + org.eclipse.jetty + jetty-server + ${jetty.version} + + + org.eclipse.jetty + jetty-servlet + ${jetty.version} + + + org.eclipse.jetty.ee10 + jetty-ee10-servlet + ${jetty.version} + + + + + com.squareup.okhttp3 + okhttp + ${okhttp.version} + + + com.squareup.okhttp3 + okhttp-sse + ${okhttp.version} + + + + + com.github.victools + jsonschema-generator + ${jsonschema.version} + + + + + com.github.jknack + handlebars + ${handlebars.version} + + + + + org.junit.jupiter + junit-jupiter + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + + + com.squareup.okhttp3 + mockwebserver + ${okhttp.version} + test + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.12.1 + + ${java.version} + ${java.version} + + -parameters + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.2.5 + + + org.apache.maven.plugins + maven-jar-plugin + 3.3.0 + + + org.apache.maven.plugins + maven-source-plugin + 3.3.0 + + + attach-sources + + jar-no-fork + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.6.3 + + + attach-javadocs + + jar + + + + + + + com.diffplug.spotless + spotless-maven-plugin + ${spotless.version} + + + + 4.29 + + + true + 2 + + + + java,javax,org,com, + + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 3.3.1 + + google_checks.xml + true + false + error + + + + com.puppycrawl.tools + checkstyle + ${checkstyle.version} + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + diff --git a/java/samples/README.md b/java/samples/README.md new file mode 100644 index 0000000000..501e50ffbc --- /dev/null +++ b/java/samples/README.md @@ -0,0 +1,265 @@ +# Genkit Java Samples + +This directory contains sample applications demonstrating various features of Genkit Java SDK. + +## Prerequisites + +All samples require: + +- **Java 17+** +- **Maven 3.6+** +- **API Key** for the model provider (OpenAI or Google GenAI) + +## Quick Start + +Each sample can be run with: + +```bash +# 1. Set your API key (OpenAI samples) +export OPENAI_API_KEY=your-api-key-here + +# Or for Google GenAI samples +export GOOGLE_GENAI_API_KEY=your-api-key-here + +# 2. Navigate to the sample directory +cd java/samples/ + +# 3. Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +## Running with Genkit Dev UI + +For the best development experience, use the Genkit CLI to run samples with the Dev UI: + +```bash +# Install Genkit CLI (if not already installed) +npm install -g genkit + +# Run sample with Dev UI +cd java/samples/ +genkit start -- ./run.sh +# Or: genkit start -- mvn exec:java +``` + +The Dev UI will be available at `http://localhost:4000` and allows you to: +- View all registered actions (flows, models, tools, prompts) +- Run flows with test inputs +- Inspect traces and execution logs +- Manage datasets and run evaluations + +## Available Samples + +| Sample | Description | API Key Required | +|--------|-------------|------------------| +| [openai](./openai) | Basic OpenAI integration with flows and tools | `OPENAI_API_KEY` | +| [google-genai](./google-genai) | Google Gemini integration with image generation | `GOOGLE_GENAI_API_KEY` | +| [dotprompt](./dotprompt) | DotPrompt files with complex inputs/outputs, variants, and partials | `OPENAI_API_KEY` | +| [rag](./rag) | RAG application with local vector store | `OPENAI_API_KEY` | +| [chat-session](./chat-session) | Multi-turn chat with session persistence | `OPENAI_API_KEY` | +| [evaluations](./evaluations) | Custom evaluators and evaluation workflows | `OPENAI_API_KEY` | +| [complex-io](./complex-io) | Complex nested types, arrays, maps in flow inputs/outputs | `OPENAI_API_KEY` | +| [middleware](./middleware) | Middleware patterns for logging, caching, rate limiting | `OPENAI_API_KEY` | +| [multi-agent](./multi-agent) | Multi-agent orchestration patterns | `OPENAI_API_KEY` | +| [interrupts](./interrupts) | Flow interrupts and human-in-the-loop patterns | `OPENAI_API_KEY` | +| [mcp](./mcp) | Model Context Protocol (MCP) integration | `OPENAI_API_KEY` | + +## Sample Details + +### OpenAI Sample + +Basic integration with OpenAI models demonstrating: +- Text generation with GPT-4o +- Tool usage +- Streaming responses +- Flow definitions + +```bash +cd java/samples/openai +export OPENAI_API_KEY=your-key +./run.sh +``` + +### Google GenAI Sample + +Integration with Google Gemini models demonstrating: +- Text generation with Gemini +- Image generation with Imagen +- Multi-modal inputs + +```bash +cd java/samples/google-genai +export GOOGLE_GENAI_API_KEY=your-key +./run.sh +``` + +### DotPrompt Sample + +Template-based prompts with Handlebars demonstrating: +- Loading `.prompt` files +- Complex input/output schemas +- Prompt variants (e.g., `recipe.robot.prompt`) +- Partials for reusable templates + +```bash +cd java/samples/dotprompt +export OPENAI_API_KEY=your-key +./run.sh +``` + +### RAG Sample + +Retrieval-Augmented Generation demonstrating: +- Local vector store for development +- Document indexing and retrieval +- Semantic search with embeddings +- Context-aware generation + +```bash +cd java/samples/rag +export OPENAI_API_KEY=your-key +./run.sh +``` + +### Chat Session Sample + +Multi-turn conversations demonstrating: +- Conversation history management +- Session state persistence +- Tool integration within sessions +- Multiple chat personas + +```bash +cd java/samples/chat-session +export OPENAI_API_KEY=your-key +./run.sh +``` + +### Evaluations Sample + +AI output evaluation demonstrating: +- Custom evaluator definitions +- Dataset management +- Evaluation workflows +- Quality metrics + +```bash +cd java/samples/evaluations +export OPENAI_API_KEY=your-key +./run.sh +``` + +### Complex I/O Sample + +Complex type handling demonstrating: +- Deeply nested object types +- Arrays and collections +- Optional fields and maps +- Domain objects (e-commerce, analytics) + +```bash +cd java/samples/complex-io +export OPENAI_API_KEY=your-key +./run.sh +``` + +### Middleware Sample + +Cross-cutting concerns demonstrating: +- Logging middleware +- Caching middleware +- Rate limiting +- Request/response transformation +- Error handling + +```bash +cd java/samples/middleware +export OPENAI_API_KEY=your-key +./run.sh +``` + +### Multi-Agent Sample + +Multi-agent orchestration demonstrating: +- Agent coordination patterns +- Task delegation +- Inter-agent communication + +```bash +cd java/samples/multi-agent +export OPENAI_API_KEY=your-key +./run.sh +``` + +### Interrupts Sample + +Flow control demonstrating: +- Human-in-the-loop patterns +- Flow interrupts and resumption +- External input handling + +```bash +cd java/samples/interrupts +export OPENAI_API_KEY=your-key +./run.sh +``` + +### MCP Sample + +Model Context Protocol integration demonstrating: +- MCP server connections +- Tool discovery and usage +- Resource management +- File operations + +```bash +cd java/samples/mcp +export OPENAI_API_KEY=your-key +./run.sh +``` + +## Building All Samples + +From the Java root directory: + +```bash +cd java +mvn clean install +``` + +## Common Issues + +### API Key Not Set + +``` +Error: OPENAI_API_KEY environment variable is not set +``` + +**Solution**: Set the required API key for the sample you're running. + +### Port Already in Use + +``` +Error: Address already in use (Bind failed) +``` + +**Solution**: The default port (8080 or 3100) is in use. Either stop the other process or configure a different port. + +### Maven Dependencies Not Found + +``` +Error: Could not find artifact com.google.genkit:genkit +``` + +**Solution**: Build the parent project first: +```bash +cd java +mvn clean install -DskipTests +``` + +## Additional Resources + +- [Genkit Java README](../README.md) - Main documentation +- [Genkit Documentation](https://firebase.google.com/docs/genkit) - Official docs +- [Genkit GitHub](https://github.com/firebase/genkit) - Source code diff --git a/java/samples/chat-session/README.md b/java/samples/chat-session/README.md new file mode 100644 index 0000000000..01d9790172 --- /dev/null +++ b/java/samples/chat-session/README.md @@ -0,0 +1,91 @@ +# Chat Session Sample + +This sample demonstrates session-based multi-turn chat with persistence in Genkit Java. + +## Features + +- **Multi-turn conversations** - Automatic conversation history management +- **Session state** - Track user preferences and conversation context +- **Session persistence** - Save and load sessions across interactions +- **Tool integration** - Using tools (note-taking) within chat sessions +- **Multiple personas** - Choose between assistant, tutor, and creative modes + +## Prerequisites + +1. Java 17 or later +2. Maven +3. OpenAI API key + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run (Interactive Mode) + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/chat-session + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: Demo Mode + +Run the automated demo to see all features: + +```bash +cd java/samples/chat-session +mvn exec:java -Dexec.args="--demo" +``` + +### Option 3: With Genkit Dev UI + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/chat-session + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +## Commands + +During interactive chat, you can use these commands: + +| Command | Description | +|---------|-------------| +| `/history` | Show conversation history | +| `/notes` | Show saved notes | +| `/state` | Show session state | +| `/topic X` | Set conversation topic to X | +| `/quit` | Exit the chat | + +## Example Session + +``` +What's your name? Alice + +Choose a chat persona: + 1. Assistant (general help) + 2. Tutor (learning & education) + 3. Creative (storytelling & ideas) +Enter choice (1-3): 2 + +✓ Session created: a1b2c3d4-e5f6-... +✓ Persona: tutor + +You: What is photosynthesis? \ No newline at end of file diff --git a/java/samples/chat-session/pom.xml b/java/samples/chat-session/pom.xml new file mode 100644 index 0000000000..ed4aad9664 --- /dev/null +++ b/java/samples/chat-session/pom.xml @@ -0,0 +1,80 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-chat-session + jar + Genkit Chat Session Sample + Sample application demonstrating session-based multi-turn chat with persistence + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.4.14 + + + + + + + org.codehaus.mojo + exec-maven-plugin + 3.1.0 + + com.google.genkit.samples.ChatSessionApp + + + + + diff --git a/java/samples/chat-session/run.sh b/java/samples/chat-session/run.sh new file mode 100755 index 0000000000..7a055a49ca --- /dev/null +++ b/java/samples/chat-session/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Run script for Genkit DotPrompt Sample +cd "$(dirname "$0")" +mvn exec:java diff --git a/java/samples/chat-session/src/main/java/com/google/genkit/samples/ChatSessionApp.java b/java/samples/chat-session/src/main/java/com/google/genkit/samples/ChatSessionApp.java new file mode 100644 index 0000000000..99c7149299 --- /dev/null +++ b/java/samples/chat-session/src/main/java/com/google/genkit/samples/ChatSessionApp.java @@ -0,0 +1,421 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Scanner; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.Message; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.session.Chat; +import com.google.genkit.ai.session.ChatOptions; +import com.google.genkit.ai.session.InMemorySessionStore; +import com.google.genkit.ai.session.Session; +import com.google.genkit.ai.session.SessionOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; + +/** + * Interactive Chat Application with Session Persistence. + * + *

+ * This sample demonstrates: + *

    + *
  • Creating and managing chat sessions
  • + *
  • Multi-turn conversations with automatic history
  • + *
  • Session state management
  • + *
  • Using tools in chat sessions
  • + *
  • Persisting and loading sessions
  • + *
+ * + *

+ * To run: + *

    + *
  1. Set the OPENAI_API_KEY environment variable
  2. + *
  3. Run: mvn exec:java -pl samples/chat-session
  4. + *
+ */ +public class ChatSessionApp { + + /** Session state to track conversation context and user preferences. */ + public static class ConversationState { + private String userName; + private String topic; + private int messageCount; + + public ConversationState() { + this.messageCount = 0; + } + + public ConversationState(String userName) { + this.userName = userName; + this.messageCount = 0; + } + + public String getUserName() { + return userName; + } + + public void setUserName(String userName) { + this.userName = userName; + } + + public String getTopic() { + return topic; + } + + public void setTopic(String topic) { + this.topic = topic; + } + + public int getMessageCount() { + return messageCount; + } + + public void incrementMessageCount() { + this.messageCount++; + } + + @Override + public String toString() { + return String.format("User: %s, Topic: %s, Messages: %d", userName != null ? userName : "Anonymous", + topic != null ? topic : "General", messageCount); + } + } + + private final Genkit genkit; + private final InMemorySessionStore sessionStore; + private final Tool noteTool; + private final Map notes; + + public ChatSessionApp() { + // Initialize notes storage + this.notes = new HashMap<>(); + + // Create Genkit with OpenAI plugin + this.genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(OpenAIPlugin.create()).build(); + + // Create a shared session store + this.sessionStore = new InMemorySessionStore<>(); + + // Define a note-taking tool + this.noteTool = createNoteTool(); + } + + @SuppressWarnings("unchecked") + private Tool createNoteTool() { + return genkit.defineTool("saveNote", + "Saves a note for the user. Use this when the user wants to remember something.", + Map.of("type", "object", "properties", + Map.of("title", Map.of("type", "string", "description", "Title of the note"), "content", + Map.of("type", "string", "description", "Content of the note")), + "required", new String[]{"title", "content"}), + (Class>) (Class) Map.class, (ctx, input) -> { + String title = (String) input.get("title"); + String content = (String) input.get("content"); + notes.put(title, content); + Map result = new HashMap<>(); + result.put("status", "saved"); + result.put("message", "Note '" + title + "' has been saved."); + return result; + }); + } + + /** Creates a new chat session with the given user name. */ + public Session createSession(String userName) { + return genkit.createSession(SessionOptions.builder().store(sessionStore) + .initialState(new ConversationState(userName)).build()); + } + + /** Loads an existing session by ID. */ + public Session loadSession(String sessionId) { + try { + return genkit + .loadSession(sessionId, SessionOptions.builder().store(sessionStore).build()) + .get(); + } catch (Exception e) { + System.err.println("Failed to load session: " + e.getMessage()); + return null; + } + } + + /** Creates a chat instance for a session. */ + @SuppressWarnings("unchecked") + public Chat createChat(Session session, String persona) { + String systemPrompt = buildSystemPrompt(session, persona); + + return session.chat(ChatOptions.builder().model("openai/gpt-4o-mini").system(systemPrompt) + .tools(List.of((Tool) noteTool)).build()); + } + + private String buildSystemPrompt(Session session, String persona) { + ConversationState state = session.getState(); + StringBuilder prompt = new StringBuilder(); + + // Base persona + switch (persona.toLowerCase()) { + case "assistant" : + prompt.append("You are a helpful, friendly assistant. "); + break; + case "tutor" : + prompt.append("You are a patient and knowledgeable tutor. Explain concepts clearly and encourage" + + " learning. "); + break; + case "creative" : + prompt.append("You are a creative writing partner. Be imaginative and help with storytelling. "); + break; + default : + prompt.append("You are a helpful assistant. "); + } + + // Add user context if available + if (state.getUserName() != null) { + prompt.append("The user's name is ").append(state.getUserName()).append(". "); + } + + // Add topic context if set + if (state.getTopic() != null) { + prompt.append("The current topic of discussion is: ").append(state.getTopic()).append(". "); + } + + prompt.append("You can save notes for the user using the saveNote tool when they want to remember" + + " something important."); + + return prompt.toString(); + } + + /** Sends a message and updates session state. */ + public String chat(Chat chat, String userMessage) { + try { + // Update message count in state + Session session = chat.getSession(); + ConversationState state = session.getState(); + state.incrementMessageCount(); + session.updateState(state).join(); + + // Send message and get response + ModelResponse response = chat.send(userMessage); + return response.getText(); + } catch (Exception e) { + return "Error: " + e.getMessage(); + } + } + + /** Displays conversation history. */ + public void showHistory(Chat chat) { + System.out.println("\n--- Conversation History ---"); + List history = chat.getHistory(); + for (Message msg : history) { + String role = msg.getRole().toString(); + String text = msg.getText(); + if (text.length() > 100) { + text = text.substring(0, 100) + "..."; + } + System.out.printf("[%s]: %s%n", role, text); + } + System.out.println("--- End History ---\n"); + } + + /** Displays saved notes. */ + public void showNotes() { + System.out.println("\n--- Saved Notes ---"); + if (notes.isEmpty()) { + System.out.println("No notes saved yet."); + } else { + notes.forEach((title, content) -> System.out.printf("• %s: %s%n", title, content)); + } + System.out.println("--- End Notes ---\n"); + } + + /** Interactive chat loop. */ + public void runInteractive() { + Scanner scanner = new Scanner(System.in); + + System.out.println("╔════════════════════════════════════════════════════════════╗"); + System.out.println("║ Genkit Chat Session Demo - Interactive Chat App ║"); + System.out.println("╚════════════════════════════════════════════════════════════╝"); + System.out.println(); + + // Get user name + System.out.print("What's your name? "); + String userName = scanner.nextLine().trim(); + if (userName.isEmpty()) { + userName = "User"; + } + + // Choose persona + System.out.println("\nChoose a chat persona:"); + System.out.println(" 1. Assistant (general help)"); + System.out.println(" 2. Tutor (learning & education)"); + System.out.println(" 3. Creative (storytelling & ideas)"); + System.out.print("Enter choice (1-3): "); + String choice = scanner.nextLine().trim(); + String persona = switch (choice) { + case "2" -> "tutor"; + case "3" -> "creative"; + default -> "assistant"; + }; + + // Create session and chat + Session session = createSession(userName); + Chat chat = createChat(session, persona); + + System.out.println("\n✓ Session created: " + session.getId()); + System.out.println("✓ Persona: " + persona); + System.out.println("\nCommands:"); + System.out.println(" /history - Show conversation history"); + System.out.println(" /notes - Show saved notes"); + System.out.println(" /state - Show session state"); + System.out.println(" /topic X - Set conversation topic to X"); + System.out.println(" /quit - Exit the chat"); + System.out.println("\nStart chatting!\n"); + + // Chat loop + while (true) { + System.out.print("You: "); + String input = scanner.nextLine().trim(); + + if (input.isEmpty()) { + continue; + } + + // Handle commands + if (input.startsWith("/")) { + if (input.equals("/quit") || input.equals("/exit")) { + System.out.println("\nGoodbye, " + userName + "! Session saved."); + break; + } else if (input.equals("/history")) { + showHistory(chat); + continue; + } else if (input.equals("/notes")) { + showNotes(); + continue; + } else if (input.equals("/state")) { + System.out.println("\nSession State: " + session.getState()); + continue; + } else if (input.startsWith("/topic ")) { + String topic = input.substring(7).trim(); + ConversationState state = session.getState(); + state.setTopic(topic); + session.updateState(state).join(); + System.out.println("✓ Topic set to: " + topic); + // Recreate chat with updated system prompt + chat = createChat(session, persona); + continue; + } else { + System.out.println("Unknown command: " + input); + continue; + } + } + + // Send message + String response = chat(chat, input); + System.out.println("\nAssistant: " + response + "\n"); + } + + scanner.close(); + } + + /** Demo mode showing various session features. */ + public void runDemo() { + System.out.println("╔════════════════════════════════════════════════════════════╗"); + System.out.println("║ Genkit Chat Session Demo - Automated Demo ║"); + System.out.println("╚════════════════════════════════════════════════════════════╝"); + System.out.println(); + + // Demo 1: Basic multi-turn conversation + System.out.println("=== Demo 1: Multi-turn Conversation ===\n"); + Session session1 = createSession("Alice"); + Chat chat1 = createChat(session1, "assistant"); + + String[] questions = {"What are the three laws of thermodynamics?", + "Can you explain the second one in simpler terms?", "How does this relate to entropy?"}; + + for (String question : questions) { + System.out.println("User: " + question); + String response = chat(chat1, question); + System.out.println("Assistant: " + truncate(response, 200) + "\n"); + } + + // Demo 2: Session state + System.out.println("\n=== Demo 2: Session State ===\n"); + System.out.println("Session ID: " + session1.getId()); + System.out.println("State: " + session1.getState()); + + // Demo 3: Save and load session + System.out.println("\n=== Demo 3: Session Persistence ===\n"); + String sessionId = session1.getId(); + System.out.println("Saving session: " + sessionId); + + // Load the session + Session loadedSession = loadSession(sessionId); + if (loadedSession != null) { + System.out.println("✓ Session loaded successfully!"); + System.out.println(" Messages in history: " + loadedSession.getMessages().size()); + System.out.println(" State: " + loadedSession.getState()); + + // Continue the conversation + Chat continuedChat = createChat(loadedSession, "assistant"); + System.out.println("\nContinuing conversation..."); + System.out.println("User: Can you summarize what we discussed?"); + String summary = chat(continuedChat, "Can you summarize what we discussed?"); + System.out.println("Assistant: " + truncate(summary, 300)); + } + + // Demo 4: Using tools + System.out.println("\n\n=== Demo 4: Using Tools (Note Taking) ===\n"); + Session session2 = createSession("Bob"); + Chat chat2 = createChat(session2, "assistant"); + + System.out.println("User: Please save a note titled 'Meeting' with content 'Review Q4 goals'"); + String noteResponse = chat(chat2, "Please save a note titled 'Meeting' with content 'Review Q4 goals'"); + System.out.println("Assistant: " + noteResponse); + showNotes(); + + System.out.println("\n=== Demo Complete ==="); + } + + private String truncate(String text, int maxLength) { + if (text == null) { + return ""; + } + if (text.length() <= maxLength) { + return text; + } + return text.substring(0, maxLength) + "..."; + } + + public static void main(String[] args) { + ChatSessionApp app = new ChatSessionApp(); + + // Check for demo mode flag + boolean demoMode = args.length > 0 && args[0].equals("--demo"); + + if (demoMode) { + app.runDemo(); + } else { + app.runInteractive(); + } + } +} diff --git a/java/samples/complex-io/README.md b/java/samples/complex-io/README.md new file mode 100644 index 0000000000..a7b352b83b --- /dev/null +++ b/java/samples/complex-io/README.md @@ -0,0 +1,282 @@ +# Complex I/O Sample + +This sample demonstrates working with complex input and output types in Genkit Java flows. + +## Overview + +The sample showcases: + +- **Deeply nested object types** - OrderRequest with Customer, Address, PaymentMethod nested objects +- **Arrays and collections** - Lists of OrderItems, TimelineEvents, Recommendations +- **Optional fields** - Various nullable fields throughout the types +- **Maps and generic types** - Metadata maps, dynamic aggregation results +- **Complex domain objects** - E-commerce orders, analytics dashboards + +## Complex Types + +### OrderRequest / OrderResponse +Simulates a complex e-commerce order with: +- Customer information with preferences +- Multiple order items with customizations +- Shipping and billing addresses with coordinates +- Payment method details +- Metadata maps + +### AnalyticsRequest / AnalyticsResponse +Simulates a complex analytics dashboard query with: +- Date ranges with timezone support +- Nested filters with logical operators +- Multiple grouping and aggregation options +- Sorting and pagination +- Visualizations with chart configurations +- AI-generated insights with suggested actions + +### ValidationResult +Demonstrates validation output with: +- List of errors with field paths +- List of warnings +- Severity levels + +## Available Flows + +| Flow | Input | Output | Description | +|------|-------|--------|-------------| +| `processOrder` | OrderRequest | OrderResponse | Process complex order and return detailed response | +| `generateAnalytics` | AnalyticsRequest | AnalyticsResponse | Generate analytics dashboard data | +| `processBatch` | List | Map | Process batch of items | +| `simplifyOrder` | OrderRequest | Map | Transform order to simplified format | +| `validateOrder` | OrderRequest | ValidationResult | Validate order structure | + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/complex-io + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI (Recommended) + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/complex-io + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +## Example Requests + +### Process Order + +```bash +curl -X POST http://localhost:8080/api/flow/processOrder \ + -H "Content-Type: application/json" \ + -d '{ + "customerId": "CUST-001", + "customer": { + "id": "CUST-001", + "firstName": "John", + "lastName": "Doe", + "email": "john@example.com", + "phone": "+1-555-123-4567", + "preferences": { + "communicationChannel": "email", + "marketingOptIn": true, + "language": "en" + } + }, + "items": [ + { + "productId": "PROD-001", + "name": "Wireless Mouse", + "quantity": 2, + "unitPrice": 29.99, + "customizations": [ + { + "type": "color", + "value": "black", + "additionalCost": 0 + } + ] + }, + { + "productId": "PROD-002", + "name": "USB-C Hub", + "quantity": 1, + "unitPrice": 49.99 + } + ], + "shippingAddress": { + "street1": "123 Main St", + "street2": "Apt 4B", + "city": "San Francisco", + "state": "CA", + "postalCode": "94102", + "country": "USA", + "coordinates": { + "latitude": 37.7749, + "longitude": -122.4194 + } + }, + "billingAddress": { + "street1": "123 Main St", + "city": "San Francisco", + "state": "CA", + "postalCode": "94102", + "country": "USA" + }, + "paymentMethod": { + "type": "credit_card", + "details": { + "lastFourDigits": "4242", + "cardType": "visa", + "expirationMonth": 12, + "expirationYear": 2026 + } + }, + "orderNotes": "Please leave at door", + "metadata": { + "source": "web", + "campaign": "winter-sale" + } + }' +``` + +### Generate Analytics + +```bash +curl -X POST http://localhost:8080/api/flow/generateAnalytics \ + -H "Content-Type: application/json" \ + -d '{ + "dashboardId": "sales-overview", + "dateRange": { + "start": "2025-01-01", + "end": "2025-02-10", + "timezone": "America/Los_Angeles", + "preset": "last_30_days" + }, + "filters": [ + { + "field": "status", + "operator": "eq", + "value": "completed" + }, + { + "field": "amount", + "operator": "gte", + "value": 100 + } + ], + "groupBy": [ + { + "field": "created_at", + "interval": "day", + "alias": "date" + } + ], + "metrics": [ + { + "name": "Total Revenue", + "field": "amount", + "aggregation": "sum", + "format": "currency" + }, + { + "name": "Order Count", + "field": "id", + "aggregation": "count", + "format": "number" + } + ], + "sorting": [ + { + "field": "date", + "direction": "desc" + } + ], + "pagination": { + "page": 1, + "pageSize": 50 + } + }' +``` + +### Validate Order + +```bash +curl -X POST http://localhost:8080/api/flow/validateOrder \ + -H "Content-Type: application/json" \ + -d '{ + "items": [] + }' +``` + +This will return validation errors because required fields are missing. + +## Type Hierarchy + +``` +OrderRequest +├── Customer +│ └── CustomerPreferences +├── List +│ ├── Discount +│ └── List +├── Address (shipping) +│ └── Coordinates +├── Address (billing) +├── PaymentMethod +│ └── PaymentDetails +└── Map (metadata) + +OrderResponse +├── CustomerSummary +├── OrderSummary +│ ├── List +│ ├── MoneyAmount (subtotal, shipping, total) +│ ├── List +│ │ └── MoneyAmount +│ └── TaxInfo +│ └── List +│ └── MoneyAmount +├── ShippingInfo +│ ├── DateRange +│ └── FormattedAddress +├── PaymentInfo +│ └── FormattedAddress +├── List +├── List +│ └── MoneyAmount +└── OrderAnalytics + └── MoneyAmount +``` + +## Use Cases + +This sample is useful for testing: + +1. **Schema Generation** - How Genkit generates JSON schemas for complex nested types +2. **Serialization** - Jackson serialization/deserialization of deeply nested objects +3. **Type Safety** - Java generics handling with flows +4. **Validation** - Input validation with complex structures +5. **UI Rendering** - Testing Genkit Developer UI with complex types diff --git a/java/samples/complex-io/pom.xml b/java/samples/complex-io/pom.xml new file mode 100644 index 0000000000..247411d6c4 --- /dev/null +++ b/java/samples/complex-io/pom.xml @@ -0,0 +1,89 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-complex-io + jar + Genkit Complex I/O Sample + Sample application demonstrating Genkit with complex input/output types + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.5.3 + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.12.1 + + 17 + 17 + + + + org.codehaus.mojo + exec-maven-plugin + 3.2.0 + + com.google.genkit.samples.ComplexIOSample + + + + + diff --git a/java/samples/complex-io/run.sh b/java/samples/complex-io/run.sh new file mode 100755 index 0000000000..7a055a49ca --- /dev/null +++ b/java/samples/complex-io/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Run script for Genkit DotPrompt Sample +cd "$(dirname "$0")" +mvn exec:java diff --git a/java/samples/complex-io/src/main/java/com/google/genkit/samples/ComplexIOSample.java b/java/samples/complex-io/src/main/java/com/google/genkit/samples/ComplexIOSample.java new file mode 100644 index 0000000000..9d33835815 --- /dev/null +++ b/java/samples/complex-io/src/main/java/com/google/genkit/samples/ComplexIOSample.java @@ -0,0 +1,424 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.core.Flow; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; +import com.google.genkit.samples.types.*; + +/** + * Sample application demonstrating complex input/output types with Genkit + * flows. + * + * This sample shows how to: - Use deeply nested object types - Handle arrays + * and collections - Work with optional fields - Process maps and generic types + * - Handle complex domain objects + */ +public class ComplexIOSample { + + public static void main(String[] args) throws Exception { + // Create the Jetty server plugin + JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build()); + + // Initialize Genkit with plugins + Genkit genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(OpenAIPlugin.create()).plugin(jetty).build(); + + // Flow 1: Process complex order with nested types + Flow processOrder = genkit.defineFlow("processOrder", OrderRequest.class, + OrderResponse.class, (request) -> { + // This flow demonstrates processing complex nested input + // and generating a complex nested response + + // In a real implementation, you would: + // 1. Validate the order + // 2. Calculate pricing with discounts + // 3. Process payment + // 4. Generate shipping details + // 5. Create timeline and recommendations + + OrderResponse response = new OrderResponse(); + response.setOrderId("ORD-" + System.currentTimeMillis()); + response.setStatus("CONFIRMED"); + + // Set customer summary + OrderResponse.CustomerSummary customerSummary = new OrderResponse.CustomerSummary(); + if (request.getCustomer() != null) { + customerSummary.setId(request.getCustomer().getId()); + customerSummary.setFullName( + request.getCustomer().getFirstName() + " " + request.getCustomer().getLastName()); + customerSummary.setEmail(request.getCustomer().getEmail()); + } + customerSummary.setLoyaltyTier("Gold"); + customerSummary.setTotalOrders(15); + response.setCustomer(customerSummary); + + // Set order summary with calculated values + OrderResponse.OrderSummary orderSummary = new OrderResponse.OrderSummary(); + orderSummary.setItemCount(request.getItems() != null ? request.getItems().size() : 0); + + // Calculate totals + double subtotal = 0.0; + if (request.getItems() != null) { + for (OrderRequest.OrderItem item : request.getItems()) { + subtotal += item.getQuantity() * item.getUnitPrice(); + } + } + + OrderResponse.MoneyAmount subtotalAmount = new OrderResponse.MoneyAmount(); + subtotalAmount.setAmount(subtotal); + subtotalAmount.setCurrency("USD"); + subtotalAmount.setFormatted(String.format("$%.2f", subtotal)); + orderSummary.setSubtotal(subtotalAmount); + + OrderResponse.MoneyAmount totalAmount = new OrderResponse.MoneyAmount(); + double tax = subtotal * 0.08; + double total = subtotal + tax + 9.99; // shipping + totalAmount.setAmount(total); + totalAmount.setCurrency("USD"); + totalAmount.setFormatted(String.format("$%.2f", total)); + orderSummary.setTotal(totalAmount); + + response.setOrderSummary(orderSummary); + + // Set shipping info + OrderResponse.ShippingInfo shippingInfo = new OrderResponse.ShippingInfo(); + shippingInfo.setMethod("Standard"); + shippingInfo.setCarrier("USPS"); + shippingInfo.setTrackingNumber("1Z999AA10123456784"); + + OrderResponse.DateRange deliveryRange = new OrderResponse.DateRange(); + deliveryRange.setEarliest("2025-02-15"); + deliveryRange.setLatest("2025-02-18"); + shippingInfo.setEstimatedDelivery(deliveryRange); + + response.setShipping(shippingInfo); + + // Set timeline events + OrderResponse.TimelineEvent event1 = new OrderResponse.TimelineEvent(); + event1.setTimestamp("2025-02-10T10:30:00Z"); + event1.setEvent("ORDER_PLACED"); + event1.setDescription("Order was placed"); + event1.setActor("customer"); + + OrderResponse.TimelineEvent event2 = new OrderResponse.TimelineEvent(); + event2.setTimestamp("2025-02-10T10:30:05Z"); + event2.setEvent("PAYMENT_CONFIRMED"); + event2.setDescription("Payment was confirmed"); + event2.setActor("system"); + + response.setTimeline(Arrays.asList(event1, event2)); + + // Set analytics + OrderResponse.OrderAnalytics analytics = new OrderResponse.OrderAnalytics(); + analytics.setProcessingTime("1.2s"); + analytics.setFraudRiskScore(0.05); + analytics.setTags(Arrays.asList("new_customer", "high_value", "rush_shipping")); + response.setAnalytics(analytics); + + return response; + }); + + // Flow 2: Generate analytics from complex request + Flow generateAnalytics = genkit.defineFlow("generateAnalytics", + AnalyticsRequest.class, AnalyticsResponse.class, (request) -> { + // This flow demonstrates processing complex analytics requests + // with filters, grouping, and metrics + + AnalyticsResponse response = new AnalyticsResponse(); + response.setRequestId("REQ-" + System.currentTimeMillis()); + response.setExecutionTime("245ms"); + + // Build summary + AnalyticsResponse.Summary summary = new AnalyticsResponse.Summary(); + summary.setTotalRecords(15234L); + + Map metrics = new HashMap<>(); + + AnalyticsResponse.MetricValue revenueMetric = new AnalyticsResponse.MetricValue(); + revenueMetric.setValue(1234567.89); + revenueMetric.setFormatted("$1,234,567.89"); + AnalyticsResponse.ChangeIndicator revenueChange = new AnalyticsResponse.ChangeIndicator(); + revenueChange.setAbsolute(123456.78); + revenueChange.setPercentage(11.1); + revenueChange.setDirection("up"); + revenueChange.setComparisonPeriod("previous_month"); + revenueMetric.setChange(revenueChange); + metrics.put("revenue", revenueMetric); + + AnalyticsResponse.MetricValue ordersMetric = new AnalyticsResponse.MetricValue(); + ordersMetric.setValue(5678); + ordersMetric.setFormatted("5,678 orders"); + metrics.put("orders", ordersMetric); + + summary.setMetrics(metrics); + + // Add trend indicators + AnalyticsResponse.TrendIndicator revenueTrend = new AnalyticsResponse.TrendIndicator(); + revenueTrend.setMetric("revenue"); + revenueTrend.setTrend("increasing"); + revenueTrend.setConfidence(0.92); + + AnalyticsResponse.ForecastPoint fp1 = new AnalyticsResponse.ForecastPoint(); + fp1.setDate("2025-03-01"); + fp1.setValue(1350000.0); + fp1.setLower(1250000.0); + fp1.setUpper(1450000.0); + + AnalyticsResponse.ForecastPoint fp2 = new AnalyticsResponse.ForecastPoint(); + fp2.setDate("2025-04-01"); + fp2.setValue(1450000.0); + fp2.setLower(1300000.0); + fp2.setUpper(1600000.0); + + revenueTrend.setForecast(Arrays.asList(fp1, fp2)); + summary.setTrends(Arrays.asList(revenueTrend)); + + response.setSummary(summary); + + // Build data result + AnalyticsResponse.DataResult dataResult = new AnalyticsResponse.DataResult(); + + AnalyticsResponse.ColumnDefinition col1 = new AnalyticsResponse.ColumnDefinition(); + col1.setName("date"); + col1.setType("date"); + col1.setFormat("YYYY-MM-DD"); + + AnalyticsResponse.ColumnDefinition col2 = new AnalyticsResponse.ColumnDefinition(); + col2.setName("revenue"); + col2.setType("currency"); + col2.setFormat("$#,##0.00"); + col2.setAggregation("sum"); + + dataResult.setColumns(Arrays.asList(col1, col2)); + + // Sample rows + Map row1 = new HashMap<>(); + row1.put("date", "2025-02-01"); + row1.put("revenue", 45678.90); + + Map row2 = new HashMap<>(); + row2.put("date", "2025-02-02"); + row2.put("revenue", 52341.23); + + dataResult.setRows(Arrays.asList(row1, row2)); + response.setData(dataResult); + + // Build visualizations + AnalyticsResponse.Visualization lineChart = new AnalyticsResponse.Visualization(); + lineChart.setId("viz-001"); + lineChart.setType("line"); + lineChart.setTitle("Revenue Over Time"); + + AnalyticsResponse.VisualizationData vizData = new AnalyticsResponse.VisualizationData(); + vizData.setLabels(Arrays.asList("Jan", "Feb", "Mar", "Apr", "May")); + + AnalyticsResponse.Dataset dataset = new AnalyticsResponse.Dataset(); + dataset.setLabel("Revenue"); + dataset.setData(Arrays.asList(120000.0, 135000.0, 142000.0, 155000.0, 168000.0)); + dataset.setColor("#4285F4"); + vizData.setDatasets(Arrays.asList(dataset)); + + lineChart.setData(vizData); + + AnalyticsResponse.VisualizationConfig vizConfig = new AnalyticsResponse.VisualizationConfig(); + AnalyticsResponse.AxisConfig xAxis = new AnalyticsResponse.AxisConfig(); + xAxis.setLabel("Month"); + xAxis.setType("category"); + vizConfig.setXAxis(xAxis); + + AnalyticsResponse.AxisConfig yAxis = new AnalyticsResponse.AxisConfig(); + yAxis.setLabel("Revenue ($)"); + yAxis.setType("linear"); + yAxis.setFormat("currency"); + vizConfig.setYAxis(yAxis); + + lineChart.setConfig(vizConfig); + response.setVisualizations(Arrays.asList(lineChart)); + + // Add insights + AnalyticsResponse.Insight insight = new AnalyticsResponse.Insight(); + insight.setType("trend"); + insight.setSeverity("info"); + insight.setTitle("Strong Revenue Growth"); + insight.setDescription("Revenue has grown 11.1% compared to the previous month"); + insight.setMetrics(Arrays.asList("revenue", "orders")); + + AnalyticsResponse.SuggestedAction action = new AnalyticsResponse.SuggestedAction(); + action.setAction("Consider increasing marketing spend"); + action.setImpact("high"); + action.setEffort("medium"); + insight.setActions(Arrays.asList(action)); + + response.setInsights(Arrays.asList(insight)); + + // Set metadata + AnalyticsResponse.ResponseMetadata metadata = new AnalyticsResponse.ResponseMetadata(); + metadata.setQueryId("QRY-" + System.currentTimeMillis()); + metadata.setCacheHit(false); + metadata.setDataFreshness("2025-02-10T10:30:00Z"); + metadata.setWarnings(Arrays.asList()); + response.setMetadata(metadata); + + return response; + }); + + // Flow 3: Process batch of items with arrays + Flow, Map, Void> processBatch = genkit.defineFlow("processBatch", + (Class>) (Class) List.class, (Class>) (Class) Map.class, + (items) -> { + Map result = new HashMap<>(); + result.put("processed", items.size()); + result.put("items", items); + result.put("timestamp", System.currentTimeMillis()); + result.put("success", true); + + Map counts = new HashMap<>(); + for (String item : items) { + counts.merge(item, 1, Integer::sum); + } + result.put("itemCounts", counts); + + return result; + }); + + // Flow 4: Transform order to simplified format + Flow, Void> simplifyOrder = genkit.defineFlow("simplifyOrder", + OrderRequest.class, (Class>) (Class) Map.class, (order) -> { + Map simplified = new HashMap<>(); + + if (order.getCustomer() != null) { + simplified.put("customerName", + order.getCustomer().getFirstName() + " " + order.getCustomer().getLastName()); + simplified.put("email", order.getCustomer().getEmail()); + } + + if (order.getItems() != null) { + simplified.put("itemCount", order.getItems().size()); + + double total = 0.0; + for (OrderRequest.OrderItem item : order.getItems()) { + total += item.getQuantity() * item.getUnitPrice(); + } + simplified.put("orderTotal", total); + } + + if (order.getShippingAddress() != null) { + simplified.put("shippingCity", order.getShippingAddress().getCity()); + simplified.put("shippingCountry", order.getShippingAddress().getCountry()); + } + + return simplified; + }); + + // Flow 5: Validate order structure + Flow validateOrder = genkit.defineFlow("validateOrder", + OrderRequest.class, ValidationResult.class, (order) -> { + ValidationResult result = new ValidationResult(); + result.setValid(true); + + List errors = new java.util.ArrayList<>(); + List warnings = new java.util.ArrayList<>(); + + // Validate customer + if (order.getCustomer() == null) { + ValidationResult.ValidationError error = new ValidationResult.ValidationError(); + error.setField("customer"); + error.setMessage("Customer information is required"); + error.setSeverity("error"); + errors.add(error); + result.setValid(false); + } else { + if (order.getCustomer().getEmail() == null || order.getCustomer().getEmail().isEmpty()) { + ValidationResult.ValidationError error = new ValidationResult.ValidationError(); + error.setField("customer.email"); + error.setMessage("Customer email is required"); + error.setSeverity("error"); + errors.add(error); + result.setValid(false); + } + } + + // Validate items + if (order.getItems() == null || order.getItems().isEmpty()) { + ValidationResult.ValidationError error = new ValidationResult.ValidationError(); + error.setField("items"); + error.setMessage("Order must have at least one item"); + error.setSeverity("error"); + errors.add(error); + result.setValid(false); + } else { + for (int i = 0; i < order.getItems().size(); i++) { + OrderRequest.OrderItem item = order.getItems().get(i); + if (item.getQuantity() == null || item.getQuantity() <= 0) { + ValidationResult.ValidationError error = new ValidationResult.ValidationError(); + error.setField("items[" + i + "].quantity"); + error.setMessage("Item quantity must be greater than 0"); + error.setSeverity("error"); + errors.add(error); + result.setValid(false); + } + if (item.getQuantity() != null && item.getQuantity() > 100) { + ValidationResult.ValidationWarning warning = new ValidationResult.ValidationWarning(); + warning.setField("items[" + i + "].quantity"); + warning.setMessage("Large quantity order - manual review recommended"); + warnings.add(warning); + } + } + } + + // Validate shipping address + if (order.getShippingAddress() == null) { + ValidationResult.ValidationError error = new ValidationResult.ValidationError(); + error.setField("shippingAddress"); + error.setMessage("Shipping address is required"); + error.setSeverity("error"); + errors.add(error); + result.setValid(false); + } + + result.setErrors(errors); + result.setWarnings(warnings); + result.setErrorCount(errors.size()); + result.setWarningCount(warnings.size()); + + return result; + }); + + System.out.println("Complex I/O Sample started!"); + System.out.println("Available flows:"); + System.out.println(" - processOrder: Process complex nested order request"); + System.out.println(" - generateAnalytics: Generate analytics from complex request"); + System.out.println(" - processBatch: Process batch of items"); + System.out.println(" - simplifyOrder: Transform order to simplified format"); + System.out.println(" - validateOrder: Validate order structure"); + + // Start the server to expose the flows + jetty.start(); + } +} diff --git a/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/AnalyticsRequest.java b/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/AnalyticsRequest.java new file mode 100644 index 0000000000..a9f637ceae --- /dev/null +++ b/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/AnalyticsRequest.java @@ -0,0 +1,375 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples.types; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Analytics dashboard request with complex filtering options. + */ +public class AnalyticsRequest { + + @JsonProperty("dashboardId") + private String dashboardId; + + @JsonProperty("dateRange") + private DateRange dateRange; + + @JsonProperty("filters") + private List filters; + + @JsonProperty("groupBy") + private List groupBy; + + @JsonProperty("metrics") + private List metrics; + + @JsonProperty("sorting") + private List sorting; + + @JsonProperty("pagination") + private Pagination pagination; + + @JsonProperty("exportOptions") + private ExportOptions exportOptions; + + // Nested types + public static class DateRange { + @JsonProperty("start") + private String start; + + @JsonProperty("end") + private String end; + + @JsonProperty("timezone") + private String timezone; + + @JsonProperty("preset") + private String preset; // last_7_days, last_30_days, this_month, etc. + + // Getters and setters + public String getStart() { + return start; + } + public void setStart(String start) { + this.start = start; + } + public String getEnd() { + return end; + } + public void setEnd(String end) { + this.end = end; + } + public String getTimezone() { + return timezone; + } + public void setTimezone(String timezone) { + this.timezone = timezone; + } + public String getPreset() { + return preset; + } + public void setPreset(String preset) { + this.preset = preset; + } + } + + public static class Filter { + @JsonProperty("field") + private String field; + + @JsonProperty("operator") + private String operator; // eq, ne, gt, gte, lt, lte, in, contains, between + + @JsonProperty("value") + private Object value; + + @JsonProperty("values") + private List values; + + @JsonProperty("logicalOperator") + private String logicalOperator; // and, or + + @JsonProperty("nestedFilters") + private List nestedFilters; + + // Getters and setters + public String getField() { + return field; + } + public void setField(String field) { + this.field = field; + } + public String getOperator() { + return operator; + } + public void setOperator(String operator) { + this.operator = operator; + } + public Object getValue() { + return value; + } + public void setValue(Object value) { + this.value = value; + } + public List getValues() { + return values; + } + public void setValues(List values) { + this.values = values; + } + public String getLogicalOperator() { + return logicalOperator; + } + public void setLogicalOperator(String logicalOperator) { + this.logicalOperator = logicalOperator; + } + public List getNestedFilters() { + return nestedFilters; + } + public void setNestedFilters(List nestedFilters) { + this.nestedFilters = nestedFilters; + } + } + + public static class GroupByOption { + @JsonProperty("field") + private String field; + + @JsonProperty("interval") + private String interval; // hour, day, week, month, year + + @JsonProperty("alias") + private String alias; + + // Getters and setters + public String getField() { + return field; + } + public void setField(String field) { + this.field = field; + } + public String getInterval() { + return interval; + } + public void setInterval(String interval) { + this.interval = interval; + } + public String getAlias() { + return alias; + } + public void setAlias(String alias) { + this.alias = alias; + } + } + + public static class MetricDefinition { + @JsonProperty("name") + private String name; + + @JsonProperty("field") + private String field; + + @JsonProperty("aggregation") + private String aggregation; // sum, avg, count, min, max, distinct + + @JsonProperty("format") + private String format; // number, percentage, currency + + @JsonProperty("customFormula") + private String customFormula; + + // Getters and setters + public String getName() { + return name; + } + public void setName(String name) { + this.name = name; + } + public String getField() { + return field; + } + public void setField(String field) { + this.field = field; + } + public String getAggregation() { + return aggregation; + } + public void setAggregation(String aggregation) { + this.aggregation = aggregation; + } + public String getFormat() { + return format; + } + public void setFormat(String format) { + this.format = format; + } + public String getCustomFormula() { + return customFormula; + } + public void setCustomFormula(String customFormula) { + this.customFormula = customFormula; + } + } + + public static class SortOption { + @JsonProperty("field") + private String field; + + @JsonProperty("direction") + private String direction; // asc, desc + + // Getters and setters + public String getField() { + return field; + } + public void setField(String field) { + this.field = field; + } + public String getDirection() { + return direction; + } + public void setDirection(String direction) { + this.direction = direction; + } + } + + public static class Pagination { + @JsonProperty("page") + private Integer page; + + @JsonProperty("pageSize") + private Integer pageSize; + + @JsonProperty("offset") + private Integer offset; + + // Getters and setters + public Integer getPage() { + return page; + } + public void setPage(Integer page) { + this.page = page; + } + public Integer getPageSize() { + return pageSize; + } + public void setPageSize(Integer pageSize) { + this.pageSize = pageSize; + } + public Integer getOffset() { + return offset; + } + public void setOffset(Integer offset) { + this.offset = offset; + } + } + + public static class ExportOptions { + @JsonProperty("format") + private String format; // csv, json, excel, pdf + + @JsonProperty("includeHeaders") + private Boolean includeHeaders; + + @JsonProperty("filename") + private String filename; + + @JsonProperty("compression") + private String compression; // none, gzip, zip + + // Getters and setters + public String getFormat() { + return format; + } + public void setFormat(String format) { + this.format = format; + } + public Boolean getIncludeHeaders() { + return includeHeaders; + } + public void setIncludeHeaders(Boolean includeHeaders) { + this.includeHeaders = includeHeaders; + } + public String getFilename() { + return filename; + } + public void setFilename(String filename) { + this.filename = filename; + } + public String getCompression() { + return compression; + } + public void setCompression(String compression) { + this.compression = compression; + } + } + + // Main class getters and setters + public String getDashboardId() { + return dashboardId; + } + public void setDashboardId(String dashboardId) { + this.dashboardId = dashboardId; + } + public DateRange getDateRange() { + return dateRange; + } + public void setDateRange(DateRange dateRange) { + this.dateRange = dateRange; + } + public List getFilters() { + return filters; + } + public void setFilters(List filters) { + this.filters = filters; + } + public List getGroupBy() { + return groupBy; + } + public void setGroupBy(List groupBy) { + this.groupBy = groupBy; + } + public List getMetrics() { + return metrics; + } + public void setMetrics(List metrics) { + this.metrics = metrics; + } + public List getSorting() { + return sorting; + } + public void setSorting(List sorting) { + this.sorting = sorting; + } + public Pagination getPagination() { + return pagination; + } + public void setPagination(Pagination pagination) { + this.pagination = pagination; + } + public ExportOptions getExportOptions() { + return exportOptions; + } + public void setExportOptions(ExportOptions exportOptions) { + this.exportOptions = exportOptions; + } +} diff --git a/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/AnalyticsResponse.java b/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/AnalyticsResponse.java new file mode 100644 index 0000000000..9ac16820c8 --- /dev/null +++ b/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/AnalyticsResponse.java @@ -0,0 +1,664 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples.types; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Complex analytics response with multiple data visualizations. + */ +public class AnalyticsResponse { + + @JsonProperty("requestId") + private String requestId; + + @JsonProperty("executionTime") + private String executionTime; + + @JsonProperty("summary") + private Summary summary; + + @JsonProperty("data") + private DataResult data; + + @JsonProperty("visualizations") + private List visualizations; + + @JsonProperty("insights") + private List insights; + + @JsonProperty("metadata") + private ResponseMetadata metadata; + + // Nested types + public static class Summary { + @JsonProperty("totalRecords") + private Long totalRecords; + + @JsonProperty("metrics") + private Map metrics; + + @JsonProperty("trends") + private List trends; + + // Getters and setters + public Long getTotalRecords() { + return totalRecords; + } + public void setTotalRecords(Long totalRecords) { + this.totalRecords = totalRecords; + } + public Map getMetrics() { + return metrics; + } + public void setMetrics(Map metrics) { + this.metrics = metrics; + } + public List getTrends() { + return trends; + } + public void setTrends(List trends) { + this.trends = trends; + } + } + + public static class MetricValue { + @JsonProperty("value") + private Object value; + + @JsonProperty("formatted") + private String formatted; + + @JsonProperty("change") + private ChangeIndicator change; + + // Getters and setters + public Object getValue() { + return value; + } + public void setValue(Object value) { + this.value = value; + } + public String getFormatted() { + return formatted; + } + public void setFormatted(String formatted) { + this.formatted = formatted; + } + public ChangeIndicator getChange() { + return change; + } + public void setChange(ChangeIndicator change) { + this.change = change; + } + } + + public static class ChangeIndicator { + @JsonProperty("absolute") + private Double absolute; + + @JsonProperty("percentage") + private Double percentage; + + @JsonProperty("direction") + private String direction; // up, down, stable + + @JsonProperty("comparisonPeriod") + private String comparisonPeriod; + + // Getters and setters + public Double getAbsolute() { + return absolute; + } + public void setAbsolute(Double absolute) { + this.absolute = absolute; + } + public Double getPercentage() { + return percentage; + } + public void setPercentage(Double percentage) { + this.percentage = percentage; + } + public String getDirection() { + return direction; + } + public void setDirection(String direction) { + this.direction = direction; + } + public String getComparisonPeriod() { + return comparisonPeriod; + } + public void setComparisonPeriod(String comparisonPeriod) { + this.comparisonPeriod = comparisonPeriod; + } + } + + public static class TrendIndicator { + @JsonProperty("metric") + private String metric; + + @JsonProperty("trend") + private String trend; + + @JsonProperty("confidence") + private Double confidence; + + @JsonProperty("forecast") + private List forecast; + + // Getters and setters + public String getMetric() { + return metric; + } + public void setMetric(String metric) { + this.metric = metric; + } + public String getTrend() { + return trend; + } + public void setTrend(String trend) { + this.trend = trend; + } + public Double getConfidence() { + return confidence; + } + public void setConfidence(Double confidence) { + this.confidence = confidence; + } + public List getForecast() { + return forecast; + } + public void setForecast(List forecast) { + this.forecast = forecast; + } + } + + public static class ForecastPoint { + @JsonProperty("date") + private String date; + + @JsonProperty("value") + private Double value; + + @JsonProperty("lower") + private Double lower; + + @JsonProperty("upper") + private Double upper; + + // Getters and setters + public String getDate() { + return date; + } + public void setDate(String date) { + this.date = date; + } + public Double getValue() { + return value; + } + public void setValue(Double value) { + this.value = value; + } + public Double getLower() { + return lower; + } + public void setLower(Double lower) { + this.lower = lower; + } + public Double getUpper() { + return upper; + } + public void setUpper(Double upper) { + this.upper = upper; + } + } + + public static class DataResult { + @JsonProperty("columns") + private List columns; + + @JsonProperty("rows") + private List> rows; + + @JsonProperty("aggregations") + private Map aggregations; + + // Getters and setters + public List getColumns() { + return columns; + } + public void setColumns(List columns) { + this.columns = columns; + } + public List> getRows() { + return rows; + } + public void setRows(List> rows) { + this.rows = rows; + } + public Map getAggregations() { + return aggregations; + } + public void setAggregations(Map aggregations) { + this.aggregations = aggregations; + } + } + + public static class ColumnDefinition { + @JsonProperty("name") + private String name; + + @JsonProperty("type") + private String type; + + @JsonProperty("format") + private String format; + + @JsonProperty("aggregation") + private String aggregation; + + // Getters and setters + public String getName() { + return name; + } + public void setName(String name) { + this.name = name; + } + public String getType() { + return type; + } + public void setType(String type) { + this.type = type; + } + public String getFormat() { + return format; + } + public void setFormat(String format) { + this.format = format; + } + public String getAggregation() { + return aggregation; + } + public void setAggregation(String aggregation) { + this.aggregation = aggregation; + } + } + + public static class Visualization { + @JsonProperty("id") + private String id; + + @JsonProperty("type") + private String type; // line, bar, pie, scatter, heatmap, table + + @JsonProperty("title") + private String title; + + @JsonProperty("data") + private VisualizationData data; + + @JsonProperty("config") + private VisualizationConfig config; + + // Getters and setters + public String getId() { + return id; + } + public void setId(String id) { + this.id = id; + } + public String getType() { + return type; + } + public void setType(String type) { + this.type = type; + } + public String getTitle() { + return title; + } + public void setTitle(String title) { + this.title = title; + } + public VisualizationData getData() { + return data; + } + public void setData(VisualizationData data) { + this.data = data; + } + public VisualizationConfig getConfig() { + return config; + } + public void setConfig(VisualizationConfig config) { + this.config = config; + } + } + + public static class VisualizationData { + @JsonProperty("labels") + private List labels; + + @JsonProperty("datasets") + private List datasets; + + // Getters and setters + public List getLabels() { + return labels; + } + public void setLabels(List labels) { + this.labels = labels; + } + public List getDatasets() { + return datasets; + } + public void setDatasets(List datasets) { + this.datasets = datasets; + } + } + + public static class Dataset { + @JsonProperty("label") + private String label; + + @JsonProperty("data") + private List data; + + @JsonProperty("color") + private String color; + + // Getters and setters + public String getLabel() { + return label; + } + public void setLabel(String label) { + this.label = label; + } + public List getData() { + return data; + } + public void setData(List data) { + this.data = data; + } + public String getColor() { + return color; + } + public void setColor(String color) { + this.color = color; + } + } + + public static class VisualizationConfig { + @JsonProperty("xAxis") + private AxisConfig xAxis; + + @JsonProperty("yAxis") + private AxisConfig yAxis; + + @JsonProperty("legend") + private LegendConfig legend; + + // Getters and setters + public AxisConfig getXAxis() { + return xAxis; + } + public void setXAxis(AxisConfig xAxis) { + this.xAxis = xAxis; + } + public AxisConfig getYAxis() { + return yAxis; + } + public void setYAxis(AxisConfig yAxis) { + this.yAxis = yAxis; + } + public LegendConfig getLegend() { + return legend; + } + public void setLegend(LegendConfig legend) { + this.legend = legend; + } + } + + public static class AxisConfig { + @JsonProperty("label") + private String label; + + @JsonProperty("type") + private String type; + + @JsonProperty("format") + private String format; + + // Getters and setters + public String getLabel() { + return label; + } + public void setLabel(String label) { + this.label = label; + } + public String getType() { + return type; + } + public void setType(String type) { + this.type = type; + } + public String getFormat() { + return format; + } + public void setFormat(String format) { + this.format = format; + } + } + + public static class LegendConfig { + @JsonProperty("position") + private String position; + + @JsonProperty("visible") + private Boolean visible; + + // Getters and setters + public String getPosition() { + return position; + } + public void setPosition(String position) { + this.position = position; + } + public Boolean getVisible() { + return visible; + } + public void setVisible(Boolean visible) { + this.visible = visible; + } + } + + public static class Insight { + @JsonProperty("type") + private String type; // anomaly, trend, correlation, recommendation + + @JsonProperty("severity") + private String severity; // info, warning, critical + + @JsonProperty("title") + private String title; + + @JsonProperty("description") + private String description; + + @JsonProperty("metrics") + private List metrics; + + @JsonProperty("actions") + private List actions; + + // Getters and setters + public String getType() { + return type; + } + public void setType(String type) { + this.type = type; + } + public String getSeverity() { + return severity; + } + public void setSeverity(String severity) { + this.severity = severity; + } + public String getTitle() { + return title; + } + public void setTitle(String title) { + this.title = title; + } + public String getDescription() { + return description; + } + public void setDescription(String description) { + this.description = description; + } + public List getMetrics() { + return metrics; + } + public void setMetrics(List metrics) { + this.metrics = metrics; + } + public List getActions() { + return actions; + } + public void setActions(List actions) { + this.actions = actions; + } + } + + public static class SuggestedAction { + @JsonProperty("action") + private String action; + + @JsonProperty("impact") + private String impact; + + @JsonProperty("effort") + private String effort; + + // Getters and setters + public String getAction() { + return action; + } + public void setAction(String action) { + this.action = action; + } + public String getImpact() { + return impact; + } + public void setImpact(String impact) { + this.impact = impact; + } + public String getEffort() { + return effort; + } + public void setEffort(String effort) { + this.effort = effort; + } + } + + public static class ResponseMetadata { + @JsonProperty("queryId") + private String queryId; + + @JsonProperty("cacheHit") + private Boolean cacheHit; + + @JsonProperty("dataFreshness") + private String dataFreshness; + + @JsonProperty("warnings") + private List warnings; + + // Getters and setters + public String getQueryId() { + return queryId; + } + public void setQueryId(String queryId) { + this.queryId = queryId; + } + public Boolean getCacheHit() { + return cacheHit; + } + public void setCacheHit(Boolean cacheHit) { + this.cacheHit = cacheHit; + } + public String getDataFreshness() { + return dataFreshness; + } + public void setDataFreshness(String dataFreshness) { + this.dataFreshness = dataFreshness; + } + public List getWarnings() { + return warnings; + } + public void setWarnings(List warnings) { + this.warnings = warnings; + } + } + + // Main class getters and setters + public String getRequestId() { + return requestId; + } + public void setRequestId(String requestId) { + this.requestId = requestId; + } + public String getExecutionTime() { + return executionTime; + } + public void setExecutionTime(String executionTime) { + this.executionTime = executionTime; + } + public Summary getSummary() { + return summary; + } + public void setSummary(Summary summary) { + this.summary = summary; + } + public DataResult getData() { + return data; + } + public void setData(DataResult data) { + this.data = data; + } + public List getVisualizations() { + return visualizations; + } + public void setVisualizations(List visualizations) { + this.visualizations = visualizations; + } + public List getInsights() { + return insights; + } + public void setInsights(List insights) { + this.insights = insights; + } + public ResponseMetadata getMetadata() { + return metadata; + } + public void setMetadata(ResponseMetadata metadata) { + this.metadata = metadata; + } +} diff --git a/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/OrderRequest.java b/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/OrderRequest.java new file mode 100644 index 0000000000..95141dfb4c --- /dev/null +++ b/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/OrderRequest.java @@ -0,0 +1,465 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples.types; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Complex order input demonstrating nested types. + */ +public class OrderRequest { + + @JsonProperty("customerId") + private String customerId; + + @JsonProperty("customer") + private Customer customer; + + @JsonProperty("items") + private List items; + + @JsonProperty("shippingAddress") + private Address shippingAddress; + + @JsonProperty("billingAddress") + private Address billingAddress; + + @JsonProperty("paymentMethod") + private PaymentMethod paymentMethod; + + @JsonProperty("orderNotes") + private String orderNotes; + + @JsonProperty("metadata") + private Map metadata; + + // Nested types + public static class Customer { + @JsonProperty("id") + private String id; + + @JsonProperty("firstName") + private String firstName; + + @JsonProperty("lastName") + private String lastName; + + @JsonProperty("email") + private String email; + + @JsonProperty("phone") + private String phone; + + @JsonProperty("preferences") + private CustomerPreferences preferences; + + // Getters and setters + public String getId() { + return id; + } + public void setId(String id) { + this.id = id; + } + public String getFirstName() { + return firstName; + } + public void setFirstName(String firstName) { + this.firstName = firstName; + } + public String getLastName() { + return lastName; + } + public void setLastName(String lastName) { + this.lastName = lastName; + } + public String getEmail() { + return email; + } + public void setEmail(String email) { + this.email = email; + } + public String getPhone() { + return phone; + } + public void setPhone(String phone) { + this.phone = phone; + } + public CustomerPreferences getPreferences() { + return preferences; + } + public void setPreferences(CustomerPreferences preferences) { + this.preferences = preferences; + } + } + + public static class CustomerPreferences { + @JsonProperty("communicationChannel") + private String communicationChannel; // email, sms, phone + + @JsonProperty("marketingOptIn") + private Boolean marketingOptIn; + + @JsonProperty("language") + private String language; + + // Getters and setters + public String getCommunicationChannel() { + return communicationChannel; + } + public void setCommunicationChannel(String communicationChannel) { + this.communicationChannel = communicationChannel; + } + public Boolean getMarketingOptIn() { + return marketingOptIn; + } + public void setMarketingOptIn(Boolean marketingOptIn) { + this.marketingOptIn = marketingOptIn; + } + public String getLanguage() { + return language; + } + public void setLanguage(String language) { + this.language = language; + } + } + + public static class OrderItem { + @JsonProperty("productId") + private String productId; + + @JsonProperty("name") + private String name; + + @JsonProperty("quantity") + private Integer quantity; + + @JsonProperty("unitPrice") + private Double unitPrice; + + @JsonProperty("discount") + private Discount discount; + + @JsonProperty("customizations") + private List customizations; + + // Getters and setters + public String getProductId() { + return productId; + } + public void setProductId(String productId) { + this.productId = productId; + } + public String getName() { + return name; + } + public void setName(String name) { + this.name = name; + } + public Integer getQuantity() { + return quantity; + } + public void setQuantity(Integer quantity) { + this.quantity = quantity; + } + public Double getUnitPrice() { + return unitPrice; + } + public void setUnitPrice(Double unitPrice) { + this.unitPrice = unitPrice; + } + public Discount getDiscount() { + return discount; + } + public void setDiscount(Discount discount) { + this.discount = discount; + } + public List getCustomizations() { + return customizations; + } + public void setCustomizations(List customizations) { + this.customizations = customizations; + } + } + + public static class Discount { + @JsonProperty("type") + private String type; // percentage, fixed + + @JsonProperty("value") + private Double value; + + @JsonProperty("code") + private String code; + + // Getters and setters + public String getType() { + return type; + } + public void setType(String type) { + this.type = type; + } + public Double getValue() { + return value; + } + public void setValue(Double value) { + this.value = value; + } + public String getCode() { + return code; + } + public void setCode(String code) { + this.code = code; + } + } + + public static class Customization { + @JsonProperty("type") + private String type; + + @JsonProperty("value") + private String value; + + @JsonProperty("additionalCost") + private Double additionalCost; + + // Getters and setters + public String getType() { + return type; + } + public void setType(String type) { + this.type = type; + } + public String getValue() { + return value; + } + public void setValue(String value) { + this.value = value; + } + public Double getAdditionalCost() { + return additionalCost; + } + public void setAdditionalCost(Double additionalCost) { + this.additionalCost = additionalCost; + } + } + + public static class Address { + @JsonProperty("street1") + private String street1; + + @JsonProperty("street2") + private String street2; + + @JsonProperty("city") + private String city; + + @JsonProperty("state") + private String state; + + @JsonProperty("postalCode") + private String postalCode; + + @JsonProperty("country") + private String country; + + @JsonProperty("coordinates") + private Coordinates coordinates; + + // Getters and setters + public String getStreet1() { + return street1; + } + public void setStreet1(String street1) { + this.street1 = street1; + } + public String getStreet2() { + return street2; + } + public void setStreet2(String street2) { + this.street2 = street2; + } + public String getCity() { + return city; + } + public void setCity(String city) { + this.city = city; + } + public String getState() { + return state; + } + public void setState(String state) { + this.state = state; + } + public String getPostalCode() { + return postalCode; + } + public void setPostalCode(String postalCode) { + this.postalCode = postalCode; + } + public String getCountry() { + return country; + } + public void setCountry(String country) { + this.country = country; + } + public Coordinates getCoordinates() { + return coordinates; + } + public void setCoordinates(Coordinates coordinates) { + this.coordinates = coordinates; + } + } + + public static class Coordinates { + @JsonProperty("latitude") + private Double latitude; + + @JsonProperty("longitude") + private Double longitude; + + // Getters and setters + public Double getLatitude() { + return latitude; + } + public void setLatitude(Double latitude) { + this.latitude = latitude; + } + public Double getLongitude() { + return longitude; + } + public void setLongitude(Double longitude) { + this.longitude = longitude; + } + } + + public static class PaymentMethod { + @JsonProperty("type") + private String type; // credit_card, debit_card, paypal, bank_transfer + + @JsonProperty("details") + private PaymentDetails details; + + // Getters and setters + public String getType() { + return type; + } + public void setType(String type) { + this.type = type; + } + public PaymentDetails getDetails() { + return details; + } + public void setDetails(PaymentDetails details) { + this.details = details; + } + } + + public static class PaymentDetails { + @JsonProperty("lastFourDigits") + private String lastFourDigits; + + @JsonProperty("cardType") + private String cardType; + + @JsonProperty("expirationMonth") + private Integer expirationMonth; + + @JsonProperty("expirationYear") + private Integer expirationYear; + + // Getters and setters + public String getLastFourDigits() { + return lastFourDigits; + } + public void setLastFourDigits(String lastFourDigits) { + this.lastFourDigits = lastFourDigits; + } + public String getCardType() { + return cardType; + } + public void setCardType(String cardType) { + this.cardType = cardType; + } + public Integer getExpirationMonth() { + return expirationMonth; + } + public void setExpirationMonth(Integer expirationMonth) { + this.expirationMonth = expirationMonth; + } + public Integer getExpirationYear() { + return expirationYear; + } + public void setExpirationYear(Integer expirationYear) { + this.expirationYear = expirationYear; + } + } + + // Main class getters and setters + public String getCustomerId() { + return customerId; + } + public void setCustomerId(String customerId) { + this.customerId = customerId; + } + public Customer getCustomer() { + return customer; + } + public void setCustomer(Customer customer) { + this.customer = customer; + } + public List getItems() { + return items; + } + public void setItems(List items) { + this.items = items; + } + public Address getShippingAddress() { + return shippingAddress; + } + public void setShippingAddress(Address shippingAddress) { + this.shippingAddress = shippingAddress; + } + public Address getBillingAddress() { + return billingAddress; + } + public void setBillingAddress(Address billingAddress) { + this.billingAddress = billingAddress; + } + public PaymentMethod getPaymentMethod() { + return paymentMethod; + } + public void setPaymentMethod(PaymentMethod paymentMethod) { + this.paymentMethod = paymentMethod; + } + public String getOrderNotes() { + return orderNotes; + } + public void setOrderNotes(String orderNotes) { + this.orderNotes = orderNotes; + } + public Map getMetadata() { + return metadata; + } + public void setMetadata(Map metadata) { + this.metadata = metadata; + } +} diff --git a/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/OrderResponse.java b/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/OrderResponse.java new file mode 100644 index 0000000000..70228d19d7 --- /dev/null +++ b/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/OrderResponse.java @@ -0,0 +1,682 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples.types; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Complex order response demonstrating deeply nested output types. + */ +public class OrderResponse { + + @JsonProperty("orderId") + private String orderId; + + @JsonProperty("status") + private String status; + + @JsonProperty("customer") + private CustomerSummary customer; + + @JsonProperty("orderSummary") + private OrderSummary orderSummary; + + @JsonProperty("shipping") + private ShippingInfo shipping; + + @JsonProperty("payment") + private PaymentInfo payment; + + @JsonProperty("timeline") + private List timeline; + + @JsonProperty("recommendations") + private List recommendations; + + @JsonProperty("analytics") + private OrderAnalytics analytics; + + // Nested types + public static class CustomerSummary { + @JsonProperty("id") + private String id; + + @JsonProperty("fullName") + private String fullName; + + @JsonProperty("email") + private String email; + + @JsonProperty("loyaltyTier") + private String loyaltyTier; + + @JsonProperty("totalOrders") + private Integer totalOrders; + + // Getters and setters + public String getId() { + return id; + } + public void setId(String id) { + this.id = id; + } + public String getFullName() { + return fullName; + } + public void setFullName(String fullName) { + this.fullName = fullName; + } + public String getEmail() { + return email; + } + public void setEmail(String email) { + this.email = email; + } + public String getLoyaltyTier() { + return loyaltyTier; + } + public void setLoyaltyTier(String loyaltyTier) { + this.loyaltyTier = loyaltyTier; + } + public Integer getTotalOrders() { + return totalOrders; + } + public void setTotalOrders(Integer totalOrders) { + this.totalOrders = totalOrders; + } + } + + public static class OrderSummary { + @JsonProperty("itemCount") + private Integer itemCount; + + @JsonProperty("items") + private List items; + + @JsonProperty("subtotal") + private MoneyAmount subtotal; + + @JsonProperty("discounts") + private List discounts; + + @JsonProperty("tax") + private TaxInfo tax; + + @JsonProperty("shipping") + private MoneyAmount shipping; + + @JsonProperty("total") + private MoneyAmount total; + + // Getters and setters + public Integer getItemCount() { + return itemCount; + } + public void setItemCount(Integer itemCount) { + this.itemCount = itemCount; + } + public List getItems() { + return items; + } + public void setItems(List items) { + this.items = items; + } + public MoneyAmount getSubtotal() { + return subtotal; + } + public void setSubtotal(MoneyAmount subtotal) { + this.subtotal = subtotal; + } + public List getDiscounts() { + return discounts; + } + public void setDiscounts(List discounts) { + this.discounts = discounts; + } + public TaxInfo getTax() { + return tax; + } + public void setTax(TaxInfo tax) { + this.tax = tax; + } + public MoneyAmount getShipping() { + return shipping; + } + public void setShipping(MoneyAmount shipping) { + this.shipping = shipping; + } + public MoneyAmount getTotal() { + return total; + } + public void setTotal(MoneyAmount total) { + this.total = total; + } + } + + public static class ProcessedItem { + @JsonProperty("productId") + private String productId; + + @JsonProperty("name") + private String name; + + @JsonProperty("quantity") + private Integer quantity; + + @JsonProperty("unitPrice") + private MoneyAmount unitPrice; + + @JsonProperty("totalPrice") + private MoneyAmount totalPrice; + + @JsonProperty("customizations") + private List customizations; + + @JsonProperty("estimatedDelivery") + private String estimatedDelivery; + + // Getters and setters + public String getProductId() { + return productId; + } + public void setProductId(String productId) { + this.productId = productId; + } + public String getName() { + return name; + } + public void setName(String name) { + this.name = name; + } + public Integer getQuantity() { + return quantity; + } + public void setQuantity(Integer quantity) { + this.quantity = quantity; + } + public MoneyAmount getUnitPrice() { + return unitPrice; + } + public void setUnitPrice(MoneyAmount unitPrice) { + this.unitPrice = unitPrice; + } + public MoneyAmount getTotalPrice() { + return totalPrice; + } + public void setTotalPrice(MoneyAmount totalPrice) { + this.totalPrice = totalPrice; + } + public List getCustomizations() { + return customizations; + } + public void setCustomizations(List customizations) { + this.customizations = customizations; + } + public String getEstimatedDelivery() { + return estimatedDelivery; + } + public void setEstimatedDelivery(String estimatedDelivery) { + this.estimatedDelivery = estimatedDelivery; + } + } + + public static class MoneyAmount { + @JsonProperty("amount") + private Double amount; + + @JsonProperty("currency") + private String currency; + + @JsonProperty("formatted") + private String formatted; + + // Getters and setters + public Double getAmount() { + return amount; + } + public void setAmount(Double amount) { + this.amount = amount; + } + public String getCurrency() { + return currency; + } + public void setCurrency(String currency) { + this.currency = currency; + } + public String getFormatted() { + return formatted; + } + public void setFormatted(String formatted) { + this.formatted = formatted; + } + } + + public static class AppliedDiscount { + @JsonProperty("code") + private String code; + + @JsonProperty("description") + private String description; + + @JsonProperty("savedAmount") + private MoneyAmount savedAmount; + + // Getters and setters + public String getCode() { + return code; + } + public void setCode(String code) { + this.code = code; + } + public String getDescription() { + return description; + } + public void setDescription(String description) { + this.description = description; + } + public MoneyAmount getSavedAmount() { + return savedAmount; + } + public void setSavedAmount(MoneyAmount savedAmount) { + this.savedAmount = savedAmount; + } + } + + public static class TaxInfo { + @JsonProperty("rate") + private Double rate; + + @JsonProperty("amount") + private MoneyAmount amount; + + @JsonProperty("breakdown") + private List breakdown; + + // Getters and setters + public Double getRate() { + return rate; + } + public void setRate(Double rate) { + this.rate = rate; + } + public MoneyAmount getAmount() { + return amount; + } + public void setAmount(MoneyAmount amount) { + this.amount = amount; + } + public List getBreakdown() { + return breakdown; + } + public void setBreakdown(List breakdown) { + this.breakdown = breakdown; + } + } + + public static class TaxBreakdown { + @JsonProperty("type") + private String type; + + @JsonProperty("rate") + private Double rate; + + @JsonProperty("amount") + private MoneyAmount amount; + + // Getters and setters + public String getType() { + return type; + } + public void setType(String type) { + this.type = type; + } + public Double getRate() { + return rate; + } + public void setRate(Double rate) { + this.rate = rate; + } + public MoneyAmount getAmount() { + return amount; + } + public void setAmount(MoneyAmount amount) { + this.amount = amount; + } + } + + public static class ShippingInfo { + @JsonProperty("method") + private String method; + + @JsonProperty("carrier") + private String carrier; + + @JsonProperty("trackingNumber") + private String trackingNumber; + + @JsonProperty("estimatedDelivery") + private DateRange estimatedDelivery; + + @JsonProperty("address") + private FormattedAddress address; + + // Getters and setters + public String getMethod() { + return method; + } + public void setMethod(String method) { + this.method = method; + } + public String getCarrier() { + return carrier; + } + public void setCarrier(String carrier) { + this.carrier = carrier; + } + public String getTrackingNumber() { + return trackingNumber; + } + public void setTrackingNumber(String trackingNumber) { + this.trackingNumber = trackingNumber; + } + public DateRange getEstimatedDelivery() { + return estimatedDelivery; + } + public void setEstimatedDelivery(DateRange estimatedDelivery) { + this.estimatedDelivery = estimatedDelivery; + } + public FormattedAddress getAddress() { + return address; + } + public void setAddress(FormattedAddress address) { + this.address = address; + } + } + + public static class DateRange { + @JsonProperty("earliest") + private String earliest; + + @JsonProperty("latest") + private String latest; + + // Getters and setters + public String getEarliest() { + return earliest; + } + public void setEarliest(String earliest) { + this.earliest = earliest; + } + public String getLatest() { + return latest; + } + public void setLatest(String latest) { + this.latest = latest; + } + } + + public static class FormattedAddress { + @JsonProperty("lines") + private List lines; + + @JsonProperty("formatted") + private String formatted; + + // Getters and setters + public List getLines() { + return lines; + } + public void setLines(List lines) { + this.lines = lines; + } + public String getFormatted() { + return formatted; + } + public void setFormatted(String formatted) { + this.formatted = formatted; + } + } + + public static class PaymentInfo { + @JsonProperty("status") + private String status; + + @JsonProperty("method") + private String method; + + @JsonProperty("transactionId") + private String transactionId; + + @JsonProperty("billingAddress") + private FormattedAddress billingAddress; + + // Getters and setters + public String getStatus() { + return status; + } + public void setStatus(String status) { + this.status = status; + } + public String getMethod() { + return method; + } + public void setMethod(String method) { + this.method = method; + } + public String getTransactionId() { + return transactionId; + } + public void setTransactionId(String transactionId) { + this.transactionId = transactionId; + } + public FormattedAddress getBillingAddress() { + return billingAddress; + } + public void setBillingAddress(FormattedAddress billingAddress) { + this.billingAddress = billingAddress; + } + } + + public static class TimelineEvent { + @JsonProperty("timestamp") + private String timestamp; + + @JsonProperty("event") + private String event; + + @JsonProperty("description") + private String description; + + @JsonProperty("actor") + private String actor; + + // Getters and setters + public String getTimestamp() { + return timestamp; + } + public void setTimestamp(String timestamp) { + this.timestamp = timestamp; + } + public String getEvent() { + return event; + } + public void setEvent(String event) { + this.event = event; + } + public String getDescription() { + return description; + } + public void setDescription(String description) { + this.description = description; + } + public String getActor() { + return actor; + } + public void setActor(String actor) { + this.actor = actor; + } + } + + public static class Recommendation { + @JsonProperty("productId") + private String productId; + + @JsonProperty("name") + private String name; + + @JsonProperty("reason") + private String reason; + + @JsonProperty("price") + private MoneyAmount price; + + @JsonProperty("score") + private Double score; + + // Getters and setters + public String getProductId() { + return productId; + } + public void setProductId(String productId) { + this.productId = productId; + } + public String getName() { + return name; + } + public void setName(String name) { + this.name = name; + } + public String getReason() { + return reason; + } + public void setReason(String reason) { + this.reason = reason; + } + public MoneyAmount getPrice() { + return price; + } + public void setPrice(MoneyAmount price) { + this.price = price; + } + public Double getScore() { + return score; + } + public void setScore(Double score) { + this.score = score; + } + } + + public static class OrderAnalytics { + @JsonProperty("processingTime") + private String processingTime; + + @JsonProperty("fraudRiskScore") + private Double fraudRiskScore; + + @JsonProperty("customerLifetimeValue") + private MoneyAmount customerLifetimeValue; + + @JsonProperty("tags") + private List tags; + + // Getters and setters + public String getProcessingTime() { + return processingTime; + } + public void setProcessingTime(String processingTime) { + this.processingTime = processingTime; + } + public Double getFraudRiskScore() { + return fraudRiskScore; + } + public void setFraudRiskScore(Double fraudRiskScore) { + this.fraudRiskScore = fraudRiskScore; + } + public MoneyAmount getCustomerLifetimeValue() { + return customerLifetimeValue; + } + public void setCustomerLifetimeValue(MoneyAmount customerLifetimeValue) { + this.customerLifetimeValue = customerLifetimeValue; + } + public List getTags() { + return tags; + } + public void setTags(List tags) { + this.tags = tags; + } + } + + // Main class getters and setters + public String getOrderId() { + return orderId; + } + public void setOrderId(String orderId) { + this.orderId = orderId; + } + public String getStatus() { + return status; + } + public void setStatus(String status) { + this.status = status; + } + public CustomerSummary getCustomer() { + return customer; + } + public void setCustomer(CustomerSummary customer) { + this.customer = customer; + } + public OrderSummary getOrderSummary() { + return orderSummary; + } + public void setOrderSummary(OrderSummary orderSummary) { + this.orderSummary = orderSummary; + } + public ShippingInfo getShipping() { + return shipping; + } + public void setShipping(ShippingInfo shipping) { + this.shipping = shipping; + } + public PaymentInfo getPayment() { + return payment; + } + public void setPayment(PaymentInfo payment) { + this.payment = payment; + } + public List getTimeline() { + return timeline; + } + public void setTimeline(List timeline) { + this.timeline = timeline; + } + public List getRecommendations() { + return recommendations; + } + public void setRecommendations(List recommendations) { + this.recommendations = recommendations; + } + public OrderAnalytics getAnalytics() { + return analytics; + } + public void setAnalytics(OrderAnalytics analytics) { + this.analytics = analytics; + } +} diff --git a/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/ValidationResult.java b/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/ValidationResult.java new file mode 100644 index 0000000000..b07da51263 --- /dev/null +++ b/java/samples/complex-io/src/main/java/com/google/genkit/samples/types/ValidationResult.java @@ -0,0 +1,139 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples.types; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Validation result with errors and warnings. + */ +public class ValidationResult { + + @JsonProperty("valid") + private Boolean valid; + + @JsonProperty("errors") + private List errors; + + @JsonProperty("warnings") + private List warnings; + + @JsonProperty("errorCount") + private Integer errorCount; + + @JsonProperty("warningCount") + private Integer warningCount; + + // Nested types + public static class ValidationError { + @JsonProperty("field") + private String field; + + @JsonProperty("message") + private String message; + + @JsonProperty("severity") + private String severity; + + @JsonProperty("code") + private String code; + + // Getters and setters + public String getField() { + return field; + } + public void setField(String field) { + this.field = field; + } + public String getMessage() { + return message; + } + public void setMessage(String message) { + this.message = message; + } + public String getSeverity() { + return severity; + } + public void setSeverity(String severity) { + this.severity = severity; + } + public String getCode() { + return code; + } + public void setCode(String code) { + this.code = code; + } + } + + public static class ValidationWarning { + @JsonProperty("field") + private String field; + + @JsonProperty("message") + private String message; + + // Getters and setters + public String getField() { + return field; + } + public void setField(String field) { + this.field = field; + } + public String getMessage() { + return message; + } + public void setMessage(String message) { + this.message = message; + } + } + + // Main class getters and setters + public Boolean getValid() { + return valid; + } + public void setValid(Boolean valid) { + this.valid = valid; + } + public List getErrors() { + return errors; + } + public void setErrors(List errors) { + this.errors = errors; + } + public List getWarnings() { + return warnings; + } + public void setWarnings(List warnings) { + this.warnings = warnings; + } + public Integer getErrorCount() { + return errorCount; + } + public void setErrorCount(Integer errorCount) { + this.errorCount = errorCount; + } + public Integer getWarningCount() { + return warningCount; + } + public void setWarningCount(Integer warningCount) { + this.warningCount = warningCount; + } +} diff --git a/java/samples/complex-io/src/main/resources/logback.xml b/java/samples/complex-io/src/main/resources/logback.xml new file mode 100644 index 0000000000..56110e53c0 --- /dev/null +++ b/java/samples/complex-io/src/main/resources/logback.xml @@ -0,0 +1,25 @@ + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + + + diff --git a/java/samples/dotprompt/README.md b/java/samples/dotprompt/README.md new file mode 100644 index 0000000000..6f2ebd9833 --- /dev/null +++ b/java/samples/dotprompt/README.md @@ -0,0 +1,163 @@ +# Genkit DotPrompt Sample + +This sample demonstrates how to use DotPrompt files with Genkit Java for building AI applications with complex inputs and outputs. + +## Features Demonstrated + +- **DotPrompt Files**: Load and use `.prompt` files with Handlebars templating +- **Complex Input Schemas**: Handle nested objects, arrays, and optional fields +- **Complex Output Schemas**: Parse structured JSON responses into Java objects +- **Prompt Variants**: Use different prompt variations (e.g., `recipe.robot.prompt`) +- **Partials**: Include reusable template fragments (e.g., `_style.prompt`) + +## Prompt Files + +Located in `src/main/resources/prompts/`: + +- `recipe.prompt` - Generate recipes with structured output +- `recipe.robot.prompt` - Robot-themed recipe generation +- `story.prompt` - Story generation with personality options +- `travel-planner.prompt` - Complex travel itinerary generation +- `code-review.prompt` - Code analysis with detailed output +- `_style.prompt` - Partial template for personality styling + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/dotprompt + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI (Recommended) + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/dotprompt + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +## Example API Calls + +### Generate a Recipe +```bash +curl -X POST http://localhost:8080/chefFlow \ + -H 'Content-Type: application/json' \ + -d '{"food":"pasta carbonara","ingredients":["bacon","eggs","parmesan"]}' +``` + +### Robot Chef Recipe +```bash +curl -X POST http://localhost:8080/robotChefFlow \ + -H 'Content-Type: application/json' \ + -d '{"food":"pizza"}' +``` + +### Tell a Story +```bash +curl -X POST http://localhost:8080/tellStory \ + -H 'Content-Type: application/json' \ + -d '{"subject":"a brave knight","personality":"dramatic","length":"short"}' +``` + +### Plan a Trip +```bash +curl -X POST http://localhost:8080/planTrip \ + -H 'Content-Type: application/json' \ + -d '{ + "destination": "Tokyo", + "duration": 5, + "budget": "$3000", + "interests": ["food", "culture", "technology"], + "travelStyle": "adventure" + }' +``` + +### Code Review +```bash +curl -X POST http://localhost:8080/reviewCode \ + -H 'Content-Type: application/json' \ + -d '{ + "code": "function add(a, b) { return a + b; }", + "language": "javascript", + "analysisType": "best practices" + }' +``` + +## Output Schemas + +### Recipe +```json +{ + "title": "Pasta Carbonara", + "ingredients": [ + {"name": "spaghetti", "quantity": "400g"}, + {"name": "bacon", "quantity": "200g"} + ], + "steps": ["Step 1...", "Step 2..."], + "prepTime": "10 minutes", + "cookTime": "20 minutes", + "servings": 4 +} +``` + +### Travel Itinerary +```json +{ + "tripName": "Tokyo Adventure", + "destination": "Tokyo", + "duration": 5, + "dailyItinerary": [ + { + "day": 1, + "title": "Arrival & Exploration", + "activities": [ + { + "time": "9:00 AM", + "activity": "Visit Senso-ji Temple", + "location": "Asakusa", + "estimatedCost": "$10", + "tips": "Go early to avoid crowds" + } + ] + } + ], + "estimatedBudget": { + "accommodation": "$800", + "food": "$500", + "activities": "$400", + "transportation": "$300", + "total": "$2000" + }, + "packingList": ["Comfortable shoes", "Power adapter"], + "travelTips": ["Get a JR Pass", "Learn basic Japanese phrases"] +} +``` + +## Development UI + +Access the Genkit Development UI at http://localhost:3100 to: +- Browse available flows and prompts +- Test flows interactively +- View execution traces diff --git a/java/samples/dotprompt/pom.xml b/java/samples/dotprompt/pom.xml new file mode 100644 index 0000000000..f88c7b5b20 --- /dev/null +++ b/java/samples/dotprompt/pom.xml @@ -0,0 +1,88 @@ + + + + 4.0.0 + + com.google.genkit.samples + genkit-sample-dotprompt + 1.0.0-SNAPSHOT + jar + Genkit DotPrompt Sample + Sample application demonstrating Genkit DotPrompt files with complex inputs and outputs + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.5.3 + + + + + + + src/main/resources + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.12.1 + + 17 + 17 + + + + org.codehaus.mojo + exec-maven-plugin + 3.2.0 + + com.google.genkit.samples.DotPromptSample + + + + + diff --git a/java/samples/dotprompt/run.sh b/java/samples/dotprompt/run.sh new file mode 100755 index 0000000000..7a055a49ca --- /dev/null +++ b/java/samples/dotprompt/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Run script for Genkit DotPrompt Sample +cd "$(dirname "$0")" +mvn exec:java diff --git a/java/samples/dotprompt/src/main/java/com/google/genkit/samples/CodeReview.java b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/CodeReview.java new file mode 100644 index 0000000000..c78441579b --- /dev/null +++ b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/CodeReview.java @@ -0,0 +1,275 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * Code review analysis output schema. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class CodeReview { + + @JsonProperty("language") + private String language; + + @JsonProperty("summary") + private String summary; + + @JsonProperty("complexity") + private Complexity complexity; + + @JsonProperty("issues") + private List issues; + + @JsonProperty("improvements") + private List improvements; + + @JsonProperty("metrics") + private Metrics metrics; + + // Getters and setters + public String getLanguage() { + return language; + } + + public void setLanguage(String language) { + this.language = language; + } + + public String getSummary() { + return summary; + } + + public void setSummary(String summary) { + this.summary = summary; + } + + public Complexity getComplexity() { + return complexity; + } + + public void setComplexity(Complexity complexity) { + this.complexity = complexity; + } + + public List getIssues() { + return issues; + } + + public void setIssues(List issues) { + this.issues = issues; + } + + public List getImprovements() { + return improvements; + } + + public void setImprovements(List improvements) { + this.improvements = improvements; + } + + public Metrics getMetrics() { + return metrics; + } + + public void setMetrics(Metrics metrics) { + this.metrics = metrics; + } + + /** + * Complexity assessment. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class Complexity { + @JsonProperty("level") + private String level; + + @JsonProperty("score") + private Integer score; + + @JsonProperty("explanation") + private String explanation; + + public String getLevel() { + return level; + } + + public void setLevel(String level) { + this.level = level; + } + + public Integer getScore() { + return score; + } + + public void setScore(Integer score) { + this.score = score; + } + + public String getExplanation() { + return explanation; + } + + public void setExplanation(String explanation) { + this.explanation = explanation; + } + } + + /** + * Code issue. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class Issue { + @JsonProperty("severity") + private String severity; + + @JsonProperty("line") + private Integer line; + + @JsonProperty("description") + private String description; + + @JsonProperty("suggestion") + private String suggestion; + + public String getSeverity() { + return severity; + } + + public void setSeverity(String severity) { + this.severity = severity; + } + + public Integer getLine() { + return line; + } + + public void setLine(Integer line) { + this.line = line; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public String getSuggestion() { + return suggestion; + } + + public void setSuggestion(String suggestion) { + this.suggestion = suggestion; + } + } + + /** + * Code improvement suggestion. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class Improvement { + @JsonProperty("category") + private String category; + + @JsonProperty("description") + private String description; + + @JsonProperty("example") + private String example; + + public String getCategory() { + return category; + } + + public void setCategory(String category) { + this.category = category; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public String getExample() { + return example; + } + + public void setExample(String example) { + this.example = example; + } + } + + /** + * Code metrics. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class Metrics { + @JsonProperty("linesOfCode") + private Integer linesOfCode; + + @JsonProperty("functions") + private Integer functions; + + @JsonProperty("classes") + private Integer classes; + + @JsonProperty("comments") + private Integer comments; + + public Integer getLinesOfCode() { + return linesOfCode; + } + + public void setLinesOfCode(Integer linesOfCode) { + this.linesOfCode = linesOfCode; + } + + public Integer getFunctions() { + return functions; + } + + public void setFunctions(Integer functions) { + this.functions = functions; + } + + public Integer getClasses() { + return classes; + } + + public void setClasses(Integer classes) { + this.classes = classes; + } + + public Integer getComments() { + return comments; + } + + public void setComments(Integer comments) { + this.comments = comments; + } + } +} diff --git a/java/samples/dotprompt/src/main/java/com/google/genkit/samples/CodeReviewInput.java b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/CodeReviewInput.java new file mode 100644 index 0000000000..871706d587 --- /dev/null +++ b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/CodeReviewInput.java @@ -0,0 +1,75 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Input schema for code review. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class CodeReviewInput { + + @JsonProperty("code") + private String code; + + @JsonProperty("language") + private String language; + + @JsonProperty("analysisType") + private String analysisType; + + public CodeReviewInput() {} + + public CodeReviewInput(String code, String language) { + this.code = code; + this.language = language; + } + + public String getCode() { + return code; + } + + public void setCode(String code) { + this.code = code; + } + + public String getLanguage() { + return language; + } + + public void setLanguage(String language) { + this.language = language; + } + + public String getAnalysisType() { + return analysisType; + } + + public void setAnalysisType(String analysisType) { + this.analysisType = analysisType; + } + + @Override + public String toString() { + return "CodeReviewInput{language='" + language + "', analysisType='" + analysisType + + "', code='" + (code != null ? code.substring(0, Math.min(50, code.length())) + "..." : "null") + "'}"; + } +} diff --git a/java/samples/dotprompt/src/main/java/com/google/genkit/samples/DotPromptSample.java b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/DotPromptSample.java new file mode 100644 index 0000000000..84aaf86ca8 --- /dev/null +++ b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/DotPromptSample.java @@ -0,0 +1,303 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.*; +import com.google.genkit.core.Flow; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; +import com.google.genkit.prompt.DotPrompt; +import com.google.genkit.prompt.ExecutablePrompt; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; + +/** + * Sample application demonstrating Genkit DotPrompt files with complex inputs and outputs. + * + *

This example shows how to: + *

    + *
  • Load and use .prompt files with Handlebars templates
  • + *
  • Work with complex input schemas (nested objects, arrays)
  • + *
  • Work with complex output schemas (JSON structures)
  • + *
  • Use prompt variants (e.g., recipe.robot.prompt)
  • + *
  • Use partial templates (e.g., _style.prompt)
  • + *
+ * + *

To run: + *

    + *
  1. Set the OPENAI_API_KEY environment variable
  2. + *
  3. Run: mvn exec:java
  4. + *
+ */ +public class DotPromptSample { + + private static final Logger logger = LoggerFactory.getLogger(DotPromptSample.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + + /** + * Extracts JSON from a string that may be wrapped in markdown code blocks. + * Handles formats like: ```json\n{...}\n``` or ```\n{...}\n``` or just {...} + * Also handles nested wrapper objects like {"recipe": {...}} or {"result": {...}} + * If no JSON is found, looks for JSON embedded in the text. + */ + private static String extractJson(String text) { + if (text == null) return null; + String trimmed = text.trim(); + + // Check for ```json or ``` markers + if (trimmed.startsWith("```")) { + // Find the end of the first line (after ```json or ```) + int firstNewline = trimmed.indexOf('\n'); + if (firstNewline == -1) return trimmed; + + // Find the closing ``` + int lastBackticks = trimmed.lastIndexOf("```"); + if (lastBackticks > firstNewline) { + trimmed = trimmed.substring(firstNewline + 1, lastBackticks).trim(); + } + } + + // If the text doesn't start with { or [, try to find JSON embedded in it + if (!trimmed.startsWith("{") && !trimmed.startsWith("[")) { + // Look for JSON object in the text + int jsonStart = trimmed.indexOf('{'); + int jsonEnd = trimmed.lastIndexOf('}'); + if (jsonStart >= 0 && jsonEnd > jsonStart) { + trimmed = trimmed.substring(jsonStart, jsonEnd + 1); + } else { + // Look for JSON array + jsonStart = trimmed.indexOf('['); + jsonEnd = trimmed.lastIndexOf(']'); + if (jsonStart >= 0 && jsonEnd > jsonStart) { + trimmed = trimmed.substring(jsonStart, jsonEnd + 1); + } + } + } + + // Try to unwrap common wrapper keys like "recipe", "result", "data" + try { + Map wrapped = objectMapper.readValue(trimmed, Map.class); + if (wrapped.size() == 1) { + String key = wrapped.keySet().iterator().next(); + if (key.equalsIgnoreCase("recipe") || key.equalsIgnoreCase("result") || + key.equalsIgnoreCase("data") || key.equalsIgnoreCase("response") || + key.equalsIgnoreCase("itinerary") || key.equalsIgnoreCase("trip")) { + Object inner = wrapped.get(key); + if (inner instanceof Map) { + return objectMapper.writeValueAsString(inner); + } + } + } + } catch (Exception e) { + // Not valid JSON or not a wrapper object, return as-is + } + + return trimmed; + } + + public static void main(String[] args) throws Exception { + // Create the Jetty server plugin + JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder() + .port(8080) + .build()); + + // Create Genkit with plugins + Genkit genkit = Genkit.builder() + .options(GenkitOptions.builder() + .devMode(true) + .reflectionPort(3100) + .promptDir("/prompts") // Configure prompt directory (default is /prompts) + .build()) + .plugin(OpenAIPlugin.create()) + .plugin(jetty) + .build(); + + // ============================================================ + // Method 1: Load prompts using genkit.prompt() - Recommended! + // Similar to JavaScript: const helloPrompt = ai.prompt('hello'); + // ============================================================ + + // Load and auto-register prompts using genkit.prompt() + // This automatically loads from /prompts directory and registers as actions + ExecutablePrompt storyPrompt = genkit.prompt("story", StoryInput.class); + ExecutablePrompt travelPrompt = genkit.prompt("travel-planner", TravelInput.class); + ExecutablePrompt codeReviewPrompt = genkit.prompt("code-review", CodeReviewInput.class); + + // Load prompt with variant (e.g., recipe.robot.prompt) + ExecutablePrompt robotRecipePrompt = genkit.prompt("recipe", RecipeInput.class, "robot"); + + // ============================================================ + // Method 2: Load prompts manually using DotPrompt.loadFromResource() + // Useful when you need more control over the loading process + // ============================================================ + DotPrompt recipePrompt = DotPrompt.loadFromResource("/prompts/recipe.prompt"); + + // ============================================================ + // Flow Examples: Different ways to use prompts + // ============================================================ + + // Flow using DotPrompt.render() + manual generate + Flow chefFlow = genkit.defineFlow( + "chefFlow", + RecipeInput.class, + Recipe.class, + (ctx, input) -> { + // Validate input + if (input == null || input.getFood() == null || input.getFood().isEmpty()) { + throw new IllegalArgumentException("Input 'food' is required. Example: {\"food\": \"pasta\", \"ingredients\": [\"tomatoes\", \"basil\"]}"); + } + + // Render the prompt + String prompt = recipePrompt.render(input); + logger.info("Generated prompt: {}", prompt); + + // Generate response + ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-5.2") + .prompt(prompt) + .config(GenerationConfig.builder() + .temperature(0.7) + .build()) + .build()); + + // Parse JSON response to Recipe object (extract from markdown if needed) + String jsonResponse = extractJson(response.getText()); + logger.debug("Extracted JSON: {}", jsonResponse); + try { + return objectMapper.readValue(jsonResponse, Recipe.class); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to parse recipe response: " + jsonResponse, e); + } + }); + + // Flow using ExecutablePrompt.generate() - Direct generation! + // This is the recommended approach - similar to JS: const { text } = await helloPrompt({ name: 'John' }); + Flow robotChefFlow = genkit.defineFlow( + "robotChefFlow", + RecipeInput.class, + Recipe.class, + (ctx, input) -> { + // Generate directly from the prompt - no need to manually call genkit.generate()! + ModelResponse response = robotRecipePrompt.generate(input); + + try { + return objectMapper.readValue(extractJson(response.getText()), Recipe.class); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to parse recipe response", e); + } + }); + + // Flow for story telling using ExecutablePrompt.generate() with custom options + // Demonstrates overriding generation config at call time + Flow tellStoryFlow = genkit.defineFlow( + "tellStory", + StoryInput.class, + String.class, + (ctx, input) -> { + // Generate with custom temperature override + ModelResponse response = storyPrompt.generate(input, + GenerateOptions.builder() + .config(GenerationConfig.builder() + .temperature(0.9) + .build()) + .build()); + return response.getText(); + }); + + // Flow for travel planning using ExecutablePrompt.generate() + Flow planTripFlow = genkit.defineFlow( + "planTrip", + TravelInput.class, + TravelItinerary.class, + (ctx, input) -> { + // Direct generation from ExecutablePrompt + ModelResponse response = travelPrompt.generate(input, + GenerateOptions.builder() + .config(GenerationConfig.builder() + .temperature(0.7) + .build()) + .build()); + try { + return objectMapper.readValue(extractJson(response.getText()), TravelItinerary.class); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to parse travel itinerary response", e); + } + }); + + // Flow for code review using ExecutablePrompt.generate() + Flow reviewCodeFlow = genkit.defineFlow( + "reviewCode", + CodeReviewInput.class, + CodeReview.class, + (ctx, input) -> { + // Direct generation with lower temperature for more focused analysis + ModelResponse response = codeReviewPrompt.generate(input, + GenerateOptions.builder() + .config(GenerationConfig.builder() + .temperature(0.3) + .build()) + .build()); + try { + return objectMapper.readValue(extractJson(response.getText()), CodeReview.class); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to parse code review response", e); + } + }); + + logger.info("=".repeat(60)); + logger.info("Genkit DotPrompt Sample Started"); + logger.info("=".repeat(60)); + logger.info(""); + logger.info("Available flows:"); + logger.info(" - chefFlow: Generate recipes from food and ingredients"); + logger.info(" - robotChefFlow: Generate recipes with robot personality"); + logger.info(" - tellStory: Generate stories with optional personality"); + logger.info(" - planTrip: Generate detailed travel itineraries"); + logger.info(" - reviewCode: Analyze and review code"); + logger.info(""); + logger.info("Example calls:"); + logger.info(" curl -X POST http://localhost:8080/chefFlow \\"); + logger.info(" -H 'Content-Type: application/json' \\"); + logger.info(" -d '{\"food\":\"pasta\",\"ingredients\":[\"tomatoes\",\"basil\"]}'"); + logger.info(""); + logger.info(" curl -X POST http://localhost:8080/tellStory \\"); + logger.info(" -H 'Content-Type: application/json' \\"); + logger.info(" -d '{\"subject\":\"a brave knight\",\"personality\":\"dramatic\"}'"); + logger.info(""); + logger.info(" curl -X POST http://localhost:8080/planTrip \\"); + logger.info(" -H 'Content-Type: application/json' \\"); + logger.info(" -d '{\"destination\":\"Tokyo\",\"duration\":5,\"budget\":\"$3000\",\"interests\":[\"food\",\"culture\"]}'"); + logger.info(""); + logger.info("Reflection API: http://localhost:3100"); + logger.info("HTTP API: http://localhost:8080"); + logger.info("=".repeat(60)); + + // Start the server and block - keeps the application running + jetty.start(); + } +} diff --git a/java/samples/dotprompt/src/main/java/com/google/genkit/samples/Recipe.java b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/Recipe.java new file mode 100644 index 0000000000..0e7af7810a --- /dev/null +++ b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/Recipe.java @@ -0,0 +1,226 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.fasterxml.jackson.annotation.JsonAlias; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Recipe output schema for dotprompt. + * Uses aliases and custom deserializers to handle various formats LLMs might return. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class Recipe { + + @JsonProperty("title") + @JsonAlias({"name", "recipeName", "recipe_name"}) + private String title; + + @JsonProperty("ingredients") + private List ingredients; + + @JsonProperty("steps") + @JsonAlias({"instructions", "directions", "procedure"}) + @JsonDeserialize(using = StepsDeserializer.class) + private List steps; + + @JsonProperty("prepTime") + private String prepTime; + + @JsonProperty("cookTime") + private String cookTime; + + @JsonProperty("servings") + private Integer servings; + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public List getIngredients() { + return ingredients; + } + + public void setIngredients(List ingredients) { + this.ingredients = ingredients; + } + + public List getSteps() { + return steps; + } + + public void setSteps(List steps) { + this.steps = steps; + } + + public String getPrepTime() { + return prepTime; + } + + public void setPrepTime(String prepTime) { + this.prepTime = prepTime; + } + + public String getCookTime() { + return cookTime; + } + + public void setCookTime(String cookTime) { + this.cookTime = cookTime; + } + + public Integer getServings() { + return servings; + } + + public void setServings(Integer servings) { + this.servings = servings; + } + + @Override + public String toString() { + return "Recipe{" + + "title='" + title + '\'' + + ", ingredients=" + ingredients + + ", steps=" + steps + + ", prepTime='" + prepTime + '\'' + + ", cookTime='" + cookTime + '\'' + + ", servings=" + servings + + '}'; + } + + /** + * Ingredient with name and quantity. + * Handles various field names and types LLMs might return. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class Ingredient { + @JsonProperty("name") + @JsonAlias({"ingredient", "item"}) + private String name; + + @JsonProperty("quantity") + @JsonDeserialize(using = QuantityDeserializer.class) + private String quantity; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getQuantity() { + return quantity; + } + + public void setQuantity(String quantity) { + this.quantity = quantity; + } + + @Override + public String toString() { + return "Ingredient{name='" + name + "', quantity='" + quantity + "'}"; + } + } + + /** + * Custom deserializer for steps/instructions that can be either: + * - Array of strings: ["step1", "step2"] + * - Array of objects: [{"step": 1, "title": "...", "details": "..."}] + */ + public static class StepsDeserializer extends JsonDeserializer> { + @Override + public List deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + List steps = new ArrayList<>(); + JsonNode node = p.getCodec().readTree(p); + + if (node.isArray()) { + for (JsonNode item : node) { + if (item.isTextual()) { + steps.add(item.asText()); + } else if (item.isObject()) { + // Handle object format: {step, title, details} + StringBuilder sb = new StringBuilder(); + if (item.has("title")) { + sb.append(item.get("title").asText()); + } + if (item.has("details")) { + if (sb.length() > 0) sb.append(": "); + sb.append(item.get("details").asText()); + } + if (sb.length() == 0 && item.has("description")) { + sb.append(item.get("description").asText()); + } + if (sb.length() == 0 && item.has("instruction")) { + sb.append(item.get("instruction").asText()); + } + steps.add(sb.length() > 0 ? sb.toString() : item.toString()); + } + } + } + return steps; + } + } + + /** + * Custom deserializer for quantity that handles various formats: + * - String: "2 cups" + * - Number: 2.5 + * - Object with amount/unit: {"amount": 2, "unit": "cups"} + */ + public static class QuantityDeserializer extends JsonDeserializer { + @Override + public String deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + JsonNode node = p.getCodec().readTree(p); + + if (node.isTextual()) { + return node.asText(); + } else if (node.isNumber()) { + return node.asText(); + } else if (node.isObject()) { + StringBuilder sb = new StringBuilder(); + if (node.has("amount")) { + sb.append(node.get("amount").asText()); + } + if (node.has("unit")) { + if (sb.length() > 0) sb.append(" "); + sb.append(node.get("unit").asText()); + } + return sb.toString(); + } + return node.toString(); + } + } +} diff --git a/java/samples/dotprompt/src/main/java/com/google/genkit/samples/RecipeInput.java b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/RecipeInput.java new file mode 100644 index 0000000000..5baf0accae --- /dev/null +++ b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/RecipeInput.java @@ -0,0 +1,62 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * Input schema for recipe prompt. + */ +public class RecipeInput { + + @JsonProperty("food") + private String food; + + @JsonProperty("ingredients") + private List ingredients; + + public RecipeInput() { + } + + public RecipeInput(String food) { + this.food = food; + } + + public RecipeInput(String food, List ingredients) { + this.food = food; + this.ingredients = ingredients; + } + + public String getFood() { + return food; + } + + public void setFood(String food) { + this.food = food; + } + + public List getIngredients() { + return ingredients; + } + + public void setIngredients(List ingredients) { + this.ingredients = ingredients; + } +} diff --git a/java/samples/dotprompt/src/main/java/com/google/genkit/samples/StoryInput.java b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/StoryInput.java new file mode 100644 index 0000000000..1b813a3c2f --- /dev/null +++ b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/StoryInput.java @@ -0,0 +1,63 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Input schema for story generation. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class StoryInput { + + @JsonProperty("subject") + private String subject; + + @JsonProperty("personality") + private String personality; + + public StoryInput() {} + + public StoryInput(String subject, String personality) { + this.subject = subject; + this.personality = personality; + } + + public String getSubject() { + return subject; + } + + public void setSubject(String subject) { + this.subject = subject; + } + + public String getPersonality() { + return personality; + } + + public void setPersonality(String personality) { + this.personality = personality; + } + + @Override + public String toString() { + return "StoryInput{subject='" + subject + "', personality='" + personality + "'}"; + } +} diff --git a/java/samples/dotprompt/src/main/java/com/google/genkit/samples/TravelInput.java b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/TravelInput.java new file mode 100644 index 0000000000..9d13fab197 --- /dev/null +++ b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/TravelInput.java @@ -0,0 +1,100 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * Input schema for travel planning. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class TravelInput { + + @JsonProperty("destination") + private String destination; + + @JsonProperty("duration") + private Integer duration; + + @JsonProperty("budget") + private String budget; + + @JsonProperty("interests") + private List interests; + + @JsonProperty("travelStyle") + private String travelStyle; + + public TravelInput() {} + + public TravelInput(String destination, Integer duration, String budget) { + this.destination = destination; + this.duration = duration; + this.budget = budget; + } + + public String getDestination() { + return destination; + } + + public void setDestination(String destination) { + this.destination = destination; + } + + public Integer getDuration() { + return duration; + } + + public void setDuration(Integer duration) { + this.duration = duration; + } + + public String getBudget() { + return budget; + } + + public void setBudget(String budget) { + this.budget = budget; + } + + public List getInterests() { + return interests; + } + + public void setInterests(List interests) { + this.interests = interests; + } + + public String getTravelStyle() { + return travelStyle; + } + + public void setTravelStyle(String travelStyle) { + this.travelStyle = travelStyle; + } + + @Override + public String toString() { + return "TravelInput{destination='" + destination + "', duration=" + duration + + ", budget='" + budget + "', interests=" + interests + + ", travelStyle='" + travelStyle + "'}"; + } +} diff --git a/java/samples/dotprompt/src/main/java/com/google/genkit/samples/TravelItinerary.java b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/TravelItinerary.java new file mode 100644 index 0000000000..f77cea62bc --- /dev/null +++ b/java/samples/dotprompt/src/main/java/com/google/genkit/samples/TravelItinerary.java @@ -0,0 +1,269 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * Travel itinerary output schema. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class TravelItinerary { + + @JsonProperty("tripName") + private String tripName; + + @JsonProperty("destination") + private String destination; + + @JsonProperty("duration") + private Integer duration; + + @JsonProperty("dailyItinerary") + private List dailyItinerary; + + @JsonProperty("estimatedBudget") + private Budget estimatedBudget; + + @JsonProperty("packingList") + private List packingList; + + @JsonProperty("travelTips") + private List travelTips; + + // Getters and setters + public String getTripName() { + return tripName; + } + + public void setTripName(String tripName) { + this.tripName = tripName; + } + + public String getDestination() { + return destination; + } + + public void setDestination(String destination) { + this.destination = destination; + } + + public Integer getDuration() { + return duration; + } + + public void setDuration(Integer duration) { + this.duration = duration; + } + + public List getDailyItinerary() { + return dailyItinerary; + } + + public void setDailyItinerary(List dailyItinerary) { + this.dailyItinerary = dailyItinerary; + } + + public Budget getEstimatedBudget() { + return estimatedBudget; + } + + public void setEstimatedBudget(Budget estimatedBudget) { + this.estimatedBudget = estimatedBudget; + } + + public List getPackingList() { + return packingList; + } + + public void setPackingList(List packingList) { + this.packingList = packingList; + } + + public List getTravelTips() { + return travelTips; + } + + public void setTravelTips(List travelTips) { + this.travelTips = travelTips; + } + + /** + * Day plan with activities. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class DayPlan { + @JsonProperty("day") + private Integer day; + + @JsonProperty("title") + private String title; + + @JsonProperty("activities") + private List activities; + + public Integer getDay() { + return day; + } + + public void setDay(Integer day) { + this.day = day; + } + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public List getActivities() { + return activities; + } + + public void setActivities(List activities) { + this.activities = activities; + } + } + + /** + * Activity within a day. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class Activity { + @JsonProperty("time") + private String time; + + @JsonProperty("activity") + private String activity; + + @JsonProperty("location") + private String location; + + @JsonProperty("estimatedCost") + private String estimatedCost; + + @JsonProperty("tips") + private String tips; + + public String getTime() { + return time; + } + + public void setTime(String time) { + this.time = time; + } + + public String getActivity() { + return activity; + } + + public void setActivity(String activity) { + this.activity = activity; + } + + public String getLocation() { + return location; + } + + public void setLocation(String location) { + this.location = location; + } + + public String getEstimatedCost() { + return estimatedCost; + } + + public void setEstimatedCost(String estimatedCost) { + this.estimatedCost = estimatedCost; + } + + public String getTips() { + return tips; + } + + public void setTips(String tips) { + this.tips = tips; + } + } + + /** + * Budget breakdown. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class Budget { + @JsonProperty("accommodation") + private String accommodation; + + @JsonProperty("food") + private String food; + + @JsonProperty("activities") + private String activities; + + @JsonProperty("transportation") + private String transportation; + + @JsonProperty("total") + private String total; + + public String getAccommodation() { + return accommodation; + } + + public void setAccommodation(String accommodation) { + this.accommodation = accommodation; + } + + public String getFood() { + return food; + } + + public void setFood(String food) { + this.food = food; + } + + public String getActivities() { + return activities; + } + + public void setActivities(String activities) { + this.activities = activities; + } + + public String getTransportation() { + return transportation; + } + + public void setTransportation(String transportation) { + this.transportation = transportation; + } + + public String getTotal() { + return total; + } + + public void setTotal(String total) { + this.total = total; + } + } +} diff --git a/java/samples/dotprompt/src/main/resources/logback.xml b/java/samples/dotprompt/src/main/resources/logback.xml new file mode 100644 index 0000000000..56110e53c0 --- /dev/null +++ b/java/samples/dotprompt/src/main/resources/logback.xml @@ -0,0 +1,25 @@ + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + + + diff --git a/java/samples/dotprompt/src/main/resources/prompts/_style.prompt b/java/samples/dotprompt/src/main/resources/prompts/_style.prompt new file mode 100644 index 0000000000..023b6204b6 --- /dev/null +++ b/java/samples/dotprompt/src/main/resources/prompts/_style.prompt @@ -0,0 +1,5 @@ +{{#if personality}} +Adopt the personality of someone who is {{personality}}. +{{else}} +Be friendly and approachable. +{{/if}} diff --git a/java/samples/dotprompt/src/main/resources/prompts/code-review.prompt b/java/samples/dotprompt/src/main/resources/prompts/code-review.prompt new file mode 100644 index 0000000000..45c0a1f319 --- /dev/null +++ b/java/samples/dotprompt/src/main/resources/prompts/code-review.prompt @@ -0,0 +1,53 @@ +--- +model: openai/gpt-4o-mini +input: + schema: + code: string + language: string + analysisType?: string +output: + format: json + schema: + language: string + summary: string + complexity: + level: string + score: integer, from 1 to 10 + explanation: string + issues(array): + severity: string + line?: integer + description: string + suggestion: string + improvements(array): + category: string + description: string + example?: string + metrics: + linesOfCode: integer + functions: integer + classes: integer + comments: integer +--- + +You are a senior software engineer and code reviewer with expertise in {{language}}. + +Analyze the following {{language}} code: + +```{{language}} +{{code}} +``` + +{{#if analysisType}} +Focus on: {{analysisType}} +{{else}} +Provide a comprehensive analysis including: +- Code quality and readability +- Potential bugs or issues +- Performance considerations +- Best practices adherence +{{/if}} + +IMPORTANT: You MUST respond with ONLY a valid JSON object matching the output schema. Do not include any explanatory text before or after the JSON. + +Return your analysis in the specified JSON format. diff --git a/java/samples/dotprompt/src/main/resources/prompts/recipe.prompt b/java/samples/dotprompt/src/main/resources/prompts/recipe.prompt new file mode 100644 index 0000000000..5647ec9dc8 --- /dev/null +++ b/java/samples/dotprompt/src/main/resources/prompts/recipe.prompt @@ -0,0 +1,31 @@ +--- +model: openai/gpt-4o-mini +input: + schema: + food: string + ingredients?(array): string +output: + format: json + schema: + title: string, recipe title + ingredients(array): + name: string + quantity: string + steps(array, the steps required to complete the recipe): string + prepTime: string + cookTime: string + servings: integer +--- + +You are a chef famous for making creative recipes that can be prepared in 45 minutes or less. + +Generate a detailed recipe for {{food}}. + +{{#if ingredients}} +Make sure to include the following ingredients: +{{#each ingredients}} +- {{this}} +{{/each}} +{{/if}} + +Provide the recipe in the exact JSON format specified in the output schema. diff --git a/java/samples/dotprompt/src/main/resources/prompts/recipe.robot.prompt b/java/samples/dotprompt/src/main/resources/prompts/recipe.robot.prompt new file mode 100644 index 0000000000..f3ca2d1df0 --- /dev/null +++ b/java/samples/dotprompt/src/main/resources/prompts/recipe.robot.prompt @@ -0,0 +1,30 @@ +--- +model: openai/gpt-4o-mini +input: + schema: + food: string + ingredients?(array): string +output: + format: json + schema: + title: string + ingredients(array): + name: string + quantity: string + steps(array): string +--- + +{{>style personality="robot"}} + +BEEP BOOP! You are RoboChef 3000, a robot chef who speaks in mechanical terms. + +Generate a recipe for {{food}}. + +{{#if ingredients}} +REQUIRED COMPONENTS DETECTED: +{{#each ingredients}} +- COMPONENT: {{this}} +{{/each}} +{{/if}} + +OUTPUT FORMAT: JSON as specified. COMPLIANCE IS MANDATORY. diff --git a/java/samples/dotprompt/src/main/resources/prompts/story.prompt b/java/samples/dotprompt/src/main/resources/prompts/story.prompt new file mode 100644 index 0000000000..59b5a8371c --- /dev/null +++ b/java/samples/dotprompt/src/main/resources/prompts/story.prompt @@ -0,0 +1,22 @@ +--- +model: openai/gpt-4o-mini +input: + schema: + subject: string + personality?: string + length?: string +output: + format: text +--- + +{{#if personality}} +Write in the style of someone who is {{personality}}. +{{/if}} + +Tell me an engaging story about {{subject}}. + +{{#if length}} +The story should be approximately {{length}} in length. +{{else}} +The story should be medium length (2-3 paragraphs). +{{/if}} diff --git a/java/samples/dotprompt/src/main/resources/prompts/travel-planner.prompt b/java/samples/dotprompt/src/main/resources/prompts/travel-planner.prompt new file mode 100644 index 0000000000..7e55914091 --- /dev/null +++ b/java/samples/dotprompt/src/main/resources/prompts/travel-planner.prompt @@ -0,0 +1,51 @@ +--- +model: openai/gpt-4o-mini +input: + schema: + destination: string + duration: integer, number of days + budget: string + interests?(array): string + travelStyle?: string +output: + format: json + schema: + tripName: string + destination: string + duration: integer + dailyItinerary(array): + day: integer + title: string + activities(array): + time: string + activity: string + location: string + estimatedCost?: string + tips?: string + estimatedBudget: + accommodation: string + food: string + activities: string + transportation: string + total: string + packingList(array): string + travelTips(array): string +--- + +You are an experienced travel planner with expertise in creating personalized itineraries. + +Create a detailed travel itinerary for a trip to {{destination}} for {{duration}} days with a budget of {{budget}}. + +{{#if travelStyle}} +Travel style preference: {{travelStyle}} +{{/if}} + +{{#if interests}} +Traveler interests: +{{#each interests}} +- {{this}} +{{/each}} +{{/if}} + +IMPORTANT: You MUST respond with ONLY a valid JSON object. Do not include any explanatory text before or after the JSON. +Provide a comprehensive day-by-day itinerary with specific activities, timings, and cost estimates in the exact JSON format specified above. diff --git a/java/samples/evaluations/README.md b/java/samples/evaluations/README.md new file mode 100644 index 0000000000..1bdf943f37 --- /dev/null +++ b/java/samples/evaluations/README.md @@ -0,0 +1,177 @@ +# Genkit Evaluations Sample + +This sample demonstrates how to use Genkit's evaluation framework to assess AI output quality with custom evaluators and datasets. + +## Features Demonstrated + +- **Custom Evaluators** - Define evaluators for length, keywords, sentiment +- **LLM-Based Evaluators** - Use AI to evaluate AI outputs +- **Datasets** - Create and manage evaluation datasets +- **Evaluation Runs** - Execute evaluations and view results +- **Dev UI Integration** - View evaluations in the Genkit Dev UI + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/evaluations + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI (Recommended) + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/evaluations + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +> **Important**: Run `genkit start` from the same directory where the Java app is running. This ensures the Dev UI can find the datasets stored in `.genkit/datasets/`. + +## Available Flows + +| Flow | Input | Output | Description | +|------|-------|--------|-------------| +| `describeFood` | String (food) | String | Generate appetizing food descriptions | + +## Custom Evaluators + +This sample defines several custom evaluators: + +| Evaluator | Description | +|-----------|-------------| +| `custom/length` | Checks if output length is between 50-500 characters | +| `custom/keywords` | Checks for food-related descriptive keywords | +| `custom/sentiment` | Evaluates positive/appetizing sentiment | + +## Example API Calls + +### Describe Food +```bash +curl -X POST http://localhost:8080/describeFood \ + -H 'Content-Type: application/json' \ + -d '"chocolate cake"' +``` + +## Creating Custom Evaluators + +### Simple Rule-Based Evaluator + +```java +Evaluator lengthEvaluator = genkit.defineEvaluator( + "custom/length", + "Output Length", + "Evaluates whether the output has an appropriate length", + (dataPoint, options) -> { + String output = dataPoint.getOutput().toString(); + int length = output.length(); + + double score = (length >= 50 && length <= 500) ? 1.0 : 0.5; + EvalStatus status = score == 1.0 ? EvalStatus.PASS : EvalStatus.FAIL; + + return EvalResponse.builder() + .testCaseId(dataPoint.getTestCaseId()) + .evaluation(Score.builder() + .score(score) + .status(status) + .reasoning("Output length: " + length) + .build()) + .build(); + }); +``` + +### Keyword-Based Evaluator + +```java +Evaluator keywordEvaluator = genkit.defineEvaluator( + "custom/keywords", + "Food Keywords", + "Checks for food-related descriptive keywords", + (dataPoint, options) -> { + String output = dataPoint.getOutput().toString().toLowerCase(); + + List keywords = Arrays.asList( + "delicious", "tasty", "flavor", "savory", "sweet"); + + int foundCount = 0; + for (String keyword : keywords) { + if (output.contains(keyword)) foundCount++; + } + + double score = Math.min(1.0, foundCount / 3.0); + + return EvalResponse.builder() + .testCaseId(dataPoint.getTestCaseId()) + .evaluation(Score.builder() + .score(score) + .status(foundCount >= 2 ? EvalStatus.PASS : EvalStatus.FAIL) + .reasoning("Found " + foundCount + " keywords") + .build()) + .build(); + }); +``` + +## Working with Datasets + +Datasets are stored in `.genkit/datasets/` and can be managed via the Dev UI or programmatically: + +```java +// Create a dataset +List items = Arrays.asList( + new DatasetItem("test-1", "pizza", null), + new DatasetItem("test-2", "sushi", null), + new DatasetItem("test-3", "tacos", null) +); + +// Run evaluation +EvalRunKey result = genkit.evaluate( + RunEvaluationRequest.builder() + .datasetId("food-dataset") + .evaluators(List.of("custom/length", "custom/keywords")) + .actionRef("/flow/describeFood") + .build()); +``` + +## Development UI + +When running with `genkit start`, access the Dev UI at http://localhost:4000 to: + +- Create and manage datasets +- Run evaluations on flows +- View evaluation results and scores +- Compare evaluation runs +- Inspect individual test cases + +## Evaluation Results + +Evaluation results include: + +- **Score**: Numeric value (0.0 - 1.0) +- **Status**: PASS, FAIL, or UNKNOWN +- **Reasoning**: Explanation of the score + +## See Also + +- [Genkit Java README](../../README.md) +- [Genkit Evaluation Documentation](https://firebase.google.com/docs/genkit/evaluation) diff --git a/java/samples/evaluations/pom.xml b/java/samples/evaluations/pom.xml new file mode 100644 index 0000000000..028e2073ac --- /dev/null +++ b/java/samples/evaluations/pom.xml @@ -0,0 +1,89 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-evaluations + jar + Genkit Evaluations Sample + Sample application demonstrating Genkit evaluations, evaluators, and datasets + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.5.3 + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.12.1 + + 17 + 17 + + + + org.codehaus.mojo + exec-maven-plugin + 3.2.0 + + com.google.genkit.samples.EvaluationsSample + + + + + diff --git a/java/samples/evaluations/run.sh b/java/samples/evaluations/run.sh new file mode 100755 index 0000000000..7a055a49ca --- /dev/null +++ b/java/samples/evaluations/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Run script for Genkit DotPrompt Sample +cd "$(dirname "$0")" +mvn exec:java diff --git a/java/samples/evaluations/src/main/java/com/google/genkit/samples/EvaluationsSample.java b/java/samples/evaluations/src/main/java/com/google/genkit/samples/EvaluationsSample.java new file mode 100644 index 0000000000..e17d89764f --- /dev/null +++ b/java/samples/evaluations/src/main/java/com/google/genkit/samples/EvaluationsSample.java @@ -0,0 +1,363 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.util.*; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.*; +import com.google.genkit.ai.evaluation.*; +import com.google.genkit.core.Flow; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; + +/** + * Sample application demonstrating Genkit evaluations, evaluators, and + * datasets. + * + * This example shows how to: - Define custom evaluators - Create and manage + * datasets - Run evaluations - Use LLM-based evaluators + * + *

To run with Dev UI:

+ *
    + *
  1. Set the OPENAI_API_KEY environment variable (for LLM-based + * evaluators)
  2. + *
  3. Navigate to the sample directory: + * {@code cd java/samples/evaluations}
  4. + *
  5. Run the app: {@code mvn exec:java}
  6. + *
  7. In a separate terminal, from the same directory, run: + * {@code genkit start}
  8. + *
  9. Open the Dev UI at http://localhost:4000
  10. + *
+ * + *

+ * Important: Run {@code genkit start} from the same directory where the + * Java app is running. This ensures the Dev UI can find the datasets stored in + * {@code .genkit/datasets/}. + */ +public class EvaluationsSample { + + public static void main(String[] args) throws Exception { + // Create the Jetty server plugin + JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build()); + + // Create Genkit with plugins + Genkit genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(OpenAIPlugin.create()).plugin(jetty).build(); + + // ===================================================================== + // Define a flow to evaluate + // ===================================================================== + + Flow describeFood = genkit.defineFlow("describeFood", String.class, String.class, + (ctx, food) -> { + // Use OpenAI to describe the food + try { + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .prompt("Describe " + food + " in a delicious and appetizing way in 2-3 sentences.") + .config(GenerationConfig.builder().temperature(0.8).maxOutputTokens(200).build()) + .build()); + return response.getText(); + } catch (Exception e) { + return "A delicious " + food + " with wonderful flavors and textures."; + } + }); + + // ===================================================================== + // Define Custom Evaluators + // ===================================================================== + + // Simple length-based evaluator + Evaluator lengthEvaluator = genkit.defineEvaluator("custom/length", "Output Length", + "Evaluates whether the output has an appropriate length (50-500 characters)", (dataPoint, options) -> { + String output = dataPoint.getOutput() != null ? dataPoint.getOutput().toString() : ""; + int length = output.length(); + + double score; + EvalStatus status; + String reasoning; + + if (length >= 50 && length <= 500) { + score = 1.0; + status = EvalStatus.PASS; + reasoning = "Output length (" + length + " chars) is within acceptable range."; + } else if (length < 50) { + score = length / 50.0; + status = EvalStatus.FAIL; + reasoning = "Output too short (" + length + " chars). Expected at least 50 characters."; + } else { + score = Math.max(0, 1.0 - (length - 500) / 500.0); + status = EvalStatus.FAIL; + reasoning = "Output too long (" + length + " chars). Expected at most 500 characters."; + } + + return EvalResponse.builder().testCaseId(dataPoint.getTestCaseId()) + .evaluation(Score.builder().score(score).status(status).reasoning(reasoning).build()) + .build(); + }); + + // Keyword presence evaluator + Evaluator keywordEvaluator = genkit.defineEvaluator("custom/keywords", "Food Keywords", + "Checks if the output contains food-related descriptive keywords", (dataPoint, options) -> { + String output = dataPoint.getOutput() != null ? dataPoint.getOutput().toString().toLowerCase() : ""; + + List positiveKeywords = Arrays.asList("delicious", "tasty", "flavor", "savory", "sweet", + "crispy", "tender", "juicy", "fresh", "aromatic", "rich", "creamy", "satisfying", + "mouth-watering"); + + int foundCount = 0; + List foundKeywords = new ArrayList<>(); + for (String keyword : positiveKeywords) { + if (output.contains(keyword)) { + foundCount++; + foundKeywords.add(keyword); + } + } + + double score = Math.min(1.0, foundCount / 3.0); + EvalStatus status = foundCount >= 2 ? EvalStatus.PASS : EvalStatus.FAIL; + String reasoning = "Found " + foundCount + " descriptive keywords: " + + String.join(", ", foundKeywords); + + return EvalResponse.builder().testCaseId(dataPoint.getTestCaseId()) + .evaluation(Score.builder().score(score).status(status).reasoning(reasoning).build()) + .build(); + }); + + // Sentiment evaluator (simple) + Evaluator sentimentEvaluator = genkit.defineEvaluator("custom/sentiment", "Positive Sentiment", + "Evaluates whether the output has a positive/appetizing sentiment", (dataPoint, options) -> { + String output = dataPoint.getOutput() != null ? dataPoint.getOutput().toString().toLowerCase() : ""; + + List positiveWords = Arrays.asList("delicious", "wonderful", "amazing", "excellent", + "perfect", "lovely", "great", "beautiful", "fantastic", "divine"); + List negativeWords = Arrays.asList("bad", "awful", "terrible", "disgusting", "horrible", + "gross", "nasty", "unpleasant", "bland"); + + int positiveCount = 0; + int negativeCount = 0; + + for (String word : positiveWords) { + if (output.contains(word)) + positiveCount++; + } + for (String word : negativeWords) { + if (output.contains(word)) + negativeCount++; + } + + double sentimentScore = positiveCount - negativeCount; + double normalizedScore = Math.max(0, Math.min(1, (sentimentScore + 2) / 4.0)); + + EvalStatus status = sentimentScore > 0 ? EvalStatus.PASS : EvalStatus.FAIL; + String reasoning = "Positive words: " + positiveCount + ", Negative words: " + negativeCount; + + return EvalResponse.builder().testCaseId(dataPoint.getTestCaseId()) + .evaluation( + Score.builder().score(normalizedScore).status(status).reasoning(reasoning).build()) + .build(); + }); + + // LLM-based "Deliciousness" evaluator + Evaluator deliciousnessEvaluator = genkit.defineEvaluator("custom/deliciousness", "Deliciousness", + "Uses an LLM to evaluate how delicious and appetizing the description sounds", true, // isBilled - this + // evaluator + // makes LLM + // calls + null, (dataPoint, options) -> { + String output = dataPoint.getOutput() != null ? dataPoint.getOutput().toString() : ""; + + try { + String prompt = """ + You are evaluating how delicious and appetizing a food description sounds. + + Food description to evaluate: + \"\"\" + %s + \"\"\" + + Rate this description on a scale of 0.0 to 1.0 where: + - 0.0 = Not appetizing at all + - 0.5 = Somewhat appetizing + - 1.0 = Extremely appetizing, makes you want to eat it + + Respond with ONLY a JSON object in this format: + {"score": 0.X, "reasoning": "brief explanation"} + """.formatted(output); + + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .prompt(prompt) + .config(GenerationConfig.builder().temperature(0.0).maxOutputTokens(200).build()) + .build()); + + String responseText = response.getText().trim(); + // Parse the JSON response + // Simple parsing - in production you'd use a JSON parser + double score = 0.5; + String reasoning = "Unable to parse response"; + + if (responseText.contains("\"score\"")) { + int scoreStart = responseText.indexOf("\"score\"") + 9; + int scoreEnd = responseText.indexOf(",", scoreStart); + if (scoreEnd == -1) + scoreEnd = responseText.indexOf("}", scoreStart); + try { + score = Double.parseDouble(responseText.substring(scoreStart, scoreEnd).trim()); + } catch (NumberFormatException e) { + // Keep default + } + } + if (responseText.contains("\"reasoning\"")) { + int reasonStart = responseText.indexOf("\"reasoning\"") + 13; + int reasonEnd = responseText.lastIndexOf("\""); + if (reasonEnd > reasonStart) { + reasoning = responseText.substring(reasonStart, reasonEnd); + } + } + + return EvalResponse.builder().testCaseId(dataPoint.getTestCaseId()) + .evaluation(Score.builder().score(score) + .status(score >= 0.6 ? EvalStatus.PASS : EvalStatus.FAIL).reasoning(reasoning) + .build()) + .build(); + + } catch (Exception e) { + return EvalResponse.builder().testCaseId(dataPoint.getTestCaseId()).evaluation(Score.builder() + .error("Failed to evaluate: " + e.getMessage()).status(EvalStatus.UNKNOWN).build()) + .build(); + } + }); + + // ===================================================================== + // Create a Sample Dataset + // ===================================================================== + + DatasetStore datasetStore = genkit.getDatasetStore(); + + // Check if our sample dataset already exists + List existingDatasets = datasetStore.listDatasets(); + boolean datasetExists = existingDatasets.stream().anyMatch(d -> "food_descriptions".equals(d.getDatasetId())); + + if (!datasetExists) { + // Create the sample dataset + List samples = Arrays.asList( + DatasetSample.builder().testCaseId("food_1").input("pizza").reference( + "A delicious Italian dish with a crispy crust, tangy tomato sauce, and melted cheese.") + .build(), + DatasetSample.builder().testCaseId("food_2").input("sushi") + .reference("Fresh, delicate rolls of vinegared rice with raw fish and vegetables.").build(), + DatasetSample.builder().testCaseId("food_3").input("tacos").reference( + "Flavorful Mexican street food with seasoned meat, fresh salsa, and corn tortillas.") + .build(), + DatasetSample.builder().testCaseId("food_4").input("chocolate cake") + .reference("Rich, moist layers of chocolate with creamy frosting.").build(), + DatasetSample.builder().testCaseId("food_5").input("ramen").reference( + "A comforting bowl of noodles in savory broth with tender pork and soft-boiled egg.") + .build()); + + CreateDatasetRequest createRequest = CreateDatasetRequest.builder().datasetId("food_descriptions") + .data(samples).datasetType(DatasetType.FLOW).targetAction("/flow/describeFood").metricRefs(Arrays + .asList("custom/length", "custom/keywords", "custom/sentiment", "custom/deliciousness")) + .build(); + + DatasetMetadata metadata = datasetStore.createDataset(createRequest); + System.out.println( + "Created dataset: " + metadata.getDatasetId() + " with " + metadata.getSize() + " samples"); + } else { + System.out.println("Dataset 'food_descriptions' already exists"); + } + + // ===================================================================== + // Define a flow to run evaluations programmatically + // ===================================================================== + + Flow, Void> runEvaluationFlow = genkit.defineFlow("runEvaluation", String.class, + (Class>) (Class) Map.class, (ctx, datasetId) -> { + try { + // Create evaluation request + RunEvaluationRequest.DataSource dataSource = new RunEvaluationRequest.DataSource(); + dataSource.setDatasetId(datasetId); + + RunEvaluationRequest request = RunEvaluationRequest.builder().dataSource(dataSource) + .targetAction("/flow/describeFood") + .evaluators(Arrays.asList("custom/length", "custom/keywords", "custom/sentiment")) + .build(); + + EvalRunKey evalRunKey = genkit.evaluate(request); + + // Return the result + Map result = new HashMap<>(); + result.put("evalRunId", evalRunKey.getEvalRunId()); + result.put("createdAt", evalRunKey.getCreatedAt()); + result.put("datasetId", datasetId); + return result; + } catch (Exception e) { + Map error = new HashMap<>(); + error.put("error", e.getMessage()); + return error; + } + }); + + // ===================================================================== + // Print Information + // ===================================================================== + + System.out.println(); + System.out.println("╔══════════════════════════════════════════════════════════════════╗"); + System.out.println("║ Genkit Evaluations Sample Application ║"); + System.out.println("╠══════════════════════════════════════════════════════════════════╣"); + System.out.println("║ Dev UI: http://localhost:3100 ║"); + System.out.println("╠══════════════════════════════════════════════════════════════════╣"); + System.out.println("║ Registered Evaluators: ║"); + System.out.println("║ • custom/length - Checks output length (50-500 chars) ║"); + System.out.println("║ • custom/keywords - Checks for food-related keywords ║"); + System.out.println("║ • custom/sentiment - Evaluates positive sentiment ║"); + System.out.println("║ • custom/deliciousness - LLM-based appetizing evaluation ║"); + System.out.println("╠══════════════════════════════════════════════════════════════════╣"); + System.out.println("║ API Endpoints: ║"); + System.out.println("║ POST /api/flows/describeFood - Describe a food item ║"); + System.out.println("║ POST /api/flows/runEvaluation - Run evaluation on dataset ║"); + System.out.println("║ GET /api/datasets - List all datasets ║"); + System.out.println("║ GET /api/evalRuns - List evaluation runs ║"); + System.out.println("╠══════════════════════════════════════════════════════════════════╣"); + System.out.println("║ Example usage: ║"); + System.out.println("║ curl -X POST http://localhost:8080/api/flows/describeFood \\ ║"); + System.out.println("║ -d '\"pizza\"' -H 'Content-Type: application/json' ║"); + System.out.println("║ ║"); + System.out.println("║ curl -X POST http://localhost:8080/api/flows/runEvaluation \\ ║"); + System.out.println("║ -d '\"food_descriptions\"' -H 'Content-Type: application/json'║"); + System.out.println("╠══════════════════════════════════════════════════════════════════╣"); + System.out.println("║ Data Storage: ║"); + System.out.println("║ Datasets: ./.genkit/datasets/ ║"); + System.out.println("║ Eval Runs: ./.genkit/evals/ ║"); + System.out.println("╠══════════════════════════════════════════════════════════════════╣"); + System.out.println("║ IMPORTANT: To see datasets in Dev UI, run 'genkit start' from ║"); + System.out.println("║ the SAME directory where this app runs (current working dir). ║"); + System.out.println("╚══════════════════════════════════════════════════════════════════╝"); + System.out.println(); + System.out.println("Working directory: " + System.getProperty("user.dir")); + System.out.println(); + System.out.println("Press Ctrl+C to stop..."); + + // Start the server and block + jetty.start(); + } +} diff --git a/java/samples/evaluations/src/main/resources/logback.xml b/java/samples/evaluations/src/main/resources/logback.xml new file mode 100644 index 0000000000..56110e53c0 --- /dev/null +++ b/java/samples/evaluations/src/main/resources/logback.xml @@ -0,0 +1,25 @@ + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + + + diff --git a/java/samples/google-genai/README.md b/java/samples/google-genai/README.md new file mode 100644 index 0000000000..0320e597a5 --- /dev/null +++ b/java/samples/google-genai/README.md @@ -0,0 +1,198 @@ +# Genkit Google GenAI Sample + +This sample demonstrates integration with Google GenAI (Gemini) models using Genkit Java, including text generation, image generation with Imagen, and text-to-speech. + +## Features Demonstrated + +- **Google GenAI Plugin Setup** - Configure Genkit with Gemini models +- **Text Generation** - Generate text with Gemini 2.0 Flash +- **Tool Calling** - Function calling with Gemini +- **Embeddings** - Generate text embeddings +- **Image Generation** - Generate images with Imagen 3 +- **Text-to-Speech** - Generate audio with Google TTS +- **Video Generation** - Generate videos with Veo + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- Google GenAI API key (get one at https://aistudio.google.com/) + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your Google GenAI API key +export GOOGLE_GENAI_API_KEY=your-api-key-here +# Or use the alternative environment variable +export GOOGLE_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/google-genai + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI (Recommended) + +```bash +# Set your API key +export GOOGLE_GENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/google-genai + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +## Available Flows + +| Flow | Input | Output | Description | +|------|-------|--------|-------------| +| `textGeneration` | String (prompt) | String | Generate text with Gemini | +| `toolCalling` | String (query) | String | Demonstrate tool/function calling | +| `embeddings` | String (text) | String | Generate text embeddings | +| `imageGeneration` | String (prompt) | String | Generate images with Imagen | +| `textToSpeech` | String (text) | String | Generate audio with TTS | +| `videoGeneration` | String (prompt) | String | Generate videos with Veo | + +## Example API Calls + +Once the server is running on port 8080: + +### Text Generation +```bash +curl -X POST http://localhost:8080/textGeneration \ + -H 'Content-Type: application/json' \ + -d '"Explain quantum computing in simple terms"' +``` + +### Tool Calling +```bash +curl -X POST http://localhost:8080/toolCalling \ + -H 'Content-Type: application/json' \ + -d '"What is the weather in Tokyo?"' +``` + +### Embeddings +```bash +curl -X POST http://localhost:8080/embeddings \ + -H 'Content-Type: application/json' \ + -d '"Hello, world!"' +``` + +### Image Generation +```bash +curl -X POST http://localhost:8080/imageGeneration \ + -H 'Content-Type: application/json' \ + -d '"A serene Japanese garden with cherry blossoms"' +``` + +### Text-to-Speech +```bash +curl -X POST http://localhost:8080/textToSpeech \ + -H 'Content-Type: application/json' \ + -d '"Welcome to Genkit for Java!"' +``` + +### Video Generation +```bash +curl -X POST http://localhost:8080/videoGeneration \ + -H 'Content-Type: application/json' \ + -d '"A cat playing with a ball of yarn"' +``` + +## Generated Media Files + +Media files (images, audio, video) are saved to the `generated_media/` directory in the sample folder. + +## Available Models + +The Google GenAI plugin provides access to: + +| Model | Description | +|-------|-------------| +| `googleai/gemini-2.0-flash` | Fast, efficient Gemini model | +| `googleai/gemini-1.5-pro` | Advanced reasoning capabilities | +| `googleai/gemini-1.5-flash` | Balanced speed and capability | +| `googleai/imagen-3.0-generate-002` | Image generation | +| `googleai/text-embedding-004` | Text embeddings | + +## Code Highlights + +### Setting Up Genkit with Google GenAI + +```java +Genkit genkit = Genkit.builder() + .options(GenkitOptions.builder() + .devMode(true) + .reflectionPort(3100) + .build()) + .plugin(GoogleGenAIPlugin.create()) + .plugin(new JettyPlugin(JettyPluginOptions.builder() + .port(8080) + .build())) + .build(); +``` + +### Text Generation with Gemini + +```java +genkit.defineFlow("textGeneration", String.class, String.class, + (ctx, prompt) -> { + ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("googleai/gemini-2.0-flash") + .prompt(prompt) + .config(GenerationConfig.builder() + .temperature(0.7) + .maxOutputTokens(500) + .build()) + .build()); + return response.getText(); + }); +``` + +### Image Generation with Imagen + +```java +genkit.defineFlow("imageGeneration", String.class, String.class, + (ctx, prompt) -> { + ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("googleai/imagen-3.0-generate-002") + .prompt(prompt) + .build()); + // Save generated image to file + // ... + return "Image saved to: " + filePath; + }); +``` + +## Development UI + +When running with `genkit start`, access the Dev UI at http://localhost:4000 to: + +- Browse all registered flows and models +- Run flows with test inputs +- View execution traces and logs +- Preview generated content + +## Environment Variables + +| Variable | Description | +|----------|-------------| +| `GOOGLE_GENAI_API_KEY` | Google GenAI API key | +| `GOOGLE_API_KEY` | Alternative API key variable | + +## See Also + +- [Genkit Java README](../../README.md) +- [Google AI Studio](https://aistudio.google.com/) +- [Gemini API Documentation](https://ai.google.dev/docs) diff --git a/java/samples/google-genai/pom.xml b/java/samples/google-genai/pom.xml new file mode 100644 index 0000000000..72205ff528 --- /dev/null +++ b/java/samples/google-genai/pom.xml @@ -0,0 +1,75 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + genkit-sample-google-genai + jar + Genkit Google GenAI Sample + Sample application demonstrating Google GenAI (Gemini) integration with Genkit + + + + com.google.genkit + genkit + ${project.version} + + + + com.google.genkit + genkit-plugin-google-genai + ${project.version} + + + + com.google.genkit + genkit-plugin-jetty + ${project.version} + + + + + ch.qos.logback + logback-classic + + + + + + + org.codehaus.mojo + exec-maven-plugin + 3.2.0 + + com.google.genkit.samples.GoogleGenAIApp + + + + + diff --git a/java/samples/google-genai/run.sh b/java/samples/google-genai/run.sh new file mode 100755 index 0000000000..b31dc8b2f9 --- /dev/null +++ b/java/samples/google-genai/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Run the Google GenAI sample application with Genkit Dev UI +cd "$(dirname "$0")" +mvn exec:java -Dexec.mainClass="com.google.genkit.samples.GoogleGenAIApp" -q diff --git a/java/samples/google-genai/src/main/java/com/google/genkit/samples/GoogleGenAIApp.java b/java/samples/google-genai/src/main/java/com/google/genkit/samples/GoogleGenAIApp.java new file mode 100644 index 0000000000..5bb6d38184 --- /dev/null +++ b/java/samples/google-genai/src/main/java/com/google/genkit/samples/GoogleGenAIApp.java @@ -0,0 +1,350 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Base64; +import java.util.List; +import java.util.Map; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.Document; +import com.google.genkit.ai.EmbedResponse; +import com.google.genkit.ai.GenerateOptions; +import com.google.genkit.ai.GenerationConfig; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.Tool; +import com.google.genkit.core.ActionContext; +import com.google.genkit.plugins.googlegenai.GoogleGenAIPlugin; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; + +/** + * Sample application demonstrating Google GenAI (Gemini) integration with + * Genkit. + * + *

+ * This sample uses the Jetty plugin to expose flows via HTTP endpoints, + * allowing you to use the Genkit Developer UI. + * + *

+ * To run this sample: + *

    + *
  1. Set the GOOGLE_API_KEY environment variable with your Gemini API key
  2. + *
  3. Run: genkit start -- ./run.sh
  4. + *
  5. Open the Genkit Developer UI at http://localhost:4000
  6. + *
+ */ +public class GoogleGenAIApp { + + // Output directory for generated media files + private static final String OUTPUT_DIR = "generated_media"; + + public static void main(String[] args) throws Exception { + System.out.println("=== Google GenAI (Gemini) Sample with Dev UI ===\n"); + + // Create output directory for media files + createOutputDirectory(); + + // Create the Jetty server plugin for HTTP endpoints + JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build()); + + // Initialize Genkit with Google GenAI plugin and Jetty + Genkit genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(GoogleGenAIPlugin.create()).plugin(jetty).build(); + + // Define flows for the Dev UI + defineTextGenerationFlow(genkit); + defineToolCallingFlow(genkit); + defineEmbeddingsFlow(genkit); + defineImageGenerationFlow(genkit); + defineTextToSpeechFlow(genkit); + defineVideoGenerationFlow(genkit); + + System.out.println("Server started on http://localhost:8080"); + System.out.println("Use Genkit Developer UI at http://localhost:4000 to interact with flows"); + System.out.println("\nAvailable flows:"); + System.out.println(" - textGeneration: Generate text with Gemini"); + System.out.println(" - toolCalling: Demonstrate tool/function calling"); + System.out.println(" - embeddings: Generate text embeddings"); + System.out.println(" - imageGeneration: Generate images with Imagen (saves to " + OUTPUT_DIR + "/)"); + System.out.println(" - textToSpeech: Generate audio with TTS (saves to " + OUTPUT_DIR + "/)"); + System.out.println(" - videoGeneration: Generate videos with Veo (saves to " + OUTPUT_DIR + "/)"); + System.out.println("\nGenerated media files will be saved to: " + new File(OUTPUT_DIR).getAbsolutePath()); + System.out.println("\nPress Ctrl+C to stop the server."); + + // Keep the application running + Thread.currentThread().join(); + } + + private static void createOutputDirectory() { + File dir = new File(OUTPUT_DIR); + if (!dir.exists()) { + dir.mkdirs(); + System.out.println("Created output directory: " + dir.getAbsolutePath()); + } + } + + /** + * Saves base64-encoded data to a file. + */ + private static String saveBase64ToFile(String base64Data, String filename) throws IOException { + byte[] data = Base64.getDecoder().decode(base64Data); + Path filePath = Paths.get(OUTPUT_DIR, filename); + try (FileOutputStream fos = new FileOutputStream(filePath.toFile())) { + fos.write(data); + } + return filePath.toAbsolutePath().toString(); + } + + /** + * Extracts base64 data from a data URL. + */ + private static String extractBase64FromDataUrl(String dataUrl) { + if (dataUrl.startsWith("data:")) { + int commaIndex = dataUrl.indexOf(","); + if (commaIndex > 0) { + return dataUrl.substring(commaIndex + 1); + } + } + return dataUrl; + } + + private static void defineTextGenerationFlow(Genkit genkit) { + genkit.defineFlow("textGeneration", String.class, String.class, (ctx, prompt) -> { + GenerationConfig config = GenerationConfig.builder().temperature(0.7).maxOutputTokens(500).build(); + + ModelResponse response = genkit.generate( + GenerateOptions.builder().model("googleai/gemini-2.0-flash").prompt(prompt).config(config).build()); + + return response.getText(); + }); + } + + @SuppressWarnings("unchecked") + private static void defineToolCallingFlow(Genkit genkit) { + // Define a simple weather tool + Tool, Map> weatherTool = genkit.defineTool("getWeather", + "Get the current weather for a location", + Map.of("type", "object", "properties", Map.of("location", + Map.of("type", "string", "description", "The city and state, e.g., San Francisco, CA"), "unit", + Map.of("type", "string", "enum", Arrays.asList("celsius", "fahrenheit"), "description", + "The temperature unit")), + "required", Arrays.asList("location")), + (Class>) (Class) Map.class, (ActionContext ctx, Map input) -> { + String location = (String) input.get("location"); + String unit = input.get("unit") != null ? (String) input.get("unit") : "celsius"; + // Mock weather response + Map result = Map.of("location", location, "temperature", 22, "unit", unit, + "condition", "Sunny"); + return result; + }); + + genkit.defineFlow("toolCalling", String.class, String.class, (ctx, prompt) -> { + ModelResponse response = genkit.generate(GenerateOptions.builder().model("googleai/gemini-2.0-flash") + .prompt(prompt).tools(List.of(weatherTool)).build()); + + return response.getText(); + }); + } + + private static void defineEmbeddingsFlow(Genkit genkit) { + genkit.defineFlow("embeddings", String.class, String.class, (ctx, text) -> { + List documents = Arrays.asList(Document.fromText(text)); + EmbedResponse response = genkit.embed("googleai/text-embedding-004", documents); + + if (response.getEmbeddings() != null && !response.getEmbeddings().isEmpty()) { + EmbedResponse.Embedding embedding = response.getEmbeddings().get(0); + return "Generated embedding with " + embedding.getValues().length + " dimensions"; + } + return "Failed to generate embedding"; + }); + } + + private static void defineImageGenerationFlow(Genkit genkit) { + genkit.defineFlow("imageGeneration", String.class, String.class, (ctx, prompt) -> { + Map imagenOptions = Map.of("numberOfImages", 1, "aspectRatio", "1:1"); + + GenerationConfig config = GenerationConfig.builder().custom(imagenOptions).build(); + + ModelResponse response = genkit.generate(GenerateOptions.builder() + .model("googleai/imagen-4.0-fast-generate-001").prompt(prompt).config(config).build()); + + // Save the generated image + if (response.getMessage() != null && response.getMessage().getContent() != null) { + StringBuilder result = new StringBuilder(); + int imageCount = 0; + + for (Part part : response.getMessage().getContent()) { + if (part.getMedia() != null) { + imageCount++; + String url = part.getMedia().getUrl(); + String contentType = part.getMedia().getContentType(); + + if (url.startsWith("data:")) { + // Extract and save base64 data + String base64Data = extractBase64FromDataUrl(url); + String extension = contentType != null && contentType.contains("png") ? "png" : "jpg"; + String filename = "image_" + System.currentTimeMillis() + "_" + imageCount + "." + + extension; + + try { + String savedPath = saveBase64ToFile(base64Data, filename); + result.append("Image ").append(imageCount).append(" saved to: ").append(savedPath) + .append("\n"); + } catch (IOException e) { + result.append("Image ").append(imageCount).append(" failed to save: ") + .append(e.getMessage()).append("\n"); + } + } else if (url.startsWith("gs://")) { + result.append("Image ").append(imageCount).append(" available at GCS: ").append(url) + .append("\n"); + } + } + } + + return result.length() > 0 ? result.toString().trim() : "No images generated"; + } + + return "No images generated"; + }); + } + + private static void defineTextToSpeechFlow(Genkit genkit) { + genkit.defineFlow("textToSpeech", String.class, String.class, (ctx, text) -> { + Map ttsOptions = Map.of("voiceName", "Zephyr" // Available: Zephyr, Puck, Charon, Kore, + // Fenrir, etc. + ); + + GenerationConfig config = GenerationConfig.builder().custom(ttsOptions).build(); + + ModelResponse response = genkit.generate(GenerateOptions.builder() + .model("googleai/gemini-2.5-flash-preview-tts").prompt(text).config(config).build()); + + // Save the generated audio + if (response.getMessage() != null && response.getMessage().getContent() != null) { + StringBuilder result = new StringBuilder(); + int audioCount = 0; + + for (Part part : response.getMessage().getContent()) { + if (part.getMedia() != null) { + audioCount++; + String url = part.getMedia().getUrl(); + String contentType = part.getMedia().getContentType(); + + if (url.startsWith("data:")) { + // Extract and save base64 data + String base64Data = extractBase64FromDataUrl(url); + // Determine file extension from content type + String extension = "wav"; + if (contentType != null) { + if (contentType.contains("mp3") || contentType.contains("mpeg")) { + extension = "mp3"; + } else if (contentType.contains("ogg")) { + extension = "ogg"; + } else if (contentType.contains("pcm")) { + extension = "pcm"; + } + } + String filename = "audio_" + System.currentTimeMillis() + "_" + audioCount + "." + + extension; + + try { + String savedPath = saveBase64ToFile(base64Data, filename); + result.append("Audio ").append(audioCount).append(" saved to: ").append(savedPath) + .append("\n"); + } catch (IOException e) { + result.append("Audio ").append(audioCount).append(" failed to save: ") + .append(e.getMessage()).append("\n"); + } + } + } + } + + return result.length() > 0 ? result.toString().trim() : "No audio generated"; + } + + return "No audio generated"; + }); + } + + private static void defineVideoGenerationFlow(Genkit genkit) { + genkit.defineFlow("videoGeneration", String.class, String.class, (ctx, prompt) -> { + Map veoOptions = Map.of("numberOfVideos", 1, "durationSeconds", 8, // Valid range: 4-8 + // seconds + "aspectRatio", "16:9", + // Note: "generateAudio" and "enhancePrompt" are only available for Vertex AI + "timeoutMs", 600000 // 10 minutes timeout + ); + + GenerationConfig config = GenerationConfig.builder().custom(veoOptions).build(); + + ModelResponse response = genkit.generate(GenerateOptions.builder().model("googleai/veo-3.0-generate-001") + .prompt(prompt).config(config).build()); + + // Save the generated video + if (response.getMessage() != null && response.getMessage().getContent() != null) { + StringBuilder result = new StringBuilder(); + int videoCount = 0; + + for (Part part : response.getMessage().getContent()) { + if (part.getMedia() != null) { + videoCount++; + String url = part.getMedia().getUrl(); + String contentType = part.getMedia().getContentType(); + + if (url.startsWith("data:")) { + // Extract and save base64 data + String base64Data = extractBase64FromDataUrl(url); + String extension = "mp4"; + if (contentType != null && contentType.contains("webm")) { + extension = "webm"; + } + String filename = "video_" + System.currentTimeMillis() + "_" + videoCount + "." + + extension; + + try { + String savedPath = saveBase64ToFile(base64Data, filename); + result.append("Video ").append(videoCount).append(" saved to: ").append(savedPath) + .append("\n"); + } catch (IOException e) { + result.append("Video ").append(videoCount).append(" failed to save: ") + .append(e.getMessage()).append("\n"); + } + } else if (url.startsWith("gs://")) { + result.append("Video ").append(videoCount).append(" available at GCS: ").append(url) + .append("\n"); + } + } + } + + return result.length() > 0 ? result.toString().trim() : "No videos generated"; + } + + return "No videos generated"; + }); + } +} diff --git a/java/samples/google-genai/src/main/resources/logback.xml b/java/samples/google-genai/src/main/resources/logback.xml new file mode 100644 index 0000000000..31a8091ad0 --- /dev/null +++ b/java/samples/google-genai/src/main/resources/logback.xml @@ -0,0 +1,15 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + diff --git a/java/samples/interrupts/README.md b/java/samples/interrupts/README.md new file mode 100644 index 0000000000..68fc2a3dc8 --- /dev/null +++ b/java/samples/interrupts/README.md @@ -0,0 +1,217 @@ +# Genkit Interrupts Sample + +This sample demonstrates human-in-the-loop patterns using Genkit's interrupt mechanism, where AI operations can pause for user confirmation before executing sensitive actions. + +## Features Demonstrated + +- **Interrupt Pattern** - Pause execution for human confirmation +- **Sensitive Operations** - Money transfers requiring approval +- **Tool Interrupts** - Tools that request user input +- **Resume Flow** - Continue execution after user provides input +- **Session Persistence** - Maintain state across interrupts + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/interrupts + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/interrupts + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +## How Interrupts Work + +``` +┌─────────────────────────────────────────────────────────────┐ +│ User: "Transfer $500 to John" │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ AI decides to use transferMoney tool │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Tool triggers INTERRUPT (sensitive operation) │ +│ ⏸️ Execution paused │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ User sees: "Confirm transfer of $500 to John? [y/n]" │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ User confirms: "y" │ +│ ▶️ Execution resumes with user confirmation │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Transfer executed successfully │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Example Interaction + +``` +🏦 Welcome to AI Banking Assistant +Your current balance: $1000.00 + +You: Transfer $500 to Alice for rent + +⚠️ CONFIRMATION REQUIRED +───────────────────────────── +Action: TRANSFER +Recipient: Alice +Amount: $500.00 +Reason: rent + +Do you approve this action? (yes/no): yes + +✅ Transfer approved and executed! +New balance: $500.00 + +You: Transfer $2000 to Bob + +⚠️ CONFIRMATION REQUIRED +───────────────────────────── +Action: TRANSFER +Recipient: Bob +Amount: $2000.00 + +Do you approve this action? (yes/no): no + +❌ Transfer declined by user. +Your balance remains: $500.00 +``` + +## Key Concepts + +### InterruptConfig + +Configure which operations should trigger interrupts: + +```java +InterruptConfig interruptConfig = InterruptConfig.builder() + .enabled(true) + .tools(List.of("transferMoney", "deleteAccount")) + .build(); +``` + +### Creating an Interruptible Tool + +```java +Tool transferTool = genkit.defineTool( + "transferMoney", + "Transfers money to another account", + schema, + TransferRequest.class, + (ctx, request) -> { + // This tool will trigger an interrupt + ctx.interrupt(InterruptRequest.builder() + .type("CONFIRMATION") + .data(Map.of( + "action", "TRANSFER", + "recipient", request.getRecipient(), + "amount", request.getAmount() + )) + .build()); + + // Code here runs after user confirms + return executeTransfer(request); + }); +``` + +### Handling Interrupts + +```java +try { + ModelResponse response = genkit.generate(options); + // Normal response +} catch (InterruptException e) { + InterruptRequest interrupt = e.getInterruptRequest(); + + // Show confirmation to user + boolean confirmed = promptUser(interrupt); + + if (confirmed) { + // Resume with confirmation + ModelResponse response = genkit.resume( + ResumeOptions.builder() + .interruptId(interrupt.getId()) + .response(new ConfirmationOutput(true, "User approved")) + .build()); + } +} +``` + +## Account State + +The sample simulates a bank account: + +```java +public class AccountState { + private double balance = 1000.00; + private List transactions; + + public void transfer(String recipient, double amount) { + if (amount > balance) { + throw new InsufficientFundsException(); + } + balance -= amount; + transactions.add(new Transaction("TRANSFER", recipient, amount)); + } +} +``` + +## Use Cases for Interrupts + +1. **Financial Transactions** - Require approval for transfers over a threshold +2. **Data Deletion** - Confirm before deleting important data +3. **External Actions** - Approve sending emails, making API calls +4. **Access Control** - Verify identity before sensitive operations +5. **Multi-Step Workflows** - Checkpoint approval in long processes + +## Development UI + +When running with `genkit start`, access the Dev UI at http://localhost:4000 to: + +- Test interruptible flows +- View interrupt requests in traces +- Manually approve/reject interrupts +- Inspect state before and after interrupts + +## See Also + +- [Genkit Java README](../../README.md) +- [Chat Sessions Sample](../chat-session/README.md) +- [Multi-Agent Sample](../multi-agent/README.md) diff --git a/java/samples/interrupts/pom.xml b/java/samples/interrupts/pom.xml new file mode 100644 index 0000000000..1091e183d0 --- /dev/null +++ b/java/samples/interrupts/pom.xml @@ -0,0 +1,80 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-interrupts + jar + Genkit Interrupts Sample + Sample application demonstrating human-in-the-loop patterns with interrupts + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.4.14 + + + + + + + org.codehaus.mojo + exec-maven-plugin + 3.1.0 + + com.google.genkit.samples.InterruptsApp + + + + + diff --git a/java/samples/interrupts/run.sh b/java/samples/interrupts/run.sh new file mode 100755 index 0000000000..7a055a49ca --- /dev/null +++ b/java/samples/interrupts/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Run script for Genkit DotPrompt Sample +cd "$(dirname "$0")" +mvn exec:java diff --git a/java/samples/interrupts/src/main/java/com/google/genkit/samples/InterruptsApp.java b/java/samples/interrupts/src/main/java/com/google/genkit/samples/InterruptsApp.java new file mode 100644 index 0000000000..dda09ed8ac --- /dev/null +++ b/java/samples/interrupts/src/main/java/com/google/genkit/samples/InterruptsApp.java @@ -0,0 +1,584 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Scanner; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.GenerateOptions; +import com.google.genkit.ai.InterruptConfig; +import com.google.genkit.ai.InterruptRequest; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Part; +import com.google.genkit.ai.ResumeOptions; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.ToolResponse; +import com.google.genkit.ai.session.Chat; +import com.google.genkit.ai.session.ChatOptions; +import com.google.genkit.ai.session.InMemorySessionStore; +import com.google.genkit.ai.session.Session; +import com.google.genkit.ai.session.SessionOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; +import com.google.genkit.prompt.ExecutablePrompt; + +/** + * Human-in-the-Loop Application using Interrupts. + * + *

+ * This sample demonstrates the interrupt pattern for human-in-the-loop + * scenarios: + *

    + *
  • Tools that pause execution to request user confirmation
  • + *
  • Handling interrupt requests and resuming with user input
  • + *
  • Sensitive operations that require explicit approval
  • + *
+ * + *

+ * To run: + *

    + *
  1. Set the OPENAI_API_KEY environment variable
  2. + *
  3. Run: mvn exec:java -pl samples/interrupts
  4. + *
+ */ +public class InterruptsApp { + + /** Confirmation input structure. */ + public static class ConfirmationInput { + private String action; + private String details; + private double amount; + + public ConfirmationInput() { + } + + public String getAction() { + return action; + } + public void setAction(String action) { + this.action = action; + } + public String getDetails() { + return details; + } + public void setDetails(String details) { + this.details = details; + } + public double getAmount() { + return amount; + } + public void setAmount(double amount) { + this.amount = amount; + } + } + + /** Transfer request for the interrupt tool input. */ + public static class TransferRequest { + private String recipient; + private double amount; + private String reason; + + public TransferRequest() { + } + + public String getRecipient() { + return recipient; + } + public void setRecipient(String recipient) { + this.recipient = recipient; + } + public double getAmount() { + return amount; + } + public void setAmount(double amount) { + this.amount = amount; + } + public String getReason() { + return reason; + } + public void setReason(String reason) { + this.reason = reason; + } + } + + /** Confirmation output structure. */ + public static class ConfirmationOutput { + private boolean confirmed; + private String reason; + + public ConfirmationOutput() { + } + public ConfirmationOutput(boolean confirmed, String reason) { + this.confirmed = confirmed; + this.reason = reason; + } + + public boolean isConfirmed() { + return confirmed; + } + public void setConfirmed(boolean confirmed) { + this.confirmed = confirmed; + } + public String getReason() { + return reason; + } + public void setReason(String reason) { + this.reason = reason; + } + } + + /** Bank account state. */ + public static class AccountState { + private String accountId; + private double balance; + private List transactions = new ArrayList<>(); + + public AccountState() { + this.accountId = "ACC-" + System.currentTimeMillis() % 10000; + this.balance = 5000.00; // Starting balance + } + + public String getAccountId() { + return accountId; + } + public double getBalance() { + return balance; + } + + public void addTransaction(String transaction, double amount) { + this.balance += amount; + this.transactions.add(transaction); + } + + public List getTransactions() { + return transactions; + } + + @Override + public String toString() { + return String.format("Account: %s, Balance: $%.2f, Transactions: %d", accountId, balance, + transactions.size()); + } + } + + /** Banking request input for the prompt. */ + public static class BankingInput { + private String request; + + public BankingInput() { + } + + public BankingInput(String request) { + this.request = request; + } + + public String getRequest() { + return request; + } + + public void setRequest(String request) { + this.request = request; + } + } + + private final Genkit genkit; + private final InMemorySessionStore sessionStore; + private final Scanner scanner; + + // Tools + private Tool getBalanceTool; + private Tool transferMoneyTool; + private Tool confirmTransferTool; + + public InterruptsApp() { + this.genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3102).build()) + .plugin(OpenAIPlugin.create()).build(); + + this.sessionStore = new InMemorySessionStore<>(); + this.scanner = new Scanner(System.in); + + initializeTools(); + } + + @SuppressWarnings("unchecked") + private void initializeTools() { + // Get Balance Tool - no confirmation needed + getBalanceTool = genkit.defineTool("getBalance", "Gets the current account balance", + Map.of("type", "object", "properties", Map.of()), (Class>) (Class) Map.class, + (ctx, input) -> { + // In a real app, we'd get this from session context + return Map.of("balance", 5000.00, "currency", "USD"); + }); + + // Use defineInterrupt to create an interrupt tool that pauses for confirmation. + // This is the preferred way to create interrupt tools - it automatically + // handles + // throwing ToolInterruptException with the proper metadata. + confirmTransferTool = genkit + .defineInterrupt(InterruptConfig.builder().name("confirmTransfer") + .description("Request user confirmation before executing a money transfer. " + + "ALWAYS use this tool before transferring money.") + .inputType(TransferRequest.class).outputType(ConfirmationOutput.class) + .inputSchema(Map.of("type", "object", "properties", + Map.of("recipient", Map.of("type", "string", "description", "Who to transfer to"), + "amount", Map.of("type", "number", "description", "Amount to transfer"), + "reason", Map.of("type", "string", "description", "Reason for transfer")), + "required", new String[]{"recipient", "amount"})) + // requestMetadata extracts info from input for the interrupt request + .requestMetadata(input -> Map.of("type", "transfer_confirmation", "recipient", + input.getRecipient(), "amount", input.getAmount(), "reason", + input.getReason() != null ? input.getReason() : "")) + .build()); + + // Transfer Money Tool - executes after confirmation + transferMoneyTool = genkit.defineTool("executeTransfer", + "Executes a confirmed money transfer. Only call this after confirmation.", + Map.of("type", "object", "properties", + Map.of("recipient", Map.of("type", "string", "description", "Transfer recipient"), "amount", + Map.of("type", "number", "description", "Amount to transfer"), "confirmationCode", + Map.of("type", "string", "description", "Confirmation code from user")), + "required", new String[]{"recipient", "amount", "confirmationCode"}), + (Class>) (Class) Map.class, (ctx, input) -> { + String recipient = (String) input.get("recipient"); + double amount = ((Number) input.get("amount")).doubleValue(); + String transactionId = "TXN-" + System.currentTimeMillis() % 100000; + + return Map.of("status", "success", "transactionId", transactionId, "recipient", recipient, "amount", + amount, "message", String.format("Successfully transferred $%.2f to %s. Transaction ID: %s", + amount, recipient, transactionId)); + }); + } + + /** Creates a chat session. */ + @SuppressWarnings("unchecked") + public Chat createChat() { + Session session = genkit.createSession( + SessionOptions.builder().store(sessionStore).initialState(new AccountState()).build()); + + String systemPrompt = "You are a helpful banking assistant for SecureBank. " + + "You can help customers check their balance and transfer money. " + + "IMPORTANT: For any money transfer, you MUST first use the confirmTransfer tool " + + "to get user confirmation. Never execute a transfer without confirmation. " + + "After the user confirms, use the executeTransfer tool with their confirmation code."; + + return session.chat(ChatOptions.builder().model("openai/gpt-4o-mini").system(systemPrompt) + .tools(List.of(getBalanceTool, confirmTransferTool, transferMoneyTool)).build()); + } + + /** Handles an interrupt by prompting the user. */ + private ConfirmationOutput handleInterrupt(InterruptRequest interrupt) { + Map metadata = interrupt.getMetadata(); + + System.out.println("\n╔═══════════════════════════════════════════════════════════╗"); + System.out.println("║ ⚠️ CONFIRMATION REQUIRED ⚠️ ║"); + System.out.println("╠═══════════════════════════════════════════════════════════╣"); + System.out.printf("║ Transfer: $%.2f to %s%n", metadata.get("amount"), metadata.get("recipient")); + if (metadata.get("reason") != null) { + System.out.printf("║ Reason: %s%n", metadata.get("reason")); + } + System.out.println("╠═══════════════════════════════════════════════════════════╣"); + System.out.println("║ Type 'yes' to confirm or 'no' to cancel ║"); + System.out.println("╚═══════════════════════════════════════════════════════════╝"); + System.out.print("\nYour decision: "); + + String response = scanner.nextLine().trim().toLowerCase(); + boolean confirmed = response.equals("yes") || response.equals("y"); + + if (confirmed) { + System.out.println("✓ Transfer confirmed"); + return new ConfirmationOutput(true, "User confirmed with code: CONF-" + System.currentTimeMillis() % 10000); + } else { + System.out.println("✗ Transfer cancelled"); + return new ConfirmationOutput(false, "User declined the transfer"); + } + } + + /** Sends a message and handles any interrupts. */ + public String sendWithInterruptHandling(Chat chat, String message) { + try { + ModelResponse response = chat.send(message); + + // Check for pending interrupts + if (chat.hasPendingInterrupts()) { + List interrupts = chat.getPendingInterrupts(); + + for (InterruptRequest interrupt : interrupts) { + // Handle the interrupt (get user confirmation) + ConfirmationOutput userResponse = handleInterrupt(interrupt); + + // Create resume options with the user's response + ToolResponse toolResponse = interrupt.respond(userResponse); + ResumeOptions resume = ResumeOptions.builder().respond(List.of(toolResponse)).build(); + + // Resume the conversation with the user's response + response = chat.send( + userResponse.isConfirmed() + ? "User confirmed. Proceed with the transfer." + : "User declined. Cancel the transfer.", + Chat.SendOptions.builder().resumeOptions(resume).build()); + } + } + + return response.getText(); + } catch (Exception e) { + return "Error: " + e.getMessage(); + } + } + + /** Interactive chat loop. */ + public void runInteractive() { + System.out.println("╔════════════════════════════════════════════════════════════════╗"); + System.out.println("║ SecureBank - Human-in-the-Loop Banking Assistant ║"); + System.out.println("╚════════════════════════════════════════════════════════════════╝"); + System.out.println(); + System.out.println("This demo shows the interrupt pattern for sensitive operations."); + System.out.println("Money transfers require explicit confirmation before execution."); + System.out.println(); + System.out.println("Try saying:"); + System.out.println(" • 'What's my balance?'"); + System.out.println(" • 'Transfer $100 to John for lunch'"); + System.out.println(" • 'Send $500 to Alice'"); + System.out.println(); + System.out.println("Commands: /status, /quit\n"); + + Chat chat = createChat(); + + while (true) { + System.out.print("You: "); + String input = scanner.nextLine().trim(); + + if (input.isEmpty()) + continue; + + if (input.equals("/quit") || input.equals("/exit")) { + System.out.println("\nThank you for banking with SecureBank!"); + break; + } + + if (input.equals("/status")) { + System.out.println("\n" + chat.getSession().getState() + "\n"); + continue; + } + + String response = sendWithInterruptHandling(chat, input); + System.out.println("\nAssistant: " + response + "\n"); + } + } + + /** Demo mode. */ + public void runDemo() { + System.out.println("╔════════════════════════════════════════════════════════════════╗"); + System.out.println("║ Interrupts Demo - Human-in-the-Loop Pattern ║"); + System.out.println("╚════════════════════════════════════════════════════════════════╝"); + System.out.println(); + System.out.println("This demo shows how interrupts work for human-in-the-loop scenarios."); + System.out.println("Watch how the system pauses for confirmation on sensitive operations.\n"); + + Chat chat = createChat(); + + // Demo 1: Check balance (no interrupt) + System.out.println("=== Demo 1: Simple Query (No Interrupt) ===\n"); + System.out.println("Customer: What's my current balance?"); + String response1 = sendWithInterruptHandling(chat, "What's my current balance?"); + System.out.println("Assistant: " + response1 + "\n"); + + // Demo 2: Transfer money (triggers interrupt) + System.out.println("\n=== Demo 2: Transfer Request (Triggers Interrupt) ===\n"); + System.out.println("Customer: Transfer $250 to John Smith for the concert tickets"); + System.out.println("\n[The system will now request confirmation...]\n"); + + // For demo, we'll use a mock confirmation + String response2 = sendWithInterruptHandling(chat, "Transfer $250 to John Smith for the concert tickets"); + System.out.println("\nAssistant: " + response2); + + System.out.println("\n=== Demo Complete ==="); + System.out.println("Final state: " + chat.getSession().getState()); + } + + /** + * Demo using generate() directly with interrupts (without Chat). + * + *

+ * This shows how to use interrupts at the lower level generate() API, which is + * useful when you don't need session management. + */ + public void runGenerateDemo() { + System.out.println("╔════════════════════════════════════════════════════════════════╗"); + System.out.println("║ Interrupts with generate() - Low-Level API Demo ║"); + System.out.println("╚════════════════════════════════════════════════════════════════╝"); + System.out.println(); + System.out.println("This demo shows how to use interrupts with the generate() method."); + System.out.println("This is useful when you don't need Chat's session management.\n"); + + String model = "openai/gpt-4o-mini"; + + // Create a simple confirm transfer interrupt tool + @SuppressWarnings("unchecked") + Tool confirmTool = (Tool) confirmTransferTool; + + // Initial request - transfer money + System.out.println("=== Step 1: Initial Generate Request ===\n"); + System.out.println("Prompt: Transfer $150 to Alice for dinner\n"); + + ModelResponse response = genkit + .generate(GenerateOptions.builder().model(model).prompt("Transfer $150 to Alice for dinner") + .system("You are a banking assistant. Use the confirmTransfer tool for any transfers.") + .tools(List.of(confirmTransferTool)).build()); + + System.out.println("Response finish reason: " + response.getFinishReason()); + + // Check if we got an interrupt + if (response.isInterrupted()) { + System.out.println("✓ Generation was interrupted!"); + System.out.println(" Number of interrupts: " + response.getInterrupts().size()); + + Part interrupt = response.getInterrupts().get(0); + Map metadata = interrupt.getMetadata(); + System.out.println(" Interrupt metadata: " + metadata); + + // Get user confirmation + System.out.println("\n=== Step 2: Get User Confirmation ===\n"); + System.out.print("Confirm transfer of $150 to Alice? (yes/no): "); + String userInput = scanner.nextLine().trim().toLowerCase(); + boolean confirmed = userInput.equals("yes") || userInput.equals("y"); + + // Create the response to the interrupt + ConfirmationOutput userResponse = new ConfirmationOutput(confirmed, + confirmed ? "User approved" : "User declined"); + + // Use the tool's respond helper + Part responseData = confirmTool.respond(interrupt, userResponse); + + System.out.println("\n=== Step 3: Resume Generation ===\n"); + System.out.println("Resuming with user " + (confirmed ? "confirmation" : "rejection") + "...\n"); + + // Resume generation with the user's response + ModelResponse resumedResponse = genkit + .generate(GenerateOptions.builder().model(model).messages(response.getMessages()) // Include + // previous + // context + .tools(List.of(confirmTransferTool)) + .resume(ResumeOptions.builder().respond(responseData.getToolResponse()).build()).build()); + + System.out.println("Final response: " + resumedResponse.getText()); + System.out.println("Finish reason: " + resumedResponse.getFinishReason()); + } else { + System.out.println("Response (no interrupt): " + response.getText()); + } + + System.out.println("\n=== Generate Demo Complete ==="); + } + + /** + * Demo using ExecutablePrompt with interrupts. + * + *

+ * This shows how to use interrupts with the prompt() API, which allows you to + * load and execute .prompt files with tool and interrupt support. + */ + public void runPromptDemo() { + System.out.println("╔════════════════════════════════════════════════════════════════╗"); + System.out.println("║ Interrupts with ExecutablePrompt - Prompt API Demo ║"); + System.out.println("╚════════════════════════════════════════════════════════════════╝"); + System.out.println(); + System.out.println("This demo shows how to use interrupts with ExecutablePrompt."); + System.out.println("It loads a .prompt file and adds tools with interrupt support.\n"); + + // Load the prompt + ExecutablePrompt bankingPrompt = genkit.prompt("banking-assistant", BankingInput.class); + + // Create a simple confirm transfer interrupt tool + @SuppressWarnings("unchecked") + Tool confirmTool = (Tool) confirmTransferTool; + + // Initial request - transfer money + System.out.println("=== Step 1: Execute Prompt with Tools ===\n"); + System.out.println("Using prompt: banking-assistant.prompt"); + System.out.println("Input: Transfer $200 to Bob for concert tickets\n"); + + BankingInput input = new BankingInput("Transfer $200 to Bob for concert tickets"); + + // Generate with tools - the prompt will use Genkit.generate() internally + // which supports interrupts + ModelResponse response = bankingPrompt.generate(input, + GenerateOptions.builder().tools(List.of(confirmTransferTool)).build()); + + System.out.println("Response finish reason: " + response.getFinishReason()); + + // Check if we got an interrupt + if (response.isInterrupted()) { + System.out.println("✓ Prompt execution was interrupted!"); + System.out.println(" Number of interrupts: " + response.getInterrupts().size()); + + Part interrupt = response.getInterrupts().get(0); + Map metadata = interrupt.getMetadata(); + System.out.println(" Interrupt metadata: " + metadata); + + // Get user confirmation + System.out.println("\n=== Step 2: Get User Confirmation ===\n"); + System.out.print("Confirm transfer of $200 to Bob? (yes/no): "); + String userInput = scanner.nextLine().trim().toLowerCase(); + boolean confirmed = userInput.equals("yes") || userInput.equals("y"); + + // Create the response to the interrupt + ConfirmationOutput userResponse = new ConfirmationOutput(confirmed, + confirmed ? "User approved" : "User declined"); + + // Use the tool's respond helper + Part responseData = confirmTool.respond(interrupt, userResponse); + + System.out.println("\n=== Step 3: Resume Prompt Execution ===\n"); + System.out.println("Resuming with user " + (confirmed ? "confirmation" : "rejection") + "...\n"); + + // Resume generation with the user's response + // Note: For full resume, you would use genkit.generate() with the messages + ModelResponse resumedResponse = genkit.generate(GenerateOptions.builder().model(bankingPrompt.getModel()) + .messages(response.getMessages()).tools(List.of(confirmTransferTool)) + .resume(ResumeOptions.builder().respond(responseData.getToolResponse()).build()).build()); + + System.out.println("Final response: " + resumedResponse.getText()); + System.out.println("Finish reason: " + resumedResponse.getFinishReason()); + } else { + System.out.println("Response (no interrupt): " + response.getText()); + } + + System.out.println("\n=== Prompt Demo Complete ==="); + } + + public static void main(String[] args) { + InterruptsApp app = new InterruptsApp(); + + boolean demoMode = args.length > 0 && args[0].equals("--demo"); + boolean generateDemo = args.length > 0 && args[0].equals("--generate"); + boolean promptDemo = args.length > 0 && args[0].equals("--prompt"); + + if (promptDemo) { + app.runPromptDemo(); + } else if (generateDemo) { + app.runGenerateDemo(); + } else if (demoMode) { + app.runDemo(); + } else { + app.runInteractive(); + } + } +} diff --git a/java/samples/interrupts/src/main/resources/logback.xml b/java/samples/interrupts/src/main/resources/logback.xml new file mode 100644 index 0000000000..d63c14f8a8 --- /dev/null +++ b/java/samples/interrupts/src/main/resources/logback.xml @@ -0,0 +1,13 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + diff --git a/java/samples/interrupts/src/main/resources/prompts/banking-assistant.prompt b/java/samples/interrupts/src/main/resources/prompts/banking-assistant.prompt new file mode 100644 index 0000000000..a2ea8507b2 --- /dev/null +++ b/java/samples/interrupts/src/main/resources/prompts/banking-assistant.prompt @@ -0,0 +1,20 @@ +--- +model: openai/gpt-4o-mini +input: + schema: + type: object + properties: + request: + type: string + description: The user's banking request + required: + - request +--- +You are a helpful banking assistant for SecureBank. +You can help customers check their balance and transfer money. + +IMPORTANT: For any money transfer, you MUST first use the confirmTransfer tool +to get user confirmation. Never execute a transfer without confirmation. +After the user confirms, use the executeTransfer tool with their confirmation code. + +User request: {{request}} diff --git a/java/samples/mcp/README.md b/java/samples/mcp/README.md new file mode 100644 index 0000000000..e3d162e351 --- /dev/null +++ b/java/samples/mcp/README.md @@ -0,0 +1,258 @@ +# Genkit MCP Sample + +This sample demonstrates how to use the Genkit MCP (Model Context Protocol) plugin to integrate with MCP servers. + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- Node.js and npm (required for running MCP servers via `npx`) +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/mcp + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI (Recommended) + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/mcp + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +The sample will start with: +- HTTP server on port 8080 +- Reflection server on port 3100 (for Genkit Dev UI) +- MCP connections to filesystem and everything servers + +## Available Flows + +### listMcpTools +Lists all available MCP tools from connected servers. +```bash +curl -X POST http://localhost:8080/listMcpTools \ + -H 'Content-Type: application/json' -d 'null' +``` + +### fileAssistant +AI-powered file operations assistant that can read, write, and list files. +```bash +curl -X POST http://localhost:8080/fileAssistant \ + -H 'Content-Type: application/json' \ + -d '"List all files in the temp directory"' +``` + +### readFile +Directly read a file using the MCP filesystem tool. +```bash +curl -X POST http://localhost:8080/readFile \ + -H 'Content-Type: application/json' \ + -d '"/tmp/test.txt"' +``` + +### listResources +List resources from a specific MCP server. +```bash +curl -X POST http://localhost:8080/listResources \ + -H 'Content-Type: application/json' \ + -d '"filesystem"' +``` + +### toolExplorer +AI assistant with access to all MCP tools. +```bash +curl -X POST http://localhost:8080/toolExplorer \ + -H 'Content-Type: application/json' \ + -d '"Generate a random UUID"' +``` + +### mcpStatus +Get the status of all connected MCP servers. +```bash +curl -X POST http://localhost:8080/mcpStatus \ + -H 'Content-Type: application/json' -d 'null' +``` + +### writeReadDemo +Demo that writes content to a file and reads it back. +```bash +curl -X POST http://localhost:8080/writeReadDemo \ + -H 'Content-Type: application/json' \ + -d '"Hello from Genkit MCP!"' +``` + +## Configuration + +### Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `OPENAI_API_KEY` | OpenAI API key for model access | (required) | +| `MCP_ALLOWED_DIR` | Directory the filesystem server can access | `/tmp` | + +### MCP Servers Used + +This sample connects to two MCP servers: + +1. **filesystem** (`@modelcontextprotocol/server-filesystem`) + - Provides file read/write/list operations + - Limited to the `MCP_ALLOWED_DIR` directory for security + +2. **everything** (`@modelcontextprotocol/server-everything`) + - Demo server with various tool types (echo, random, etc.) + - Useful for testing MCP integration + +## MCP Server Sample + +This sample also includes `MCPServerSample`, which demonstrates how to **expose** Genkit tools as an MCP server. + +### Running the MCP Server + +```bash +# Run directly +mvn exec:java -Dexec.mainClass="com.google.genkit.samples.MCPServerSample" + +# Or build and run the JAR +mvn package -DskipTests +java -jar target/genkit-mcp-sample-1.0.0-SNAPSHOT-server.jar +``` + +### Using with Claude Desktop + +Add to your `~/.config/claude/claude_desktop_config.json` (macOS) or similar on other platforms: + +```json +{ + "mcpServers": { + "genkit-tools": { + "command": "java", + "args": ["-jar", "/absolute/path/to/genkit-mcp-sample-1.0.0-SNAPSHOT-server.jar"] + } + } +} +``` + +### Available Tools + +The MCP server exposes these demonstration tools: + +| Tool | Description | +|------|-------------| +| `calculator` | Basic math operations (add, subtract, multiply, divide) | +| `get_weather` | Mock weather data for any location | +| `get_datetime` | Current date/time in various formats | +| `greet` | Personalized greeting generator | +| `translate_mock` | Mock translation tool | + +## Code Examples + +### Adding Your Own MCP Server + +```java +MCPPluginOptions mcpOptions = MCPPluginOptions.builder() + .name("my-app") + .addServer("filesystem", MCPServerConfig.stdio( + "npx", "-y", "@modelcontextprotocol/server-filesystem", "/path/to/files")) + .addServer("github", MCPServerConfig.builder() + .command("npx") + .args("-y", "@modelcontextprotocol/server-github") + .env("GITHUB_TOKEN", System.getenv("GITHUB_TOKEN")) + .build()) + .addServer("remote", MCPServerConfig.http("http://mcp-server.example.com:3001/mcp")) + .build(); + +MCPPlugin mcpPlugin = MCPPlugin.create(mcpOptions); +``` + +### Creating Your Own MCP Server + +```java +// Define tools with Genkit +Genkit genkit = Genkit.builder().build(); + +genkit.defineTool("my_tool", "My custom tool", + Map.of("type", "object", "properties", Map.of( + "input", Map.of("type", "string") + )), + (Class>) (Class) Map.class, + (ctx, input) -> { + return Map.of("result", "processed: " + input.get("input")); + }); + +genkit.init(); + +// Create and start MCP server +MCPServer mcpServer = new MCPServer(genkit.getRegistry(), + MCPServerOptions.builder() + .name("my-server") + .version("1.0.0") + .build()); + +mcpServer.start(); // Uses STDIO transport + (ctx, input) -> { + List> tools = mcpPlugin.getTools(); + + ModelResponse response = genkit.generate(GenerateOptions.builder() + .model("openai/gpt-4o") + .prompt(input) + .tools(tools) + .build()); + + return response.getText(); + }); +``` + +### Direct Tool Calls + +```java +// Write a file +mcpPlugin.callTool("filesystem", "write_file", + Map.of("path", "/tmp/hello.txt", "content", "Hello World!")); + +// Read it back +Object content = mcpPlugin.callTool("filesystem", "read_file", + Map.of("path", "/tmp/hello.txt")); +``` + +## Troubleshooting + +### "Command not found: npx" +Ensure Node.js and npm are installed and in your PATH. + +### Connection timeouts +Check that the MCP server package can be downloaded via npm. You may need to configure proxy settings. + +### Permission denied errors +Make sure the `MCP_ALLOWED_DIR` directory exists and is writable. + +### Tools not appearing +Check the logs for MCP connection errors. Increase log level for more details: +```xml + +``` + +## Learn More + +- [MCP Plugin Documentation](../../plugins/mcp/README.md) +- [Model Context Protocol](https://modelcontextprotocol.io/) +- [Available MCP Servers](https://github.com/modelcontextprotocol/servers) diff --git a/java/samples/mcp/dependency-reduced-pom.xml b/java/samples/mcp/dependency-reduced-pom.xml new file mode 100644 index 0000000000..2636ee1172 --- /dev/null +++ b/java/samples/mcp/dependency-reduced-pom.xml @@ -0,0 +1,74 @@ + + + + genkit-parent + com.google.genkit + 1.0.0-SNAPSHOT + ../../pom.xml + + 4.0.0 + com.google.genkit.samples + genkit-sample-mcp + Genkit MCP Sample + Sample application demonstrating Genkit with MCP (Model Context Protocol) + + + + maven-compiler-plugin + 3.12.1 + + 17 + 17 + + + + org.codehaus.mojo + exec-maven-plugin + 3.2.0 + + + mcp-client + + com.google.genkit.samples.MCPSample + + + + mcp-server + + com.google.genkit.samples.MCPServerSample + + + + + com.google.genkit.samples.MCPSample + + + + maven-shade-plugin + 3.5.1 + + + package + + shade + + + + + com.google.genkit.samples.MCPServerSample + + + genkit-mcp-server + + + + + + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + diff --git a/java/samples/mcp/pom.xml b/java/samples/mcp/pom.xml new file mode 100644 index 0000000000..c12fffd50e --- /dev/null +++ b/java/samples/mcp/pom.xml @@ -0,0 +1,141 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-mcp + jar + Genkit MCP Sample + Sample application demonstrating Genkit with MCP (Model Context Protocol) + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + + + + + com.google.genkit + genkit + ${genkit.version} + + + + + com.google.genkit + genkit-plugin-mcp + ${genkit.version} + + + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + + + ch.qos.logback + logback-classic + 1.5.3 + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.12.1 + + 17 + 17 + + + + org.codehaus.mojo + exec-maven-plugin + 3.2.0 + + com.google.genkit.samples.MCPSample + + + + + mcp-client + + com.google.genkit.samples.MCPSample + + + + + mcp-server + + com.google.genkit.samples.MCPServerSample + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.5.1 + + + package + + shade + + + + + com.google.genkit.samples.MCPServerSample + + + genkit-mcp-server + + + + + + + diff --git a/java/samples/mcp/run.sh b/java/samples/mcp/run.sh new file mode 100755 index 0000000000..17ffb09f35 --- /dev/null +++ b/java/samples/mcp/run.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Run script for Genkit MCP Sample + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +cd "$(dirname "$0")" + +# Check for OPENAI_API_KEY +if [ -z "$OPENAI_API_KEY" ]; then + echo "Warning: OPENAI_API_KEY environment variable is not set." + echo "Some features may not work without it." +fi + +# Set MCP_ALLOWED_DIR if not already set +if [ -z "$MCP_ALLOWED_DIR" ]; then + export MCP_ALLOWED_DIR="/tmp" + echo "MCP_ALLOWED_DIR not set, using default: $MCP_ALLOWED_DIR" +fi + +mvn exec:java diff --git a/java/samples/mcp/src/main/java/com/google/genkit/samples/MCPSample.java b/java/samples/mcp/src/main/java/com/google/genkit/samples/MCPSample.java new file mode 100644 index 0000000000..f82b137b8d --- /dev/null +++ b/java/samples/mcp/src/main/java/com/google/genkit/samples/MCPSample.java @@ -0,0 +1,329 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.util.List; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.GenerateOptions; +import com.google.genkit.ai.GenerationConfig; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Tool; +import com.google.genkit.core.Flow; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; +import com.google.genkit.plugins.mcp.MCPClient; +import com.google.genkit.plugins.mcp.MCPPlugin; +import com.google.genkit.plugins.mcp.MCPPluginOptions; +import com.google.genkit.plugins.mcp.MCPResource; +import com.google.genkit.plugins.mcp.MCPServerConfig; +import com.google.genkit.plugins.openai.OpenAIPlugin; + +/** + * Sample application demonstrating Genkit with MCP (Model Context Protocol). + * + *

+ * This example shows how to: + *

    + *
  • Configure the MCP plugin with different server types
  • + *
  • Connect to MCP servers via STDIO (local processes) and HTTP
  • + *
  • Use MCP tools in Genkit flows
  • + *
  • Access MCP resources
  • + *
  • Combine MCP tools with AI models for powerful workflows
  • + *
+ * + *

+ * Prerequisites: + *

    + *
  • Node.js and npm installed (for MCP server packages)
  • + *
  • OPENAI_API_KEY environment variable set
  • + *
+ * + *

+ * To run: + *

    + *
  1. Set the OPENAI_API_KEY environment variable
  2. + *
  3. Run: mvn exec:java
  4. + *
+ * + *

+ * Available MCP servers in this sample: + *

    + *
  • filesystem: Access files + * using @modelcontextprotocol/server-filesystem
  • + *
  • everything: Demo server with various tool types
  • + *
+ */ +public class MCPSample { + + private static final Logger logger = LoggerFactory.getLogger(MCPSample.class); + + public static void main(String[] args) throws Exception { + logger.info("Starting Genkit MCP Sample..."); + + // ======================================================= + // Configure MCP Plugin with multiple servers + // ======================================================= + + // Get the allowed directory for the filesystem server + // Default to /tmp or use MCP_ALLOWED_DIR environment variable + String tempAllowedDir = System.getenv("MCP_ALLOWED_DIR"); + if (tempAllowedDir == null || tempAllowedDir.isEmpty()) { + tempAllowedDir = System.getProperty("java.io.tmpdir"); + } + final String allowedDir = tempAllowedDir; + logger.info("Filesystem server will have access to: {}", allowedDir); + + MCPPluginOptions mcpOptions = MCPPluginOptions.builder().name("genkit-mcp-sample") + // Filesystem server - allows file operations in allowed directory + .addServer("filesystem", + MCPServerConfig.stdio("npx", "-y", "@modelcontextprotocol/server-filesystem", allowedDir)) + // Everything server - demo server with various tools + .addServer("everything", MCPServerConfig.stdio("npx", "-y", "@modelcontextprotocol/server-everything")) + .build(); + + MCPPlugin mcpPlugin = MCPPlugin.create(mcpOptions); + + // Create the Jetty server plugin + JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build()); + + // ======================================================= + // Create Genkit with plugins + // ======================================================= + + Genkit genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(OpenAIPlugin.create()).plugin(mcpPlugin).plugin(jetty).build(); + + // ======================================================= + // Example 1: List available MCP tools + // ======================================================= + + Flow listToolsFlow = genkit.defineFlow("listMcpTools", Void.class, String.class, + (ctx, input) -> { + StringBuilder sb = new StringBuilder(); + sb.append("=== Available MCP Tools ===\n\n"); + + List> tools = mcpPlugin.getTools(); + for (Tool tool : tools) { + sb.append("- ").append(tool.getName()).append("\n"); + sb.append(" Description: ").append(tool.getDescription()).append("\n\n"); + } + + return sb.toString(); + }); + + // ======================================================= + // Example 2: Use MCP filesystem tools with AI + // ======================================================= + + Flow fileAssistantFlow = genkit.defineFlow("fileAssistant", String.class, String.class, + (ctx, userRequest) -> { + logger.info("File assistant processing request: {}", userRequest); + + // Get all MCP tools + List> mcpTools = mcpPlugin.getTools(); + + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .system("You are a helpful file assistant. You can read, write, and list files " + + "using the available filesystem tools. The filesystem tools use the server name " + + "'filesystem' as prefix (e.g., 'filesystem/read_file'). " + + "Always explain what you're doing and show file contents when relevant.") + .prompt(userRequest).tools(mcpTools) + .config(GenerationConfig.builder().temperature(0.7).maxOutputTokens(1000).build()).build()); + + return response.getText(); + }); + + // ======================================================= + // Example 3: Direct MCP tool usage (without AI) + // ======================================================= + + Flow readFileFlow = genkit.defineFlow("readFile", String.class, String.class, + (ctx, filePath) -> { + logger.info("Reading file via MCP: {}", filePath); + + try { + Object result = mcpPlugin.callTool("filesystem", "read_file", Map.of("path", filePath)); + return result != null ? result.toString() : "File is empty"; + } catch (Exception e) { + return "Error reading file: " + e.getMessage(); + } + }); + + // ======================================================= + // Example 4: List MCP resources + // ======================================================= + + Flow listResourcesFlow = genkit.defineFlow("listResources", String.class, String.class, + (ctx, serverName) -> { + StringBuilder sb = new StringBuilder(); + sb.append("=== Resources from ").append(serverName).append(" ===\n\n"); + + try { + List resources = mcpPlugin.getResources(serverName); + if (resources.isEmpty()) { + sb.append("No resources available.\n"); + } else { + for (MCPResource resource : resources) { + sb.append("- URI: ").append(resource.getUri()).append("\n"); + sb.append(" Name: ").append(resource.getName()).append("\n"); + if (resource.getDescription() != null && !resource.getDescription().isEmpty()) { + sb.append(" Description: ").append(resource.getDescription()).append("\n"); + } + sb.append("\n"); + } + } + } catch (Exception e) { + sb.append("Error listing resources: ").append(e.getMessage()).append("\n"); + } + + return sb.toString(); + }); + + // ======================================================= + // Example 5: AI-powered tool exploration with 'everything' server + // ======================================================= + + Flow toolExplorerFlow = genkit.defineFlow("toolExplorer", String.class, String.class, + (ctx, query) -> { + logger.info("Tool explorer processing: {}", query); + + List> mcpTools = mcpPlugin.getTools(); + + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .system("You are a helpful assistant that can use various tools. " + + "You have access to tools from multiple MCP servers including 'filesystem' and 'everything'. " + + "Use the appropriate tools to help the user with their request. " + + "Explain what tools you're using and why.") + .prompt(query).tools(mcpTools) + .config(GenerationConfig.builder().temperature(0.7).maxOutputTokens(1000).build()).build()); + + return response.getText(); + }); + + // ======================================================= + // Example 6: Get MCP server status + // ======================================================= + + Flow serverStatusFlow = genkit.defineFlow("mcpStatus", Void.class, String.class, + (ctx, input) -> { + StringBuilder sb = new StringBuilder(); + sb.append("=== MCP Server Status ===\n\n"); + + Map clients = mcpPlugin.getClients(); + for (Map.Entry entry : clients.entrySet()) { + String serverName = entry.getKey(); + MCPClient client = entry.getValue(); + + sb.append("Server: ").append(serverName).append("\n"); + sb.append(" Connected: ").append(client.isConnected()).append("\n"); + + if (client.isConnected()) { + try { + List> tools = mcpPlugin.getTools(serverName); + sb.append(" Tools: ").append(tools.size()).append("\n"); + } catch (Exception e) { + sb.append(" Tools: Error - ").append(e.getMessage()).append("\n"); + } + + try { + List resources = mcpPlugin.getResources(serverName); + sb.append(" Resources: ").append(resources.size()).append("\n"); + } catch (Exception e) { + sb.append(" Resources: Error - ").append(e.getMessage()).append("\n"); + } + } + sb.append("\n"); + } + + return sb.toString(); + }); + + // ======================================================= + // Example 7: Write and read file demo + // ======================================================= + + Flow writeReadDemoFlow = genkit.defineFlow("writeReadDemo", String.class, String.class, + (ctx, content) -> { + String testFile = allowedDir + "/genkit-mcp-test.txt"; + StringBuilder sb = new StringBuilder(); + + try { + // Write content + sb.append("Writing to file: ").append(testFile).append("\n"); + mcpPlugin.callTool("filesystem", "write_file", Map.of("path", testFile, "content", content)); + sb.append("Write successful!\n\n"); + + // Read it back + sb.append("Reading file back:\n"); + Object readResult = mcpPlugin.callTool("filesystem", "read_file", Map.of("path", testFile)); + sb.append(readResult != null ? readResult.toString() : "(empty)"); + + } catch (Exception e) { + sb.append("Error: ").append(e.getMessage()); + } + + return sb.toString(); + }); + + // ======================================================= + // Print usage information + // ======================================================= + + logger.info("\n========================================"); + logger.info("Genkit MCP Sample Started!"); + logger.info("========================================\n"); + + logger.info("Available flows:"); + logger.info(" - listMcpTools: List all available MCP tools"); + logger.info(" - fileAssistant: AI-powered file operations assistant"); + logger.info(" - readFile: Read a file using MCP filesystem tool"); + logger.info(" - listResources: List resources from an MCP server"); + logger.info(" - toolExplorer: AI assistant with all MCP tools"); + logger.info(" - mcpStatus: Get status of all MCP servers"); + logger.info(" - writeReadDemo: Demo writing and reading a file\n"); + + logger.info("Server running on http://localhost:8080"); + logger.info("Reflection server running on http://localhost:3100"); + logger.info("\nExample requests:"); + logger.info(" curl -X POST http://localhost:8080/listMcpTools -H 'Content-Type: application/json' -d 'null'"); + logger.info( + " curl -X POST http://localhost:8080/fileAssistant -H 'Content-Type: application/json' -d '\"List files in the temp directory\"'"); + logger.info( + " curl -X POST http://localhost:8080/readFile -H 'Content-Type: application/json' -d '\"/tmp/test.txt\"'"); + logger.info( + " curl -X POST http://localhost:8080/listResources -H 'Content-Type: application/json' -d '\"filesystem\"'"); + logger.info( + " curl -X POST http://localhost:8080/toolExplorer -H 'Content-Type: application/json' -d '\"Generate a random number\"'"); + logger.info(" curl -X POST http://localhost:8080/mcpStatus -H 'Content-Type: application/json' -d 'null'"); + logger.info( + " curl -X POST http://localhost:8080/writeReadDemo -H 'Content-Type: application/json' -d '\"Hello from Genkit MCP!\"'"); + + // Add shutdown hook to cleanup MCP connections + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + logger.info("Shutting down..."); + mcpPlugin.disconnect(); + })); + } +} diff --git a/java/samples/mcp/src/main/java/com/google/genkit/samples/MCPServerSample.java b/java/samples/mcp/src/main/java/com/google/genkit/samples/MCPServerSample.java new file mode 100644 index 0000000000..707cb5f733 --- /dev/null +++ b/java/samples/mcp/src/main/java/com/google/genkit/samples/MCPServerSample.java @@ -0,0 +1,287 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.HashMap; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.plugins.mcp.MCPServer; +import com.google.genkit.plugins.mcp.MCPServerOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; + +/** + * Sample application demonstrating an MCP Server built with Genkit. + * + *

+ * This example shows how to: + *

    + *
  • Create a Genkit application with custom tools
  • + *
  • Expose those tools as an MCP server
  • + *
  • Use STDIO transport for integration with Claude Desktop and other MCP + * clients
  • + *
+ * + *

+ * This server exposes several demonstration tools: + *

    + *
  • calculator: Performs basic math operations
  • + *
  • weather: Gets mock weather information
  • + *
  • datetime: Gets current date and time
  • + *
  • greet: Creates personalized greetings
  • + *
  • translate_mock: Mock translation tool
  • + *
+ * + *

+ * To use with Claude Desktop, add to your claude_desktop_config.json: + * + *

{@code
+ * {
+ *   "mcpServers": {
+ *     "genkit-tools": {
+ *       "command": "java",
+ *       "args": ["-jar", "/path/to/genkit-mcp-server-sample.jar"]
+ *     }
+ *   }
+ * }
+ * }
+ * + *

+ * Or run directly for testing: + * + *

+ * mvn exec:java -Dexec.mainClass="com.google.genkit.samples.MCPServerSample"
+ * 
+ */ +public class MCPServerSample { + + private static final Logger logger = LoggerFactory.getLogger(MCPServerSample.class); + + public static void main(String[] args) throws Exception { + // Note: For MCP server mode, we minimize console output since + // STDIO is used for communication with the MCP client. + // Logging goes to stderr which is separate from the protocol. + + logger.info("Initializing Genkit MCP Server..."); + + // ======================================================= + // Create Genkit instance + // ======================================================= + + Genkit genkit = Genkit.builder().options(GenkitOptions.builder().devMode(false) // Disable dev mode for server + .build()).plugin(OpenAIPlugin.create()) // Optional: Include if tools need AI + .build(); + + // ======================================================= + // Define tools to expose via MCP + // ======================================================= + + // Tool 1: Calculator + genkit.defineTool("calculator", "Performs basic math operations (add, subtract, multiply, divide)", + Map.of("type", "object", "properties", + Map.of("operation", + Map.of("type", "string", "description", "The operation to perform", "enum", + new String[]{"add", "subtract", "multiply", "divide"}), + "a", Map.of("type", "number", "description", "First operand"), "b", + Map.of("type", "number", "description", "Second operand")), + "required", new String[]{"operation", "a", "b"}), + (Class>) (Class) Map.class, (ctx, input) -> { + String operation = (String) input.get("operation"); + double a = ((Number) input.get("a")).doubleValue(); + double b = ((Number) input.get("b")).doubleValue(); + + double result; + switch (operation) { + case "add" : + result = a + b; + break; + case "subtract" : + result = a - b; + break; + case "multiply" : + result = a * b; + break; + case "divide" : + if (b == 0) { + throw new IllegalArgumentException("Cannot divide by zero"); + } + result = a / b; + break; + default : + throw new IllegalArgumentException("Unknown operation: " + operation); + } + + Map response = new HashMap<>(); + response.put("operation", operation); + response.put("a", a); + response.put("b", b); + response.put("result", result); + return response; + }); + + // Tool 2: Weather (mock) + genkit.defineTool("get_weather", "Gets the current weather for a location (mock data)", + Map.of("type", "object", "properties", + Map.of("location", Map.of("type", "string", "description", "The city name"), "unit", + Map.of("type", "string", "description", "Temperature unit (celsius or fahrenheit)", + "enum", new String[]{"celsius", "fahrenheit"})), + "required", new String[]{"location"}), + (Class>) (Class) Map.class, (ctx, input) -> { + String location = (String) input.get("location"); + String unit = input.get("unit") != null ? (String) input.get("unit") : "celsius"; + + // Mock weather data + int tempC = (int) (Math.random() * 30) + 5; + int tempF = (int) (tempC * 9.0 / 5.0 + 32); + + String[] conditions = {"Sunny", "Cloudy", "Partly Cloudy", "Rainy", "Windy"}; + String condition = conditions[(int) (Math.random() * conditions.length)]; + + Map weather = new HashMap<>(); + weather.put("location", location); + weather.put("temperature", unit.equals("celsius") ? tempC : tempF); + weather.put("unit", unit); + weather.put("condition", condition); + weather.put("humidity", (int) (Math.random() * 60) + 30 + "%"); + weather.put("note", "This is mock weather data for demonstration purposes"); + return weather; + }); + + // Tool 3: Date/Time + genkit.defineTool("get_datetime", "Gets the current date and time in various formats", + Map.of("type", "object", "properties", Map.of("timezone", + Map.of("type", "string", "description", "Timezone (e.g., UTC, America/New_York)"), "format", + Map.of("type", "string", "description", "Output format (iso, readable, date_only, time_only)")), + "required", new String[]{}), + (Class>) (Class) Map.class, (ctx, input) -> { + String format = input.get("format") != null ? (String) input.get("format") : "readable"; + + LocalDateTime now = LocalDateTime.now(); + String formatted; + + switch (format) { + case "iso" : + formatted = now.format(DateTimeFormatter.ISO_LOCAL_DATE_TIME); + break; + case "date_only" : + formatted = now.format(DateTimeFormatter.ISO_LOCAL_DATE); + break; + case "time_only" : + formatted = now.format(DateTimeFormatter.ofPattern("HH:mm:ss")); + break; + case "readable" : + default : + formatted = now.format(DateTimeFormatter.ofPattern("EEEE, MMMM d, yyyy 'at' h:mm a")); + break; + } + + Map result = new HashMap<>(); + result.put("datetime", formatted); + result.put("format", format); + result.put("timestamp", System.currentTimeMillis()); + return result; + }); + + // Tool 4: Greeting generator + genkit.defineTool("greet", "Creates a personalized greeting message", Map.of("type", "object", "properties", + Map.of("name", Map.of("type", "string", "description", "The name of the person to greet"), "style", + Map.of("type", "string", "description", "Greeting style", "enum", + new String[]{"formal", "casual", "enthusiastic"})), + "required", new String[]{"name"}), (Class>) (Class) Map.class, (ctx, input) -> { + String name = (String) input.get("name"); + String style = input.get("style") != null ? (String) input.get("style") : "casual"; + + String greeting; + switch (style) { + case "formal" : + greeting = "Dear " + name + ", it is a pleasure to make your acquaintance."; + break; + case "enthusiastic" : + greeting = "Hey " + name + "! 🎉 So awesome to meet you! Let's do something amazing!"; + break; + case "casual" : + default : + greeting = "Hi " + name + "! Nice to meet you."; + break; + } + + Map result = new HashMap<>(); + result.put("greeting", greeting); + result.put("name", name); + result.put("style", style); + return result; + }); + + // Tool 5: Mock translator + genkit.defineTool("translate_mock", "Mock translation tool - demonstrates how a translation tool might work", + Map.of("type", "object", "properties", + Map.of("text", Map.of("type", "string", "description", "The text to translate"), + "targetLanguage", + Map.of("type", "string", "description", "Target language code (es, fr, de, ja, etc.)")), + "required", new String[]{"text", "targetLanguage"}), + (Class>) (Class) Map.class, (ctx, input) -> { + String text = (String) input.get("text"); + String targetLang = (String) input.get("targetLanguage"); + + // Mock translations - just add a prefix and note + Map langNames = Map.of("es", "Spanish", "fr", "French", "de", "German", "ja", + "Japanese", "zh", "Chinese", "ko", "Korean", "pt", "Portuguese", "it", "Italian"); + + String langName = langNames.getOrDefault(targetLang, targetLang); + String mockTranslation = "[" + langName + "] " + text + " (mock translation)"; + + Map result = new HashMap<>(); + result.put("originalText", text); + result.put("translatedText", mockTranslation); + result.put("targetLanguage", targetLang); + result.put("targetLanguageName", langName); + result.put("note", "This is a mock translation. In a real implementation, use a translation API."); + return result; + }); + + // ======================================================= + // Create and start MCP Server + // ======================================================= + // Note: genkit.init() is already called by the builder, so we don't need to + // call it again + + MCPServerOptions serverOptions = MCPServerOptions.builder().name("genkit-tools-server").version("1.0.0") + .build(); + + MCPServer mcpServer = new MCPServer(genkit.getRegistry(), serverOptions); + + logger.info("Starting MCP server with STDIO transport..."); + logger.info("Available tools: calculator, get_weather, get_datetime, greet, translate_mock"); + + // Add shutdown hook for cleanup + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + logger.info("Shutting down MCP server..."); + mcpServer.stop(); + })); + + // Start the server (blocks until client disconnects) + mcpServer.start(); + } +} diff --git a/java/samples/mcp/src/main/resources/logback.xml b/java/samples/mcp/src/main/resources/logback.xml new file mode 100644 index 0000000000..9ad4be443b --- /dev/null +++ b/java/samples/mcp/src/main/resources/logback.xml @@ -0,0 +1,55 @@ + + + + + + + + System.err + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java/samples/middleware/README.md b/java/samples/middleware/README.md new file mode 100644 index 0000000000..7274c5cc09 --- /dev/null +++ b/java/samples/middleware/README.md @@ -0,0 +1,168 @@ +# Genkit Java Middleware Sample + +This sample demonstrates how to use middleware in Genkit Java to implement cross-cutting concerns like logging, metrics, caching, rate limiting, and error handling. + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/middleware + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI (Recommended) + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/middleware + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +## Features Demonstrated + +### 1. Logging Middleware +Simple request/response logging using `CommonMiddleware.logging()`. + +### 2. Custom Metrics Middleware +Custom middleware that tracks request counts and response times. + +### 3. Request/Response Transformation +Using `CommonMiddleware.transformRequest()` and `CommonMiddleware.transformResponse()` to sanitize input and format output. + +### 4. Validation Middleware +Using `CommonMiddleware.validate()` to validate input before processing. + +### 5. Retry Middleware +Using `CommonMiddleware.retry()` for automatic retry with exponential backoff. + +### 6. Caching Middleware +Using `CommonMiddleware.cache()` with `SimpleCache` for caching expensive operations. + +### 7. Rate Limiting Middleware +Using `CommonMiddleware.rateLimit()` to limit request frequency. + +### 8. Conditional Middleware +Using `CommonMiddleware.conditional()` to apply middleware only when a condition is met. + +### 9. Before/After Hooks +Using `CommonMiddleware.beforeAfter()` for setup and cleanup operations. + +### 10. Error Handling Middleware +Using `CommonMiddleware.errorHandler()` to gracefully handle errors. + +## Available Endpoints + +| Endpoint | Description | +|----------|-------------| +| `/greeting` | Simple greeting with logging middleware | +| `/chat` | AI chat with multiple middleware | +| `/fact` | AI facts with caching | +| `/joke` | AI jokes with rate limiting | +| `/echo` | Echo with conditional logging | +| `/analyze` | Analysis with timing hooks | +| `/safe` | Demonstrates error handling | +| `/metrics` | View collected metrics | + +## Example Requests + +```bash +# Greeting flow +curl -X POST http://localhost:8080/greeting \ + -H 'Content-Type: application/json' \ + -d '"World"' + +# Chat flow +curl -X POST http://localhost:8080/chat \ + -H 'Content-Type: application/json' \ + -d '"What is the capital of France?"' + +# Fact flow (try twice to see caching) +curl -X POST http://localhost:8080/fact \ + -H 'Content-Type: application/json' \ + -d '"penguins"' + +# Joke flow +curl -X POST http://localhost:8080/joke \ + -H 'Content-Type: application/json' \ + -d '"programming"' + +# Echo flow (with debug logging) +curl -X POST http://localhost:8080/echo \ + -H 'Content-Type: application/json' \ + -d '"debug: test message"' + +# Safe flow (test error handling) +curl -X POST http://localhost:8080/safe \ + -H 'Content-Type: application/json' \ + -d '"error"' + +# View metrics +curl -X POST http://localhost:8080/metrics \ + -H 'Content-Type: application/json' \ + -d 'null' +``` + +## Creating Custom Middleware + +You can create custom middleware by implementing the `Middleware` interface: + +```java +import com.google.genkit.core.middleware.Middleware; + +// Custom middleware that adds a prefix to all requests +Middleware prefixMiddleware = (request, context, next) -> { + String modifiedRequest = "PREFIX: " + request; + return next.apply(modifiedRequest, context); +}; + +// Use it in a flow +List> middleware = List.of(prefixMiddleware); +Flow myFlow = genkit.defineFlow( + "myFlow", String.class, String.class, + (ctx, input) -> "Result: " + input, + middleware +); +``` + +## Architecture + +The middleware system follows the chain of responsibility pattern: + +1. Middleware are executed in order (first added, first executed) +2. Each middleware can: + - Modify the request before passing it to the next middleware + - Modify the response after receiving it from the next middleware + - Short-circuit the chain by not calling `next` + - Handle errors from downstream middleware + +``` +Request -> [MW1] -> [MW2] -> [MW3] -> Action + | +Response <- [MW1] <- [MW2] <- [MW3] <---- +``` + +## See Also + +- [Genkit Documentation](https://github.com/firebase/genkit) +- [JavaScript Middleware Documentation](../../../js/ai/src/model/middleware.ts) diff --git a/java/samples/middleware/pom.xml b/java/samples/middleware/pom.xml new file mode 100644 index 0000000000..0727d33e4b --- /dev/null +++ b/java/samples/middleware/pom.xml @@ -0,0 +1,89 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-middleware + jar + Genkit Middleware Sample + Sample application demonstrating Genkit middleware support + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.5.3 + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.12.1 + + 17 + 17 + + + + org.codehaus.mojo + exec-maven-plugin + 3.2.0 + + com.google.genkit.samples.MiddlewareSample + + + + + diff --git a/java/samples/middleware/run.sh b/java/samples/middleware/run.sh new file mode 100755 index 0000000000..defed56975 --- /dev/null +++ b/java/samples/middleware/run.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# Run the Genkit Middleware Sample + +set -e + +# Navigate to the sample directory +cd "$(dirname "$0")" + +# Check for OPENAI_API_KEY +if [ -z "$OPENAI_API_KEY" ]; then + echo "Warning: OPENAI_API_KEY is not set. The sample may not work correctly." + echo "Set it with: export OPENAI_API_KEY=your-api-key" +fi + +# Build and run +echo "Building and running Genkit Middleware Sample..." +mvn compile exec:java -q diff --git a/java/samples/middleware/src/main/java/com/google/genkit/samples/MiddlewareSample.java b/java/samples/middleware/src/main/java/com/google/genkit/samples/MiddlewareSample.java new file mode 100644 index 0000000000..8dbb0d9a7f --- /dev/null +++ b/java/samples/middleware/src/main/java/com/google/genkit/samples/MiddlewareSample.java @@ -0,0 +1,298 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.*; +import com.google.genkit.core.Flow; +import com.google.genkit.core.GenkitException; +import com.google.genkit.core.middleware.*; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; + +/** + * Sample application demonstrating Genkit middleware support. + * + *

+ * This example shows how to: + *

    + *
  • Create custom middleware for logging, timing, and metrics
  • + *
  • Use built-in middleware from CommonMiddleware
  • + *
  • Chain multiple middleware together
  • + *
  • Apply middleware to flows
  • + *
  • Create reusable middleware for cross-cutting concerns
  • + *
+ * + *

+ * To run: + *

    + *
  1. Set the OPENAI_API_KEY environment variable
  2. + *
  3. Run: mvn exec:java
  4. + *
+ */ +public class MiddlewareSample { + + private static final Logger logger = LoggerFactory.getLogger(MiddlewareSample.class); + + // Metrics storage for demonstration + private static final Map requestCounts = new ConcurrentHashMap<>(); + private static final Map> responseTimes = new ConcurrentHashMap<>(); + + public static void main(String[] args) throws Exception { + // Create the Jetty server plugin + JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build()); + + // Create Genkit with plugins + Genkit genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(OpenAIPlugin.create()).plugin(jetty).build(); + + // ======================================================= + // Example 1: Simple logging middleware + // ======================================================= + + // Create a list of middleware for the flow + List> loggingMiddleware = List.of(CommonMiddleware.logging("greeting")); + + // Define flow with logging middleware + Flow greetingFlow = genkit.defineFlow("greeting", String.class, String.class, + (ctx, name) -> "Hello, " + name + "!", loggingMiddleware); + + // ======================================================= + // Example 2: Custom metrics middleware + // ======================================================= + + // Custom middleware that collects metrics + Middleware metricsMiddleware = (request, context, next) -> { + String flowName = "chat"; + requestCounts.computeIfAbsent(flowName, k -> new AtomicLong(0)).incrementAndGet(); + + long start = System.currentTimeMillis(); + try { + String result = next.apply(request, context); + long duration = System.currentTimeMillis() - start; + responseTimes.computeIfAbsent(flowName, k -> new ArrayList<>()).add(duration); + return result; + } catch (GenkitException e) { + logger.error("Flow {} failed: {}", flowName, e.getMessage()); + throw e; + } + }; + + // ======================================================= + // Example 3: Request/Response transformation middleware + // ======================================================= + + // Middleware that sanitizes input (removes extra whitespace, trims) + Middleware sanitizeMiddleware = CommonMiddleware.transformRequest(input -> { + if (input == null) { + return ""; + } + return input.trim().replaceAll("\\s+", " "); + }); + + // Middleware that formats output + Middleware formatMiddleware = CommonMiddleware.transformResponse(output -> { + return "[" + Instant.now() + "] " + output; + }); + + // ======================================================= + // Example 4: Validation middleware + // ======================================================= + + Middleware validationMiddleware = CommonMiddleware.validate(input -> { + if (input == null || input.trim().isEmpty()) { + throw new GenkitException("Input cannot be empty"); + } + if (input.length() > 1000) { + throw new GenkitException("Input too long (max 1000 characters)"); + } + }); + + // ======================================================= + // Example 5: Chat flow with multiple middleware + // ======================================================= + + // Combine multiple middleware + List> chatMiddleware = List.of(CommonMiddleware.logging("chat"), // Log + // requests/responses + metricsMiddleware, // Collect metrics + sanitizeMiddleware, // Sanitize input + validationMiddleware, // Validate input + CommonMiddleware.retry(2, 100) // Retry on failure + ); + + // Define chat flow with middleware chain + Flow chatFlow = genkit.defineFlow("chat", String.class, String.class, + (ctx, userMessage) -> { + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .system("You are a helpful assistant. Be concise.").prompt(userMessage) + .config(GenerationConfig.builder().temperature(0.7).maxOutputTokens(200).build()).build()); + return response.getText(); + }, chatMiddleware); + + // ======================================================= + // Example 6: Caching middleware for expensive operations + // ======================================================= + + // Create a cache with 5 minute TTL + SimpleCache factCache = new SimpleCache<>(5 * 60 * 1000); + + List> factMiddleware = List.of(CommonMiddleware.logging("fact"), + CommonMiddleware.cache(factCache, request -> request.toLowerCase()) // Cache by lowercase input + ); + + Flow factFlow = genkit.defineFlow("fact", String.class, String.class, (ctx, topic) -> { + logger.info("Generating fact for: {} (not cached)", topic); + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .prompt("Give me an interesting fact about: " + topic + ". Keep it to one sentence.").build()); + return response.getText(); + }, factMiddleware); + + // ======================================================= + // Example 7: Rate limiting middleware + // ======================================================= + + List> rateLimitedMiddleware = List.of(CommonMiddleware.logging("joke"), + CommonMiddleware.rateLimit(10, 60000) // Max 10 requests per minute + ); + + Flow jokeFlow = genkit.defineFlow("joke", String.class, String.class, (ctx, topic) -> { + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .prompt("Tell me a short, funny joke about: " + topic) + .config(GenerationConfig.builder().temperature(0.9).build()).build()); + return response.getText(); + }, rateLimitedMiddleware); + + // ======================================================= + // Example 8: Conditional middleware + // ======================================================= + + // Only log if the request contains "debug" + Middleware conditionalLogging = CommonMiddleware.conditional( + (request, ctx) -> request.toLowerCase().contains("debug"), CommonMiddleware.logging("debug-echo")); + + List> echoMiddleware = List.of(conditionalLogging); + + Flow echoFlow = genkit.defineFlow("echo", String.class, String.class, + (ctx, input) -> "Echo: " + input, echoMiddleware); + + // ======================================================= + // Example 9: Before/After hooks + // ======================================================= + + List> hookMiddleware = List.of(CommonMiddleware.beforeAfter( + (request, ctx) -> logger.info("🚀 Starting analysis of: {}", request), + (response, ctx) -> logger.info("✅ Analysis complete, response length: {} chars", response.length())), + CommonMiddleware.timing(duration -> { + logger.info("⏱️ Analysis took {}ms", duration); + })); + + Flow analyzeFlow = genkit.defineFlow("analyze", String.class, String.class, + (ctx, topic) -> { + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .prompt("Provide a brief analysis of the topic: " + topic) + .config(GenerationConfig.builder().maxOutputTokens(300).build()).build()); + return response.getText(); + }, hookMiddleware); + + // ======================================================= + // Example 10: Error handling middleware + // ======================================================= + + Middleware errorHandling = CommonMiddleware.errorHandler(e -> { + logger.error("Flow failed with error: {}", e.getMessage()); + return "Sorry, I encountered an error: " + e.getMessage(); + }); + + List> safeMiddleware = List.of(errorHandling, // This goes first to catch errors from + // other middleware + CommonMiddleware.logging("safe")); + + Flow safeFlow = genkit.defineFlow("safe", String.class, String.class, (ctx, input) -> { + if (input.equals("error")) { + throw new GenkitException("Intentional error for demonstration"); + } + return "Safe result: " + input; + }, safeMiddleware); + + // ======================================================= + // Example 11: Metrics endpoint flow + // ======================================================= + + Flow metricsFlow = genkit.defineFlow("metrics", Void.class, String.class, (ctx, input) -> { + StringBuilder sb = new StringBuilder(); + sb.append("=== Middleware Sample Metrics ===\n\n"); + + sb.append("Request Counts:\n"); + requestCounts + .forEach((flow, count) -> sb.append(" ").append(flow).append(": ").append(count).append("\n")); + + sb.append("\nAverage Response Times:\n"); + responseTimes.forEach((flow, times) -> { + if (!times.isEmpty()) { + double avg = times.stream().mapToLong(Long::longValue).average().orElse(0); + sb.append(" ").append(flow).append(": ").append(String.format("%.2f", avg)).append("ms\n"); + } + }); + + return sb.toString(); + }); + + // Initialize Genkit + genkit.init(); + + logger.info("\n========================================"); + logger.info("Genkit Middleware Sample Started!"); + logger.info("========================================\n"); + + logger.info("Available flows:"); + logger.info(" - greeting: Simple flow with logging middleware"); + logger.info(" - chat: AI chat with multiple middleware (logging, metrics, validation, retry)"); + logger.info(" - fact: AI facts with caching middleware"); + logger.info(" - joke: AI jokes with rate limiting middleware"); + logger.info(" - echo: Echo with conditional logging"); + logger.info(" - analyze: Analysis with before/after hooks and timing"); + logger.info(" - safe: Demonstrates error handling middleware"); + logger.info(" - metrics: View collected metrics\n"); + + logger.info("Server running on http://localhost:8080"); + logger.info("Reflection server running on http://localhost:3100"); + logger.info("\nExample requests:"); + logger.info(" curl -X POST http://localhost:8080/greeting -H 'Content-Type: application/json' -d '\"World\"'"); + logger.info( + " curl -X POST http://localhost:8080/chat -H 'Content-Type: application/json' -d '\"What is the capital of France?\"'"); + logger.info(" curl -X POST http://localhost:8080/fact -H 'Content-Type: application/json' -d '\"penguins\"'"); + logger.info( + " curl -X POST http://localhost:8080/joke -H 'Content-Type: application/json' -d '\"programming\"'"); + logger.info(" curl -X POST http://localhost:8080/safe -H 'Content-Type: application/json' -d '\"error\"'"); + logger.info(" curl -X POST http://localhost:8080/metrics -H 'Content-Type: application/json' -d 'null'"); + } +} diff --git a/java/samples/middleware/src/main/resources/logback.xml b/java/samples/middleware/src/main/resources/logback.xml new file mode 100644 index 0000000000..fe98c37a85 --- /dev/null +++ b/java/samples/middleware/src/main/resources/logback.xml @@ -0,0 +1,26 @@ + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + + + + diff --git a/java/samples/multi-agent/README.md b/java/samples/multi-agent/README.md new file mode 100644 index 0000000000..fc1af0b05b --- /dev/null +++ b/java/samples/multi-agent/README.md @@ -0,0 +1,204 @@ +# Genkit Multi-Agent Sample + +This sample demonstrates multi-agent orchestration patterns using Genkit Java, where specialized agents handle different domains and a triage agent routes requests. + +## Features Demonstrated + +- **Multi-Agent Architecture** - Triage agent routing to specialized agents +- **Specialized Agents** - Reservation, menu, and order agents +- **Agent-as-Tool Pattern** - Agents can be used as tools for delegation +- **Session Management** - Track customer state across interactions +- **Tool Integration** - Agents with domain-specific tools + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/multi-agent + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/multi-agent + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Customer Request │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Triage Agent │ +│ Routes requests to specialized agents based on intent │ +└─────────────────────────────────────────────────────────────┘ + │ │ │ + ▼ ▼ ▼ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ Reservation │ │ Menu │ │ Order │ +│ Agent │ │ Agent │ │ Agent │ +│ │ │ │ │ │ +│ • makeRes │ │ • getMenu │ │ • placeOrder │ +│ • cancelRes │ │ • getDietInfo │ │ • getOrderStatus│ +└─────────────────┘ └─────────────────┘ └─────────────────┘ +``` + +## Agents + +### Triage Agent +The main entry point that analyzes customer requests and routes them to the appropriate specialized agent. + +### Reservation Agent +Handles table reservations: +- Make new reservations +- Cancel existing reservations +- Check availability + +### Menu Agent +Handles menu-related queries: +- Get menu items +- Dietary information +- Recommendations + +### Order Agent +Handles food orders: +- Place orders +- Check order status +- Modify orders + +## Available Tools + +| Tool | Agent | Description | +|------|-------|-------------| +| `makeReservation` | Reservation | Makes a new reservation | +| `cancelReservation` | Reservation | Cancels an existing reservation | +| `getMenu` | Menu | Returns menu items | +| `placeOrder` | Order | Places a food order | + +## Example Interactions + +The sample runs as an interactive CLI application: + +``` +🍽️ Welcome to the Restaurant! +Type 'quit' to exit. + +You: I'd like to make a reservation for 4 people tomorrow at 7pm + +Agent: I'd be happy to help you with your reservation. Let me set that up for you. + +[Reservation Agent handles the request] + +Reservation confirmed! Your confirmation number is RES-1234. +- Date: 2024-01-16 +- Time: 19:00 +- Party size: 4 + +You: What's on the menu? + +Agent: [Menu Agent handles the request] + +Here's our current menu: +- Appetizers: ... +- Main Courses: ... +- Desserts: ... +``` + +## Session State + +The sample tracks customer state across interactions: + +```java +public class CustomerState { + private String customerId; + private String currentAgent; + private List reservations; + private List orders; +} +``` + +## Code Highlights + +### Defining an Agent + +```java +Agent reservationAgent = genkit.defineAgent( + AgentConfig.builder() + .name("reservationAgent") + .model("openai/gpt-4o") + .system("You are a helpful reservation agent for a restaurant...") + .tools(List.of(makeReservationTool, cancelReservationTool)) + .build()); +``` + +### Agent-as-Tool Pattern + +```java +// Agents can be used as tools for delegation +Tool reservationAgentTool = reservationAgent.asTool(); + +Agent triageAgent = genkit.defineAgent( + AgentConfig.builder() + .name("triageAgent") + .model("openai/gpt-4o") + .system("Route requests to the appropriate agent...") + .tools(List.of(reservationAgentTool, menuAgentTool, orderAgentTool)) + .build()); +``` + +### Session-Based Chat + +```java +Session session = genkit.createSession( + SessionOptions.builder() + .sessionStore(sessionStore) + .initialState(new CustomerState()) + .build()); + +Chat chat = session.chat(ChatOptions.builder() + .model("openai/gpt-4o") + .agent(triageAgent) + .build()); + +String response = chat.send("I'd like to make a reservation"); +``` + +## Development UI + +When running with `genkit start`, access the Dev UI at http://localhost:4000 to: + +- View registered agents and tools +- Test individual agents +- Inspect traces showing agent routing +- View tool calls and responses + +## See Also + +- [Genkit Java README](../../README.md) +- [Chat Sessions Sample](../chat-session/README.md) +- [Interrupts Sample](../interrupts/README.md) diff --git a/java/samples/multi-agent/pom.xml b/java/samples/multi-agent/pom.xml new file mode 100644 index 0000000000..00a9fcddc6 --- /dev/null +++ b/java/samples/multi-agent/pom.xml @@ -0,0 +1,80 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-multi-agent + jar + Genkit Multi-Agent Sample + Sample application demonstrating multi-agent patterns with agent delegation + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.4.14 + + + + + + + org.codehaus.mojo + exec-maven-plugin + 3.1.0 + + com.google.genkit.samples.MultiAgentApp + + + + + diff --git a/java/samples/multi-agent/run.sh b/java/samples/multi-agent/run.sh new file mode 100755 index 0000000000..7a055a49ca --- /dev/null +++ b/java/samples/multi-agent/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Run script for Genkit DotPrompt Sample +cd "$(dirname "$0")" +mvn exec:java diff --git a/java/samples/multi-agent/src/main/java/com/google/genkit/samples/MultiAgentApp.java b/java/samples/multi-agent/src/main/java/com/google/genkit/samples/MultiAgentApp.java new file mode 100644 index 0000000000..99a62d56a6 --- /dev/null +++ b/java/samples/multi-agent/src/main/java/com/google/genkit/samples/MultiAgentApp.java @@ -0,0 +1,413 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Scanner; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.Agent; +import com.google.genkit.ai.AgentConfig; +import com.google.genkit.ai.GenerationConfig; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.session.Chat; +import com.google.genkit.ai.session.ChatOptions; +import com.google.genkit.ai.session.InMemorySessionStore; +import com.google.genkit.ai.session.Session; +import com.google.genkit.ai.session.SessionOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; + +/** + * Multi-Agent Customer Service Application. + * + *

+ * This sample demonstrates the multi-agent pattern where: + *

    + *
  • A triage agent routes requests to specialized agents
  • + *
  • Specialized agents handle specific domains (reservations, menu, + * etc.)
  • + *
  • Agents can be used as tools for delegation
  • + *
+ * + *

+ * To run: + *

    + *
  1. Set the OPENAI_API_KEY environment variable
  2. + *
  3. Run: mvn exec:java -pl samples/multi-agent
  4. + *
+ */ +public class MultiAgentApp { + + /** Customer state for tracking context. */ + public static class CustomerState { + private String customerId; + private String currentAgent; + private List reservations = new ArrayList<>(); + private List orders = new ArrayList<>(); + + public CustomerState() { + this.customerId = "customer-" + System.currentTimeMillis(); + this.currentAgent = "triage"; + } + + public String getCustomerId() { + return customerId; + } + + public String getCurrentAgent() { + return currentAgent; + } + + public void setCurrentAgent(String agent) { + this.currentAgent = agent; + } + + public List getReservations() { + return reservations; + } + + public void addReservation(String reservation) { + this.reservations.add(reservation); + } + + public List getOrders() { + return orders; + } + + public void addOrder(String order) { + this.orders.add(order); + } + + @Override + public String toString() { + return String.format("Customer: %s, Agent: %s, Reservations: %d, Orders: %d", customerId, currentAgent, + reservations.size(), orders.size()); + } + } + + private final Genkit genkit; + private final InMemorySessionStore sessionStore; + + // Agents + private Agent triageAgent; + private Agent reservationAgent; + private Agent menuAgent; + private Agent orderAgent; + + // Tools + private Tool makeReservationTool; + private Tool cancelReservationTool; + private Tool getMenuTool; + private Tool placeOrderTool; + + public MultiAgentApp() { + // Initialize Genkit + this.genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3101).build()) + .plugin(OpenAIPlugin.create()).build(); + + this.sessionStore = new InMemorySessionStore<>(); + + // Initialize tools and agents + initializeTools(); + initializeAgents(); + } + + @SuppressWarnings("unchecked") + private void initializeTools() { + // Reservation Tool + makeReservationTool = genkit.defineTool("makeReservation", "Makes a restaurant reservation for the customer", + Map.of("type", "object", "properties", + Map.of("date", Map.of("type", "string", "description", "Date in YYYY-MM-DD format"), "time", + Map.of("type", "string", "description", "Time in HH:MM format"), "partySize", + Map.of("type", "integer", "description", "Number of guests")), + "required", new String[]{"date", "time", "partySize"}), + (Class>) (Class) Map.class, (ctx, input) -> { + String date = (String) input.get("date"); + String time = (String) input.get("time"); + Integer partySize = (Integer) input.get("partySize"); + String confirmationId = "RES-" + System.currentTimeMillis() % 10000; + + Map result = new HashMap<>(); + result.put("status", "confirmed"); + result.put("confirmationId", confirmationId); + result.put("date", date); + result.put("time", time); + result.put("partySize", partySize); + result.put("message", + String.format("Reservation confirmed for %d guests on %s at %s. Confirmation: %s", + partySize, date, time, confirmationId)); + return result; + }); + + // Cancel Reservation Tool + cancelReservationTool = genkit.defineTool("cancelReservation", "Cancels an existing reservation", + Map.of("type", "object", "properties", + Map.of("confirmationId", + Map.of("type", "string", "description", "The reservation confirmation ID")), + "required", new String[]{"confirmationId"}), + (Class>) (Class) Map.class, (ctx, input) -> { + String confirmationId = (String) input.get("confirmationId"); + Map result = new HashMap<>(); + result.put("status", "cancelled"); + result.put("confirmationId", confirmationId); + result.put("message", "Reservation " + confirmationId + " has been cancelled."); + return result; + }); + + // Menu Tool + getMenuTool = genkit.defineTool("getMenu", "Gets the current restaurant menu", + Map.of("type", "object", "properties", + Map.of("category", + Map.of("type", "string", "description", + "Menu category: appetizers, mains, desserts, drinks, or all", "enum", + new String[]{"appetizers", "mains", "desserts", "drinks", "all"}))), + (Class>) (Class) Map.class, (ctx, input) -> { + String category = input.get("category") != null ? (String) input.get("category") : "all"; + Map menu = new HashMap<>(); + + if (category.equals("all") || category.equals("appetizers")) { + menu.put("appetizers", List.of( + Map.of("name", "Bruschetta", "price", 8.99, "description", + "Toasted bread with tomatoes"), + Map.of("name", "Calamari", "price", 12.99, "description", "Fried squid rings"))); + } + if (category.equals("all") || category.equals("mains")) { + menu.put("mains", List.of( + Map.of("name", "Grilled Salmon", "price", 24.99, "description", + "Atlantic salmon with herbs"), + Map.of("name", "Ribeye Steak", "price", 32.99, "description", "12oz prime ribeye"), + Map.of("name", "Pasta Primavera", "price", 18.99, "description", + "Seasonal vegetables"))); + } + if (category.equals("all") || category.equals("desserts")) { + menu.put("desserts", + List.of(Map.of("name", "Tiramisu", "price", 9.99, "description", + "Classic Italian dessert"), + Map.of("name", "Cheesecake", "price", 8.99, "description", "NY style"))); + } + if (category.equals("all") || category.equals("drinks")) { + menu.put("drinks", + List.of(Map.of("name", "House Wine", "price", 8.99, "description", "Red or white"), + Map.of("name", "Craft Beer", "price", 6.99, "description", "Local selection"))); + } + + return menu; + }); + + // Order Tool + placeOrderTool = genkit.defineTool("placeOrder", "Places a food order for pickup or delivery", + Map.of("type", "object", "properties", + Map.of("items", + Map.of("type", "array", "items", Map.of("type", "string"), "description", + "List of menu item names to order"), + "orderType", + Map.of("type", "string", "description", "pickup or delivery", "enum", + new String[]{"pickup", "delivery"})), + "required", new String[]{"items", "orderType"}), + (Class>) (Class) Map.class, (ctx, input) -> { + @SuppressWarnings("unchecked") + List items = (List) input.get("items"); + String orderType = (String) input.get("orderType"); + String orderId = "ORD-" + System.currentTimeMillis() % 10000; + + Map result = new HashMap<>(); + result.put("status", "confirmed"); + result.put("orderId", orderId); + result.put("items", items); + result.put("orderType", orderType); + result.put("estimatedTime", orderType.equals("pickup") ? "20 minutes" : "45 minutes"); + result.put("message", + String.format("Order %s placed for %s. Items: %s. Ready in %s.", orderId, orderType, + String.join(", ", items), + orderType.equals("pickup") ? "20 minutes" : "45 minutes")); + return result; + }); + } + + @SuppressWarnings("unchecked") + private void initializeAgents() { + // Reservation Agent - handles booking and cancellation + // Note: genkit.defineAgent automatically registers the agent + reservationAgent = genkit.defineAgent(AgentConfig.builder().name("reservationAgent") + .description("Handles restaurant reservations. Transfer to this agent when the customer " + + "wants to make, modify, or cancel a reservation.") + .system("You are a reservation specialist for an upscale restaurant. " + + "Help customers make, modify, or cancel reservations. " + + "Always confirm the date, time, and party size before making a reservation. " + + "Be professional and courteous.") + .model("openai/gpt-4o-mini").tools(List.of(makeReservationTool, cancelReservationTool)) + .config(GenerationConfig.builder().temperature(0.3).build()).build()); + + // Menu Agent - provides menu information + menuAgent = genkit.defineAgent(AgentConfig.builder().name("menuAgent") + .description("Provides menu information. Transfer to this agent when the customer " + + "wants to know about menu items, prices, or recommendations.") + .system("You are a menu expert at an upscale restaurant. " + + "Help customers explore the menu, understand dishes, and get recommendations. " + + "Use the getMenu tool to retrieve current menu items. " + + "Be knowledgeable about ingredients and preparation methods.") + .model("openai/gpt-4o-mini").tools(List.of(getMenuTool)) + .config(GenerationConfig.builder().temperature(0.5).build()).build()); + + // Order Agent - handles food orders + orderAgent = genkit.defineAgent(AgentConfig.builder().name("orderAgent") + .description("Handles food orders for pickup or delivery. Transfer to this agent when " + + "the customer wants to place an order.") + .system("You are an order specialist for a restaurant. " + + "Help customers place orders for pickup or delivery. " + + "Confirm all items before placing the order. " + "Provide accurate time estimates.") + .model("openai/gpt-4o-mini").tools(List.of(placeOrderTool, getMenuTool)) + .config(GenerationConfig.builder().temperature(0.3).build()).build()); + + // Triage Agent - routes to specialized agents + triageAgent = genkit.defineAgent(AgentConfig.builder().name("triageAgent") + .description("Main customer service agent that routes requests to specialists") + .system("You are the main customer service agent for The Golden Fork restaurant. " + + "Your job is to understand what the customer needs and transfer them to the right specialist.\n\n" + + "IMPORTANT: To transfer to another agent, you MUST call the appropriate agent tool. " + + "Do NOT just say you are transferring - you must actually invoke the tool:\n" + + "- reservationAgent: for reservations (booking, canceling, modifying)\n" + + "- menuAgent: for menu questions, recommendations, or dietary info\n" + + "- orderAgent: for placing orders (pickup or delivery)\n\n" + + "When a customer needs help with a specific task, call the corresponding agent tool immediately. " + + "You can handle general greetings and questions, but for specific tasks, always use the tools.") + .model("openai/gpt-4o-mini") + .agents(List.of(reservationAgent.getConfig(), menuAgent.getConfig(), orderAgent.getConfig())) + .config(GenerationConfig.builder().temperature(0.7).build()).build()); + } + + /** Creates a chat session with the triage agent. */ + @SuppressWarnings("unchecked") + public Chat createChat() { + Session session = genkit.createSession( + SessionOptions.builder().store(sessionStore).initialState(new CustomerState()).build()); + + // Get all tools including sub-agents as tools - Genkit handles the registry + List> allTools = genkit.getAllToolsForAgent(triageAgent); + + // Agent registry is automatically available from the session - no need to pass + // explicitly + return session.chat(ChatOptions.builder().model("openai/gpt-4o-mini") + .system(triageAgent.getSystem()).tools(allTools).build()); + } + + /** Interactive chat loop. */ + public void runInteractive() { + Scanner scanner = new Scanner(System.in); + + System.out.println("╔════════════════════════════════════════════════════════════════╗"); + System.out.println("║ The Golden Fork Restaurant - Multi-Agent Customer Service ║"); + System.out.println("╚════════════════════════════════════════════════════════════════╝"); + System.out.println(); + System.out.println("Available agents:"); + System.out.println(" • Triage Agent - Routes your requests"); + System.out.println(" • Reservation Agent - Handles bookings"); + System.out.println(" • Menu Agent - Menu information and recommendations"); + System.out.println(" • Order Agent - Pickup and delivery orders"); + System.out.println(); + System.out.println("Commands:"); + System.out.println(" /status - Show current state"); + System.out.println(" /quit - Exit"); + System.out.println(); + System.out.println("How can we help you today?\n"); + + Chat chat = createChat(); + + while (true) { + System.out.print("You: "); + String input = scanner.nextLine().trim(); + + if (input.isEmpty()) + continue; + + if (input.equals("/quit") || input.equals("/exit")) { + System.out.println("\nThank you for visiting The Golden Fork!"); + break; + } + + if (input.equals("/status")) { + String currentAgent = chat.getCurrentAgentName(); + System.out.println("\nState: " + chat.getSession().getState()); + System.out + .println("Current Agent: " + (currentAgent != null ? currentAgent : "triage (default)") + "\n"); + continue; + } + + try { + String response = chat.send(input).getText(); + System.out.println("\nAssistant: " + response + "\n"); + } catch (Exception e) { + System.out.println("\nError: " + e.getMessage() + "\n"); + } + } + + scanner.close(); + } + + /** Demo mode. */ + public void runDemo() { + System.out.println("╔════════════════════════════════════════════════════════════════╗"); + System.out.println("║ Multi-Agent Demo - Restaurant Customer Service ║"); + System.out.println("╚════════════════════════════════════════════════════════════════╝"); + System.out.println(); + + Chat chat = createChat(); + + // Demo conversation + String[] messages = {"Hi, I'd like to make a reservation for this weekend", "Saturday at 7pm for 4 people", + "Thanks! Also, what's on your dessert menu?", + "I'd like to place a pickup order for the Tiramisu and Cheesecake"}; + + for (String message : messages) { + System.out.println("Customer: " + message); + try { + String response = chat.send(message).getText(); + System.out.println("\nAssistant: " + truncate(response, 300) + "\n"); + Thread.sleep(1000); // Pause for readability + } catch (Exception e) { + System.out.println("Error: " + e.getMessage()); + } + } + + System.out.println("\n=== Demo Complete ==="); + System.out.println("Final state: " + chat.getSession().getState()); + } + + private String truncate(String text, int maxLength) { + if (text == null || text.length() <= maxLength) + return text; + return text.substring(0, maxLength) + "..."; + } + + public static void main(String[] args) { + MultiAgentApp app = new MultiAgentApp(); + + boolean demoMode = args.length > 0 && args[0].equals("--demo"); + + if (demoMode) { + app.runDemo(); + } else { + app.runInteractive(); + } + } +} diff --git a/java/samples/multi-agent/src/main/resources/logback.xml b/java/samples/multi-agent/src/main/resources/logback.xml new file mode 100644 index 0000000000..d63c14f8a8 --- /dev/null +++ b/java/samples/multi-agent/src/main/resources/logback.xml @@ -0,0 +1,13 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + diff --git a/java/samples/openai/README.md b/java/samples/openai/README.md new file mode 100644 index 0000000000..8fcec3175a --- /dev/null +++ b/java/samples/openai/README.md @@ -0,0 +1,179 @@ +# Genkit OpenAI Sample + +This sample demonstrates basic integration with OpenAI models using Genkit Java. + +## Features Demonstrated + +- **OpenAI Plugin Setup** - Configure Genkit with OpenAI models +- **Flow Definitions** - Create observable, traceable AI workflows +- **Tool Usage** - Define and use tools with automatic execution +- **Text Generation** - Generate text with GPT-4o and GPT-4o-mini +- **Streaming** - Real-time response streaming +- **Vision Models** - Process images with vision capabilities +- **Image Generation** - Generate images with DALL-E + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/openai + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI (Recommended) + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/openai + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +## Available Flows + +| Flow | Input | Output | Description | +|------|-------|--------|-------------| +| `greeting` | String (name) | String | Simple greeting flow | +| `tellJoke` | String (topic) | String | Generate a joke about a topic | +| `chat` | String (message) | String | Chat with GPT-4o | +| `weatherAssistant` | String (query) | String | Weather assistant using tools | + +## Example API Calls + +Once the server is running on port 8080: + +### Simple Greeting +```bash +curl -X POST http://localhost:8080/greeting \ + -H 'Content-Type: application/json' \ + -d '"World"' +``` + +### Generate a Joke +```bash +curl -X POST http://localhost:8080/tellJoke \ + -H 'Content-Type: application/json' \ + -d '"programming"' +``` + +### Chat +```bash +curl -X POST http://localhost:8080/chat \ + -H 'Content-Type: application/json' \ + -d '"What is the capital of France?"' +``` + +### Weather Assistant (with Tool) +```bash +curl -X POST http://localhost:8080/weatherAssistant \ + -H 'Content-Type: application/json' \ + -d '"What is the weather in Paris?"' +``` + +## Available Models + +The OpenAI plugin provides access to: + +| Model | Description | +|-------|-------------| +| `openai/gpt-4o` | Most capable model, best for complex tasks | +| `openai/gpt-4o-mini` | Faster and more cost-effective | +| `openai/gpt-4-turbo` | Previous generation GPT-4 | +| `openai/gpt-3.5-turbo` | Fast and economical | +| `openai/dall-e-3` | Image generation | +| `openai/text-embedding-3-small` | Text embeddings | +| `openai/text-embedding-3-large` | High-dimension text embeddings | + +## Code Highlights + +### Setting Up Genkit with OpenAI + +```java +Genkit genkit = Genkit.builder() + .options(GenkitOptions.builder() + .devMode(true) + .reflectionPort(3100) + .build()) + .plugin(OpenAIPlugin.create()) + .plugin(new JettyPlugin(JettyPluginOptions.builder() + .port(8080) + .build())) + .build(); +``` + +### Defining a Flow + +```java +Flow jokeFlow = genkit.defineFlow( + "tellJoke", String.class, String.class, + (ctx, topic) -> { + ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt("Tell me a short, funny joke about: " + topic) + .config(GenerationConfig.builder() + .temperature(0.9) + .maxOutputTokens(200) + .build()) + .build()); + return response.getText(); + }); +``` + +### Defining and Using Tools + +```java +Tool, Map> weatherTool = genkit.defineTool( + "getWeather", + "Gets the current weather for a location", + Map.of("type", "object", "properties", + Map.of("location", Map.of("type", "string")), + "required", new String[]{"location"}), + (Class>) (Class) Map.class, + (ctx, input) -> { + String location = (String) input.get("location"); + return Map.of("location", location, "temperature", "72°F"); + }); + +// Use tool in generation +ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o") + .prompt("What's the weather in Paris?") + .tools(List.of(weatherTool)) + .build()); +``` + +## Development UI + +When running with `genkit start`, access the Dev UI at http://localhost:4000 to: + +- Browse all registered flows, tools, and models +- Run flows with test inputs +- View execution traces and logs +- Inspect tool calls and responses + +## See Also + +- [Genkit Java README](../../README.md) +- [OpenAI API Documentation](https://platform.openai.com/docs) diff --git a/java/samples/openai/pom.xml b/java/samples/openai/pom.xml new file mode 100644 index 0000000000..e71990216c --- /dev/null +++ b/java/samples/openai/pom.xml @@ -0,0 +1,89 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-openai + jar + Genkit OpenAI Sample + Sample application demonstrating Genkit with OpenAI + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.5.3 + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.12.1 + + 17 + 17 + + + + org.codehaus.mojo + exec-maven-plugin + 3.2.0 + + com.google.genkit.samples.OpenAISample + + + + + diff --git a/java/samples/openai/run.sh b/java/samples/openai/run.sh new file mode 100755 index 0000000000..7a055a49ca --- /dev/null +++ b/java/samples/openai/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Run script for Genkit DotPrompt Sample +cd "$(dirname "$0")" +mvn exec:java diff --git a/java/samples/openai/src/main/java/com/google/genkit/samples/OpenAISample.java b/java/samples/openai/src/main/java/com/google/genkit/samples/OpenAISample.java new file mode 100644 index 0000000000..e599bc1afb --- /dev/null +++ b/java/samples/openai/src/main/java/com/google/genkit/samples/OpenAISample.java @@ -0,0 +1,270 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.*; +import com.google.genkit.core.Flow; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; + +/** + * Sample application demonstrating Genkit with OpenAI. + * + * This example shows how to: - Configure Genkit with the OpenAI plugin - Define + * flows - Use tools - Generate text with OpenAI models - Expose flows via HTTP + * endpoints - Process images with vision models - Generate images with DALL-E + * + * To run: 1. Set the OPENAI_API_KEY environment variable 2. Run: mvn exec:java + */ +public class OpenAISample { + + public static void main(String[] args) throws Exception { + // Create the Jetty server plugin + JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build()); + + // Create Genkit with plugins + Genkit genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(OpenAIPlugin.create()).plugin(jetty).build(); + + // Define a simple greeting flow + Flow greetingFlow = genkit.defineFlow("greeting", String.class, String.class, + (name) -> "Hello, " + name + "!"); + + // Define a joke generator flow using OpenAI + Flow jokeFlow = genkit.defineFlow("tellJoke", String.class, String.class, + (ctx, topic) -> { + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .prompt("Tell me a short, funny joke about: " + topic) + .config(GenerationConfig.builder().temperature(0.9).maxOutputTokens(200).build()).build()); + + return response.getText(); + }); + + // Define a tool for getting current weather (mock implementation) + @SuppressWarnings("unchecked") + Tool, Map> weatherTool = genkit.defineTool("getWeather", + "Gets the current weather for a location", + Map.of("type", "object", "properties", + Map.of("location", Map.of("type", "string", "description", "The city name")), "required", + new String[]{"location"}), + (Class>) (Class) Map.class, (ctx, input) -> { + String location = (String) input.get("location"); + Map weather = new HashMap<>(); + weather.put("location", location); + weather.put("temperature", "72°F"); + weather.put("conditions", "Sunny"); + return weather; + }); + + // Define a chat flow + Flow chatFlow = genkit.defineFlow("chat", String.class, String.class, + (ctx, userMessage) -> { + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o") + .system("You are a helpful assistant.").prompt(userMessage).build()); + + return response.getText(); + }); + + // Define a flow that uses the weather tool + Flow weatherAssistantFlow = genkit.defineFlow("weatherAssistant", String.class, + String.class, (ctx, userMessage) -> { + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o").system( + "You are a helpful weather assistant. Use the getWeather tool to provide weather information when asked about the weather in a specific location.") + .prompt(userMessage).tools(List.of(weatherTool)).build()); + + return response.getText(); + }); + + // Define a streaming chat flow + Flow streamingChatFlow = genkit.defineFlow("streamingChat", String.class, String.class, + (ctx, userMessage) -> { + StringBuilder result = new StringBuilder(); + + ModelResponse response = genkit.generateStream(GenerateOptions.builder().model("openai/gpt-4o") + .system("You are a helpful assistant that provides detailed, comprehensive responses.") + .prompt(userMessage).config(GenerationConfig.builder().maxOutputTokens(1000).build()) + .build(), (chunk) -> { + // Process each chunk as it arrives + String text = chunk.getText(); + if (text != null) { + result.append(text); + System.out.print(text); // Print chunks in real-time + } + }); + + System.out.println(); // New line after streaming completes + return response.getText(); + }); + + // Define a streaming flow that uses tools - combines both features! + Flow streamingWeatherFlow = genkit.defineFlow("streamingWeather", String.class, + String.class, (ctx, userMessage) -> { + StringBuilder result = new StringBuilder(); + + System.out.println("\n--- Streaming Weather Assistant ---"); + System.out.println("Query: " + userMessage); + System.out.println("Response: "); + + ModelResponse response = genkit.generateStream( + GenerateOptions.builder().model("openai/gpt-4o") + .system("You are a helpful weather assistant. When asked about weather, " + + "use the getWeather tool to get current conditions, then provide " + + "a friendly, detailed response about the weather.") + .prompt(userMessage).tools(List.of(weatherTool)) + .config(GenerationConfig.builder().maxOutputTokens(500).build()).build(), + (chunk) -> { + // Stream chunks as they arrive + String text = chunk.getText(); + if (text != null) { + result.append(text); + System.out.print(text); + } + }); + + System.out.println("\n--- End of Response ---\n"); + return response.getText(); + }); + + // ==================== + // IMAGE EXAMPLES + // ==================== + + // Define a flow that analyzes an image using GPT-4 Vision + // This flow accepts an image URL and returns a description + Flow describeImageFlow = genkit.defineFlow("describeImage", String.class, String.class, + (ctx, imageUrl) -> { + System.out.println("\n--- Image Description Flow ---"); + System.out.println("Analyzing image: " + imageUrl); + + // Create a message with both text and image + Message userMessage = new Message(); + userMessage.setRole(Role.USER); + userMessage.setContent(List.of(Part.text( + "Describe this image in detail. What do you see? Include colors, objects, people, and any text visible."), + Part.media("image/jpeg", imageUrl) // Can also be image/png, image/gif, image/webp + )); + + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o") // GPT-4o + // supports + // vision + .messages(List.of(userMessage)) + .config(GenerationConfig.builder().maxOutputTokens(500).temperature(0.7).build()).build()); + + System.out.println("Description: " + response.getText()); + System.out.println("--- End of Image Description ---\n"); + + return response.getText(); + }); + + // Define a flow that generates an image using DALL-E 3 + // This flow accepts a prompt and returns the generated image URL (base64 data + // URI) + Flow generateImageFlow = genkit.defineFlow("generateImage", String.class, String.class, + (ctx, prompt) -> { + System.out.println("\n--- Image Generation Flow ---"); + System.out.println("Generating image for prompt: " + prompt); + + // Create image-specific config options using the custom field + Map imageOptions = new HashMap<>(); + imageOptions.put("size", "1024x1024"); // Image size + imageOptions.put("quality", "standard"); // "standard" or "hd" + imageOptions.put("style", "vivid"); // "vivid" or "natural" + imageOptions.put("n", 1); // Number of images + + ModelResponse response = genkit.generate(GenerateOptions.builder().model("openai/dall-e-3") // DALL-E + // 3 for + // image + // generation + .prompt(prompt).config(GenerationConfig.builder().custom(imageOptions).build()).build()); + + // Get the generated image media + Message message = response.getCandidates().get(0).getMessage(); + List parts = message.getContent(); + + // The response contains media parts with the generated images + for (Part part : parts) { + if (part.getMedia() != null) { + String imageUrl = part.getMedia().getUrl(); + String contentType = part.getMedia().getContentType(); + System.out.println("Generated image (" + contentType + ")"); + + // The URL will be a data URI like: data:image/png;base64, + // You can save this or display it in your application + if (imageUrl.startsWith("data:")) { + System.out.println( + "Image returned as base64 data URI (length: " + imageUrl.length() + " chars)"); + } else { + System.out.println("Image URL: " + imageUrl); + } + + return imageUrl; + } + } + + System.out.println("--- End of Image Generation ---\n"); + return "No image generated"; + }); + + System.out.println("Genkit Sample Application Started!"); + System.out.println("====================================="); + System.out.println("Dev UI: http://localhost:3100"); + System.out.println("API Endpoints:"); + System.out.println(" POST http://localhost:8080/api/flows/greeting"); + System.out.println(" POST http://localhost:8080/api/flows/tellJoke"); + System.out.println(" POST http://localhost:8080/api/flows/chat"); + System.out.println(" POST http://localhost:8080/api/flows/weatherAssistant (uses tools)"); + System.out.println(" POST http://localhost:8080/api/flows/streamingChat (uses streaming)"); + System.out.println(" POST http://localhost:8080/api/flows/streamingWeather (uses streaming + tools)"); + System.out.println(" POST http://localhost:8080/api/flows/describeImage (vision - analyze images)"); + System.out.println(" POST http://localhost:8080/api/flows/generateImage (DALL-E - generate images)"); + System.out.println(""); + System.out.println("Example usage:"); + System.out.println( + " curl -X POST http://localhost:8080/api/flows/greeting -d '\"World\"' -H 'Content-Type: application/json'"); + System.out.println( + " curl -X POST http://localhost:8080/api/flows/tellJoke -d '\"programming\"' -H 'Content-Type: application/json'"); + System.out.println( + " curl -X POST http://localhost:8080/api/flows/weatherAssistant -d '\"What is the weather in San Francisco?\"' -H 'Content-Type: application/json'"); + System.out.println( + " curl -X POST http://localhost:8080/api/flows/streamingChat -d '\"Explain quantum computing\"' -H 'Content-Type: application/json'"); + System.out.println( + " curl -X POST http://localhost:8080/api/flows/streamingWeather -d '\"How is the weather in Tokyo today?\"' -H 'Content-Type: application/json'"); + System.out.println(""); + System.out.println("Image Examples:"); + System.out.println(" # Analyze an image with GPT-4 Vision (use a direct image URL, not a webpage):"); + System.out.println( + " curl -X POST http://localhost:8080/api/flows/describeImage -d '\"https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg\"' -H 'Content-Type: application/json'"); + System.out.println(""); + System.out.println(" # Generate an image with DALL-E 3:"); + System.out.println( + " curl -X POST http://localhost:8080/api/flows/generateImage -d '\"A serene Japanese garden with a koi pond at sunset, digital art\"' -H 'Content-Type: application/json'"); + System.out.println(""); + System.out.println("Press Ctrl+C to stop..."); + + // Start the server and block - keeps the application running + jetty.start(); + } +} diff --git a/java/samples/openai/src/main/java/com/google/genkit/samples/SessionSample.java b/java/samples/openai/src/main/java/com/google/genkit/samples/SessionSample.java new file mode 100644 index 0000000000..f2136dc0e6 --- /dev/null +++ b/java/samples/openai/src/main/java/com/google/genkit/samples/SessionSample.java @@ -0,0 +1,288 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.*; +import com.google.genkit.ai.session.*; +import com.google.genkit.plugins.openai.OpenAIPlugin; + +/** + * Sample application demonstrating session-based multi-turn conversations. + * + * This example shows how to: - Create sessions with persistent state - Conduct + * multi-turn conversations with automatic history management - Use multiple + * conversation threads within a session - Implement custom session stores for + * persistence - Use tools within session-based chats + * + * To run: 1. Set the OPENAI_API_KEY environment variable 2. Run: mvn exec:java + * -Dexec.mainClass=com.google.genkit.samples.SessionSample + */ +public class SessionSample { + + /** + * Custom session state to track user preferences and conversation context. + */ + public static class UserState { + private String userName; + private String preferredLanguage; + private int messageCount; + + public UserState() { + } + + public UserState(String userName) { + this.userName = userName; + this.preferredLanguage = "English"; + this.messageCount = 0; + } + + public String getUserName() { + return userName; + } + + public void setUserName(String userName) { + this.userName = userName; + } + + public String getPreferredLanguage() { + return preferredLanguage; + } + + public void setPreferredLanguage(String preferredLanguage) { + this.preferredLanguage = preferredLanguage; + } + + public int getMessageCount() { + return messageCount; + } + + public void incrementMessageCount() { + this.messageCount++; + } + } + + public static void main(String[] args) throws Exception { + // Create Genkit with OpenAI plugin + Genkit genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(OpenAIPlugin.create()).build(); + + // Define a tool for the conversation + @SuppressWarnings("unchecked") + Tool, Map> reminderTool = genkit.defineTool("setReminder", + "Sets a reminder for the user", + Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string", "description", "The reminder message"), "time", + Map.of("type", "string", "description", + "When to remind (e.g., '5 minutes', 'tomorrow')")), + "required", new String[]{"message", "time"}), + (Class>) (Class) Map.class, (ctx, input) -> { + Map result = new HashMap<>(); + result.put("status", "success"); + result.put("message", "Reminder set: " + input.get("message") + " at " + input.get("time")); + return result; + }); + + System.out.println("=== Session-Based Chat Demo ===\n"); + + // Example 1: Basic session with multi-turn conversation + basicSessionExample(genkit); + + // Example 2: Session with custom state + sessionWithStateExample(genkit); + + // Example 3: Multiple conversation threads + multiThreadExample(genkit); + + // Example 4: Session with tools + sessionWithToolsExample(genkit, reminderTool); + + // Example 5: Loading existing sessions + sessionPersistenceExample(genkit); + + System.out.println("\n=== Demo Complete ==="); + } + + /** + * Demonstrates basic session creation and multi-turn conversation. + */ + private static void basicSessionExample(Genkit genkit) throws Exception { + System.out.println("--- Example 1: Basic Multi-Turn Conversation ---\n"); + + // Create a session + Session session = genkit.createSession(); + System.out.println("Created session: " + session.getId()); + + // Create a chat with system prompt + Chat chat = session.chat(ChatOptions.builder().model("openai/gpt-4o-mini") + .system("You are a helpful assistant. Keep your responses brief and friendly.").build()); + + // Multi-turn conversation - history is automatically managed + System.out.println("\nUser: What is the capital of France?"); + ModelResponse response1 = chat.send("What is the capital of France?"); + System.out.println("Assistant: " + response1.getText()); + + System.out.println("\nUser: What's the population?"); + ModelResponse response2 = chat.send("What's the population?"); + System.out.println("Assistant: " + response2.getText()); + + System.out.println("\nUser: What language do they speak there?"); + ModelResponse response3 = chat.send("What language do they speak there?"); + System.out.println("Assistant: " + response3.getText()); + + // Show conversation history + System.out.println("\n--- Conversation History ---"); + for (Message msg : chat.getHistory()) { + System.out.println( + msg.getRole() + ": " + msg.getText().substring(0, Math.min(50, msg.getText().length())) + "..."); + } + System.out.println(); + } + + /** + * Demonstrates session with custom state management. + */ + private static void sessionWithStateExample(Genkit genkit) throws Exception { + System.out.println("--- Example 2: Session with Custom State ---\n"); + + // Create session with initial state + Session session = genkit + .createSession(SessionOptions.builder().initialState(new UserState("Alice")).build()); + + System.out.println("Created session for user: " + session.getState().getUserName()); + + // Create chat + Chat chat = session.chat(ChatOptions.builder().model("openai/gpt-4o-mini") + .system("You are a helpful assistant. The user's name is " + session.getState().getUserName() + ".") + .build()); + + // Send message and update state + ModelResponse response = chat.send("Hello! Can you remember my name?"); + System.out.println("Assistant: " + response.getText()); + + // Update session state + UserState state = session.getState(); + state.incrementMessageCount(); + session.updateState(state).join(); + + System.out.println("Message count: " + session.getState().getMessageCount()); + System.out.println(); + } + + /** + * Demonstrates multiple conversation threads within a session. + */ + private static void multiThreadExample(Genkit genkit) throws Exception { + System.out.println("--- Example 3: Multiple Conversation Threads ---\n"); + + Session session = genkit.createSession(); + + // Create chat for general conversation + Chat generalChat = session.chat("general", ChatOptions.builder().model("openai/gpt-4o-mini") + .system("You are a helpful general assistant.").build()); + + // Create chat for coding help + Chat codingChat = session.chat("coding", ChatOptions.builder().model("openai/gpt-4o-mini") + .system("You are an expert programmer. Provide concise code examples.").build()); + + // Use different threads for different topics + System.out.println("General thread:"); + ModelResponse generalResponse = generalChat.send("What's a good recipe for pasta?"); + System.out.println("Response: " + + generalResponse.getText().substring(0, Math.min(100, generalResponse.getText().length())) + "...\n"); + + System.out.println("Coding thread:"); + ModelResponse codingResponse = codingChat.send("How do I reverse a string in Java?"); + System.out.println("Response: " + + codingResponse.getText().substring(0, Math.min(100, codingResponse.getText().length())) + "...\n"); + + // Continue in general thread - context is preserved per thread + System.out.println("Back to general thread:"); + ModelResponse followUp = generalChat.send("What ingredients do I need for that?"); + System.out.println( + "Response: " + followUp.getText().substring(0, Math.min(100, followUp.getText().length())) + "...\n"); + } + + /** + * Demonstrates using tools within session-based chats. + */ + private static void sessionWithToolsExample(Genkit genkit, Tool reminderTool) throws Exception { + System.out.println("--- Example 4: Session with Tools ---\n"); + + Session session = genkit.createSession(); + + @SuppressWarnings("unchecked") + Chat chat = session.chat(ChatOptions.builder().model("openai/gpt-4o-mini") + .system("You are a helpful assistant that can set reminders for users.") + .tools(List.of((Tool) reminderTool)).build()); + + System.out.println("User: Remind me to buy groceries in 1 hour"); + ModelResponse response = chat.send("Remind me to buy groceries in 1 hour"); + System.out.println("Assistant: " + response.getText()); + System.out.println(); + } + + /** + * Demonstrates session persistence - saving and loading sessions. + */ + private static void sessionPersistenceExample(Genkit genkit) throws Exception { + System.out.println("--- Example 5: Session Persistence ---\n"); + + // Create a custom session store (using in-memory for this example) + InMemorySessionStore store = new InMemorySessionStore<>(); + + // Create session with the store + Session session = genkit.createSession(SessionOptions.builder().store(store) + .sessionId("persistent-session-001").initialState(new UserState("Bob")).build()); + + // Have a conversation + Chat chat = session.chat(ChatOptions.builder().model("openai/gpt-4o-mini") + .system("You are a helpful assistant.").build()); + + chat.send("Hello, I'm learning about AI"); + chat.send("What's machine learning?"); + + System.out.println("Original session ID: " + session.getId()); + System.out.println("Messages in session: " + chat.getHistory().size()); + + // Load the session later (simulating app restart) + Session loadedSession = genkit + .loadSession("persistent-session-001", SessionOptions.builder().store(store).build()).get(); + + if (loadedSession != null) { + System.out.println("\nLoaded session ID: " + loadedSession.getId()); + System.out.println("User name from state: " + loadedSession.getState().getUserName()); + System.out.println("Messages preserved: " + loadedSession.getMessages().size()); + + // Continue the conversation + Chat continuedChat = loadedSession.chat(ChatOptions.builder() + .model("openai/gpt-4o-mini").system("You are a helpful assistant.").build()); + + System.out.println("\nContinuing conversation..."); + ModelResponse response = continuedChat.send("Can you summarize what we discussed?"); + System.out.println("Assistant: " + response.getText()); + } + System.out.println(); + } +} diff --git a/java/samples/openai/src/main/resources/logback.xml b/java/samples/openai/src/main/resources/logback.xml new file mode 100644 index 0000000000..56110e53c0 --- /dev/null +++ b/java/samples/openai/src/main/resources/logback.xml @@ -0,0 +1,25 @@ + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + + + diff --git a/java/samples/rag/README.md b/java/samples/rag/README.md new file mode 100644 index 0000000000..c10d03d669 --- /dev/null +++ b/java/samples/rag/README.md @@ -0,0 +1,198 @@ +# Genkit RAG Sample + +This sample demonstrates how to build RAG (Retrieval Augmented Generation) applications with Genkit Java using a local vector store for development. + +## Features Demonstrated + +- **Local Vector Store Plugin**: File-based vector storage for development and testing +- **Document Indexing**: Index documents from various sources +- **Semantic Retrieval**: Find relevant documents using embeddings +- **RAG Flows**: Combine retrieval with LLM generation +- **Multiple Knowledge Bases**: Separate vector stores for different domains + +## Architecture + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ Index Flow │────▶│ Local Vec Store │◀────│ Retrieve Flow │ +│ (documents) │ │ (embeddings) │ │ (query) │ +└─────────────────┘ └──────────────────┘ └────────┬────────┘ + │ + ▼ + ┌──────────────────┐ ┌─────────────────┐ + │ OpenAI LLM │◀────│ RAG Flow │ + │ (generation) │ │ (answer) │ + └──────────────────┘ └─────────────────┘ +``` + +## Knowledge Bases + +The sample includes three pre-configured knowledge bases: + +1. **world-capitals**: Information about capital cities around the world +2. **dog-breeds**: Facts about popular dog breeds +3. **coffee-facts**: Information about coffee and brewing methods + +## Prerequisites + +- Java 17+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/rag + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI (Recommended) + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/rag + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +## Usage + +### Step 1: Index the Data + +Before querying, you need to index the documents: + +```bash +# Index world capitals +curl -X POST http://localhost:8080/indexWorldCapitals + +# Index dog breeds +curl -X POST http://localhost:8080/indexDogBreeds + +# Index coffee facts +curl -X POST http://localhost:8080/indexCoffeeFacts +``` + +### Step 2: Query the Knowledge Bases + +```bash +# Ask about world capitals +curl -X POST http://localhost:8080/askAboutCapitals \ + -H 'Content-Type: application/json' \ + -d '"What is the capital of France and what is it known for?"' + +# Ask about dogs +curl -X POST http://localhost:8080/askAboutDogs \ + -H 'Content-Type: application/json' \ + -d '"What are good dog breeds for families with children?"' + +# Ask about coffee +curl -X POST http://localhost:8080/askAboutCoffee \ + -H 'Content-Type: application/json' \ + -d '"How do you make espresso and what is a cappuccino?"' +``` + +### Step 3: Retrieve Documents Without Generation + +```bash +# Just retrieve relevant documents +curl -X POST http://localhost:8080/retrieveDocuments \ + -H 'Content-Type: application/json' \ + -d '{ + "query": "France capital", + "store": "world-capitals", + "k": 2 + }' +``` + +### Step 4: Index Custom Documents + +```bash +curl -X POST http://localhost:8080/indexDocuments \ + -H 'Content-Type: application/json' \ + -d '[ + "The first fact about my topic.", + "The second fact about my topic.", + "The third fact about my topic." + ]' +``` + +## How It Works + +### Indexing + +1. Documents are loaded from text files (one paragraph = one document) +2. Each document is converted to an embedding using OpenAI's embedding model +3. Documents and embeddings are stored in a JSON file on disk + +### Retrieval + +1. The query is converted to an embedding +2. Cosine similarity is computed between the query and all stored documents +3. The top-k most similar documents are returned + +### Generation + +1. Retrieved documents are formatted as context +2. The context and question are combined into a prompt +3. The LLM generates an answer based on the context + +## Local Vector Store + +The local vector store is designed for development and testing only. For production, use a proper vector database like: + +- Pinecone +- Chroma +- Weaviate +- pgvector (PostgreSQL) +- Vertex AI Vector Search + +### Storage Location + +Documents are stored in JSON files at: +``` +{java.io.tmpdir}/genkit-rag-sample/__db_{index-name}.json +``` + +## Adding Your Own Data + +1. Create a text file with your content (paragraphs separated by blank lines) +2. Place it in `src/main/resources/data/` +3. Create a new `LocalVecConfig` for your data +4. Define indexing and query flows + +## Development UI + +Access the Genkit Development UI at http://localhost:3100 to: +- Browse available flows, indexers, and retrievers +- Test flows interactively +- View execution traces +- Inspect indexed documents + +## Troubleshooting + +### Empty Results +- Make sure you've indexed the documents first +- Check that the embedding model is working correctly + +### Slow Indexing +- The first indexing takes longer due to embedding computation +- Subsequent runs use cached embeddings + +### Out of Memory +- For large datasets, consider batch indexing +- Use a proper vector database for production diff --git a/java/samples/rag/pom.xml b/java/samples/rag/pom.xml new file mode 100644 index 0000000000..2f375ae735 --- /dev/null +++ b/java/samples/rag/pom.xml @@ -0,0 +1,99 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-rag + jar + Genkit RAG Sample + Sample application demonstrating Genkit RAG with local vector store + + + UTF-8 + 17 + 17 + 1.0.0-SNAPSHOT + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-localvec + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.5.3 + + + + + + + src/main/resources + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.12.1 + + 17 + 17 + + + + org.codehaus.mojo + exec-maven-plugin + 3.2.0 + + com.google.genkit.samples.RagSample + + + + + diff --git a/java/samples/rag/run.sh b/java/samples/rag/run.sh new file mode 100755 index 0000000000..7a055a49ca --- /dev/null +++ b/java/samples/rag/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Run script for Genkit DotPrompt Sample +cd "$(dirname "$0")" +mvn exec:java diff --git a/java/samples/rag/src/main/java/com/google/genkit/samples/RagSample.java b/java/samples/rag/src/main/java/com/google/genkit/samples/RagSample.java new file mode 100644 index 0000000000..d6ca60769f --- /dev/null +++ b/java/samples/rag/src/main/java/com/google/genkit/samples/RagSample.java @@ -0,0 +1,280 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import java.io.BufferedReader; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.*; +import com.google.genkit.core.Flow; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; +import com.google.genkit.plugins.localvec.LocalVecConfig; +import com.google.genkit.plugins.localvec.LocalVecPlugin; +import com.google.genkit.plugins.openai.OpenAIPlugin; + +/** + * Sample application demonstrating RAG (Retrieval Augmented Generation) with + * Genkit Java. + * + *

+ * This example shows how to: + *

    + *
  • Use the local vector store plugin for development
  • + *
  • Index documents from text files
  • + *
  • Create retriever flows to fetch relevant documents
  • + *
  • Build RAG flows that combine retrieval with generation
  • + *
+ * + *

+ * To run: + *

    + *
  1. Set the OPENAI_API_KEY environment variable
  2. + *
  3. Run: mvn exec:java
  4. + *
+ */ +public class RagSample { + + private static final Logger logger = LoggerFactory.getLogger(RagSample.class); + + /** + * System prompt for RAG queries. Documents are automatically injected via the + * .docs() option. + */ + private static final String RAG_SYSTEM_PROMPT = """ + You are a helpful assistant that answers questions based on the provided context documents. + + Please provide a helpful answer based only on the context provided. If the context doesn't contain + enough information to answer the question, say so. + """; + + public static void main(String[] args) throws Exception { + // Configure local vector stores with embedder name (will be resolved during + // init) + Path storageDir = Paths.get(System.getProperty("java.io.tmpdir"), "genkit-rag-sample"); + + LocalVecConfig worldCapitalsConfig = LocalVecConfig.builder().indexName("world-capitals") + .embedderName("openai/text-embedding-3-small").directory(storageDir).build(); + + LocalVecConfig dogBreedsConfig = LocalVecConfig.builder().indexName("dog-breeds") + .embedderName("openai/text-embedding-3-small").directory(storageDir).build(); + + LocalVecConfig coffeeFactsConfig = LocalVecConfig.builder().indexName("coffee-facts") + .embedderName("openai/text-embedding-3-small").directory(storageDir).build(); + + // Create the Jetty server plugin + JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build()); + + // Create Genkit with all plugins - LocalVec embedders are resolved + // automatically + Genkit genkit = Genkit.builder().options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(OpenAIPlugin.create()).plugin(LocalVecPlugin.builder().addStore(worldCapitalsConfig) + .addStore(dogBreedsConfig).addStore(coffeeFactsConfig).build()) + .plugin(jetty).build(); + + // Define flow to index world capitals data + Flow indexWorldCapitalsFlow = genkit.defineFlow("indexWorldCapitals", Void.class, + String.class, (ctx, input) -> { + List documents = loadDocumentsFromResource("/data/world-capitals.txt"); + genkit.index("devLocalVectorStore/world-capitals", documents); + return "Indexed " + documents.size() + " world capitals documents"; + }); + + // Define flow to index dog breeds data + Flow indexDogBreedsFlow = genkit.defineFlow("indexDogBreeds", Void.class, String.class, + (ctx, input) -> { + List documents = loadDocumentsFromResource("/data/dog-breeds.txt"); + genkit.index("devLocalVectorStore/dog-breeds", documents); + return "Indexed " + documents.size() + " dog breeds documents"; + }); + + // Define flow to index coffee facts data + Flow indexCoffeeFactsFlow = genkit.defineFlow("indexCoffeeFacts", Void.class, String.class, + (ctx, input) -> { + List documents = loadDocumentsFromResource("/data/coffee-facts.txt"); + genkit.index("devLocalVectorStore/coffee-facts", documents); + return "Indexed " + documents.size() + " coffee facts documents"; + }); + + // Define RAG flow for world capitals + Flow askAboutCapitalsFlow = genkit.defineFlow("askAboutCapitals", String.class, + String.class, (ctx, question) -> { + // Retrieve relevant documents + List docs = genkit.retrieve("devLocalVectorStore/world-capitals", question); + + // Generate answer with retrieved documents as context + ModelResponse modelResponse = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .system(RAG_SYSTEM_PROMPT).prompt(question).docs(docs) + .config(GenerationConfig.builder().temperature(0.3).build()).build()); + + return modelResponse.getText(); + }); + + // Define RAG flow for dog breeds + Flow askAboutDogsFlow = genkit.defineFlow("askAboutDogs", String.class, String.class, + (ctx, question) -> { + List docs = genkit.retrieve("devLocalVectorStore/dog-breeds", question); + + ModelResponse modelResponse = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .system(RAG_SYSTEM_PROMPT).prompt(question).docs(docs) + .config(GenerationConfig.builder().temperature(0.3).build()).build()); + + return modelResponse.getText(); + }); + + // Define RAG flow for coffee facts + Flow askAboutCoffeeFlow = genkit.defineFlow("askAboutCoffee", String.class, String.class, + (ctx, question) -> { + List docs = genkit.retrieve("devLocalVectorStore/coffee-facts", question); + + ModelResponse modelResponse = genkit.generate(GenerateOptions.builder().model("openai/gpt-4o-mini") + .system(RAG_SYSTEM_PROMPT).prompt(question).docs(docs) + .config(GenerationConfig.builder().temperature(0.3).build()).build()); + + return modelResponse.getText(); + }); + + // Define a generic indexing flow that accepts documents + Flow, String, Void> indexDocumentsFlow = genkit.defineFlow("indexDocuments", + (Class>) (Class) List.class, String.class, (ctx, texts) -> { + List documents = texts.stream().map(Document::fromText).collect(Collectors.toList()); + + genkit.index("devLocalVectorStore/world-capitals", documents); + return "Indexed " + documents.size() + " documents"; + }); + + // Define a simple retrieval-only flow + Flow, Void> retrieveDocumentsFlow = genkit.defineFlow("retrieveDocuments", Map.class, + (Class>) (Class) List.class, (ctx, input) -> { + String query = (String) input.get("query"); + String store = (String) input.getOrDefault("store", "world-capitals"); + + List docs = genkit.retrieve("devLocalVectorStore/" + store, query); + + return docs.stream().map(Document::text).collect(Collectors.toList()); + }); + + logger.info("=".repeat(60)); + logger.info("Genkit RAG Sample Started"); + logger.info("=".repeat(60)); + logger.info(""); + logger.info("Available flows:"); + logger.info(""); + logger.info("Indexing flows (run these first to populate the vector stores):"); + logger.info(" - indexWorldCapitals: Index world capitals data"); + logger.info(" - indexDogBreeds: Index dog breeds data"); + logger.info(" - indexCoffeeFacts: Index coffee facts data"); + logger.info(" - indexDocuments: Index custom documents"); + logger.info(""); + logger.info("RAG Query flows:"); + logger.info(" - askAboutCapitals: Ask questions about world capitals"); + logger.info(" - askAboutDogs: Ask questions about dog breeds"); + logger.info(" - askAboutCoffee: Ask questions about coffee"); + logger.info(""); + logger.info("Retrieval flow:"); + logger.info(" - retrieveDocuments: Retrieve documents without generation"); + logger.info(""); + logger.info("Example calls:"); + logger.info(""); + logger.info("1. First, index the data:"); + logger.info(" curl -X POST http://localhost:8080/indexWorldCapitals"); + logger.info(" curl -X POST http://localhost:8080/indexDogBreeds"); + logger.info(" curl -X POST http://localhost:8080/indexCoffeeFacts"); + logger.info(""); + logger.info("2. Then query:"); + logger.info(" curl -X POST http://localhost:8080/askAboutCapitals \\"); + logger.info(" -H 'Content-Type: application/json' \\"); + logger.info(" -d '\"What is the capital of France?\"'"); + logger.info(""); + logger.info(" curl -X POST http://localhost:8080/askAboutDogs \\"); + logger.info(" -H 'Content-Type: application/json' \\"); + logger.info(" -d '\"What are good family dogs?\"'"); + logger.info(""); + logger.info(" curl -X POST http://localhost:8080/askAboutCoffee \\"); + logger.info(" -H 'Content-Type: application/json' \\"); + logger.info(" -d '\"How is espresso made?\"'"); + logger.info(""); + logger.info("3. Retrieve without generation:"); + logger.info(" curl -X POST http://localhost:8080/retrieveDocuments \\"); + logger.info(" -H 'Content-Type: application/json' \\"); + logger.info(" -d '{\"query\":\"France\",\"store\":\"world-capitals\",\"k\":2}'"); + logger.info(""); + logger.info("Reflection API: http://localhost:3100"); + logger.info("HTTP API: http://localhost:8080"); + logger.info("=".repeat(60)); + + // Start the server and block - keeps the application running + jetty.start(); + } + + /** + * Loads documents from a text resource file. Each paragraph (separated by blank + * lines) becomes a separate document. + */ + private static List loadDocumentsFromResource(String resourcePath) { + List documents = new ArrayList<>(); + + try (InputStream is = RagSample.class.getResourceAsStream(resourcePath)) { + if (is == null) { + throw new RuntimeException("Resource not found: " + resourcePath); + } + + BufferedReader reader = new BufferedReader(new InputStreamReader(is)); + StringBuilder paragraph = new StringBuilder(); + String line; + + while ((line = reader.readLine()) != null) { + if (line.trim().isEmpty()) { + if (paragraph.length() > 0) { + documents.add(Document.fromText(paragraph.toString().trim())); + paragraph = new StringBuilder(); + } + } else { + if (paragraph.length() > 0) { + paragraph.append(" "); + } + paragraph.append(line.trim()); + } + } + + // Don't forget the last paragraph + if (paragraph.length() > 0) { + documents.add(Document.fromText(paragraph.toString().trim())); + } + + } catch (Exception e) { + throw new RuntimeException("Failed to load documents from " + resourcePath, e); + } + + logger.info("Loaded {} documents from {}", documents.size(), resourcePath); + return documents; + } +} diff --git a/java/samples/rag/src/main/resources/data/coffee-facts.txt b/java/samples/rag/src/main/resources/data/coffee-facts.txt new file mode 100644 index 0000000000..801e138e5a --- /dev/null +++ b/java/samples/rag/src/main/resources/data/coffee-facts.txt @@ -0,0 +1,19 @@ +Coffee originated in Ethiopia, where legend says a goat herder named Kaldi discovered the energizing effects of coffee beans after his goats ate them. + +Espresso is a concentrated form of coffee made by forcing hot water through finely-ground coffee beans. A single shot is typically 30ml and contains about 63mg of caffeine. + +The cappuccino is an Italian coffee drink made with equal parts espresso, steamed milk, and milk foam. It's traditionally consumed only in the morning in Italy. + +Arabica and Robusta are the two main species of coffee beans. Arabica is considered higher quality with more complex flavors, while Robusta has more caffeine and a stronger, more bitter taste. + +Cold brew coffee is made by steeping coarsely ground coffee in cold water for 12-24 hours. It results in a smooth, less acidic coffee concentrate. + +The latte (caffè latte) consists of espresso with steamed milk and a small amount of foam. Latte art, created by pouring steamed milk into espresso, has become a specialty skill. + +Turkish coffee is prepared by boiling finely ground coffee in a special pot called a cezve. It's served unfiltered, with the grounds settling at the bottom of the cup. + +The French press (or press pot) is a manual brewing method that involves steeping coffee grounds in hot water and then pressing them with a plunger. + +Decaffeinated coffee has had at least 97% of its caffeine removed. The decaffeination process typically uses water, organic solvents, or carbon dioxide. + +The global coffee industry is worth over $450 billion annually. Brazil is the world's largest coffee producer, followed by Vietnam, Colombia, and Indonesia. diff --git a/java/samples/rag/src/main/resources/data/dog-breeds.txt b/java/samples/rag/src/main/resources/data/dog-breeds.txt new file mode 100644 index 0000000000..f12a23dbd2 --- /dev/null +++ b/java/samples/rag/src/main/resources/data/dog-breeds.txt @@ -0,0 +1,19 @@ +The golden retriever is a friendly and intelligent dog breed. They are known for their golden coat and gentle temperament. Golden retrievers are excellent family dogs and are often used as therapy and service dogs. + +The German shepherd is a large, intelligent working dog. They are commonly used as police, military, and search and rescue dogs. German shepherds are loyal and protective of their families. + +The Labrador retriever is one of the most popular dog breeds. They come in three colors: black, yellow, and chocolate. Labs are friendly, outgoing, and great with children. + +The French bulldog is a small, muscular dog with a flat face. They are known for their bat-like ears and playful personality. French bulldogs are excellent apartment dogs due to their small size. + +The poodle is an intelligent and elegant dog breed. They come in three sizes: standard, miniature, and toy. Poodles are hypoallergenic and known for their curly coat. + +The beagle is a small to medium-sized hound dog. They have an excellent sense of smell and are often used as detection dogs. Beagles are friendly and good with children. + +The bulldog (English bulldog) is known for its wrinkled face and stocky build. Despite their tough appearance, bulldogs are gentle and affectionate. They are good with families but can be stubborn. + +The Yorkshire terrier (Yorkie) is a small, feisty terrier. They have a long, silky coat and a big personality. Yorkies make excellent companions but can be yappy. + +The dachshund is known for its long body and short legs. Originally bred to hunt badgers, they are courageous and loyal. Dachshunds come in standard and miniature sizes. + +The boxer is a medium to large dog with a muscular build. They are energetic, playful, and great with children. Boxers are protective of their families and make good guard dogs. diff --git a/java/samples/rag/src/main/resources/data/world-capitals.txt b/java/samples/rag/src/main/resources/data/world-capitals.txt new file mode 100644 index 0000000000..48fcd681de --- /dev/null +++ b/java/samples/rag/src/main/resources/data/world-capitals.txt @@ -0,0 +1,19 @@ +Paris is the capital of France. It is known for the Eiffel Tower, the Louvre Museum, and its rich cultural heritage. The city is often called the "City of Light" (La Ville Lumière). + +London is the capital of the United Kingdom. It is famous for Big Ben, the Tower of London, and Buckingham Palace. The River Thames flows through the heart of the city. + +Tokyo is the capital of Japan. It is one of the most populous metropolitan areas in the world. Famous landmarks include the Tokyo Tower, Senso-ji Temple, and the Imperial Palace. + +Berlin is the capital of Germany. It is known for its history, art scene, and modern architecture. The Brandenburg Gate and the Berlin Wall Memorial are major attractions. + +Rome is the capital of Italy. It was the center of the Roman Empire and is home to the Colosseum, Vatican City, and the Trevi Fountain. + +Madrid is the capital of Spain. It is famous for its art museums, including the Prado Museum and Reina Sofia. The city is also known for its vibrant nightlife and tapas culture. + +Washington D.C. is the capital of the United States. It houses the White House, the Capitol Building, and numerous Smithsonian museums. + +Canberra is the capital of Australia. Unlike Sydney and Melbourne, it was purpose-built to serve as the national capital. It is home to many national monuments and institutions. + +Ottawa is the capital of Canada. Located in Ontario, it is known for Parliament Hill, the Rideau Canal, and the National Gallery of Canada. + +Brasília is the capital of Brazil. It was planned and built in the 1950s to serve as the new capital, replacing Rio de Janeiro. The city is known for its modernist architecture. diff --git a/java/samples/rag/src/main/resources/logback.xml b/java/samples/rag/src/main/resources/logback.xml new file mode 100644 index 0000000000..fe98c37a85 --- /dev/null +++ b/java/samples/rag/src/main/resources/logback.xml @@ -0,0 +1,26 @@ + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + + + +