diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/main/java/org/springframework/ai/model/anthropic/autoconfigure/AnthropicChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/main/java/org/springframework/ai/model/anthropic/autoconfigure/AnthropicChatAutoConfiguration.java index dd4ae062bf..5427e35f8f 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/main/java/org/springframework/ai/model/anthropic/autoconfigure/AnthropicChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/main/java/org/springframework/ai/model/anthropic/autoconfigure/AnthropicChatAutoConfiguration.java @@ -27,7 +27,9 @@ import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.ImportAutoConfiguration; @@ -79,12 +81,15 @@ public AnthropicApi anthropicApi(AnthropicConnectionProperties connectionPropert public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi, AnthropicChatProperties chatProperties, RetryTemplate retryTemplate, ToolCallingManager toolCallingManager, ObjectProvider observationRegistry, - ObjectProvider observationConvention) { + ObjectProvider observationConvention, + ObjectProvider anthropicToolExecutionEligibilityPredicate) { var chatModel = AnthropicChatModel.builder() .anthropicApi(anthropicApi) .defaultOptions(chatProperties.getOptions()) .toolCallingManager(toolCallingManager) + .toolExecutionEligibilityPredicate(anthropicToolExecutionEligibilityPredicate + .getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate())) .retryTemplate(retryTemplate) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .build(); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/main/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/main/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiChatAutoConfiguration.java index ce067c58e3..b7c82f057e 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/main/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/main/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiChatAutoConfiguration.java @@ -25,7 +25,9 @@ import org.springframework.ai.model.SpringAIModels; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; @@ -59,12 +61,15 @@ public class AzureOpenAiChatAutoConfiguration { public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatProperties chatProperties, ToolCallingManager toolCallingManager, ObjectProvider observationRegistry, - ObjectProvider observationConvention) { + ObjectProvider observationConvention, + ObjectProvider azureOpenAiToolExecutionEligibilityPredicate) { var chatModel = AzureOpenAiChatModel.builder() .openAIClientBuilder(openAIClientBuilder) .defaultOptions(chatProperties.getOptions()) .toolCallingManager(toolCallingManager) + .toolExecutionEligibilityPredicate(azureOpenAiToolExecutionEligibilityPredicate + .getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate())) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .build(); observationConvention.ifAvailable(chatModel::setObservationConvention); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/main/java/org/springframework/ai/model/bedrock/converse/autoconfigure/BedrockConverseProxyChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/main/java/org/springframework/ai/model/bedrock/converse/autoconfigure/BedrockConverseProxyChatAutoConfiguration.java index 9aed1da08b..f5aaff363c 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/main/java/org/springframework/ai/model/bedrock/converse/autoconfigure/BedrockConverseProxyChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/main/java/org/springframework/ai/model/bedrock/converse/autoconfigure/BedrockConverseProxyChatAutoConfiguration.java @@ -31,7 +31,9 @@ import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.ImportAutoConfiguration; @@ -70,7 +72,8 @@ public BedrockProxyChatModel bedrockProxyChatModel(AwsCredentialsProvider creden ObjectProvider observationRegistry, ObjectProvider observationConvention, ObjectProvider bedrockRuntimeClient, - ObjectProvider bedrockRuntimeAsyncClient) { + ObjectProvider bedrockRuntimeAsyncClient, + ObjectProvider bedrockToolExecutionEligibilityPredicate) { var chatModel = BedrockProxyChatModel.builder() .credentialsProvider(credentialsProvider) @@ -79,6 +82,8 @@ public BedrockProxyChatModel bedrockProxyChatModel(AwsCredentialsProvider creden .defaultOptions(chatProperties.getOptions()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .toolCallingManager(toolCallingManager) + .toolExecutionEligibilityPredicate(bedrockToolExecutionEligibilityPredicate + .getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate())) .bedrockRuntimeClient(bedrockRuntimeClient.getIfAvailable()) .bedrockRuntimeAsyncClient(bedrockRuntimeAsyncClient.getIfAvailable()) .build(); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiChatAutoConfiguration.java index c67a438560..cc2dc41d55 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiChatAutoConfiguration.java @@ -25,7 +25,9 @@ import org.springframework.ai.model.SpringAIModels; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; import org.springframework.beans.factory.ObjectProvider; @@ -69,7 +71,8 @@ public MistralAiChatModel mistralAiChatModel(MistralAiCommonProperties commonPro MistralAiChatProperties chatProperties, ObjectProvider restClientBuilderProvider, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, ObjectProvider observationRegistry, - ObjectProvider observationConvention) { + ObjectProvider observationConvention, + ObjectProvider mistralAiToolExecutionEligibilityPredicate) { var mistralAiApi = mistralAiApi(chatProperties.getApiKey(), commonProperties.getApiKey(), chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), @@ -79,6 +82,8 @@ public MistralAiChatModel mistralAiChatModel(MistralAiCommonProperties commonPro .mistralAiApi(mistralAiApi) .defaultOptions(chatProperties.getOptions()) .toolCallingManager(toolCallingManager) + .toolExecutionEligibilityPredicate(mistralAiToolExecutionEligibilityPredicate + .getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate())) .retryTemplate(retryTemplate) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .build(); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfiguration.java index 6df931d119..ac097a6b24 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfiguration.java @@ -23,7 +23,9 @@ import org.springframework.ai.model.SpringAIModels; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaApi; @@ -64,7 +66,8 @@ public class OllamaChatAutoConfiguration { public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties, OllamaInitializationProperties initProperties, ToolCallingManager toolCallingManager, ObjectProvider observationRegistry, - ObjectProvider observationConvention) { + ObjectProvider observationConvention, + ObjectProvider ollamaToolExecutionEligibilityPredicate) { var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER; @@ -72,6 +75,8 @@ public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties .ollamaApi(ollamaApi) .defaultOptions(properties.getOptions()) .toolCallingManager(toolCallingManager) + .toolExecutionEligibilityPredicate(ollamaToolExecutionEligibilityPredicate + .getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate())) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .modelManagementOptions( new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(), diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiChatAutoConfiguration.java index a14f5e1c60..0bdb150195 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiChatAutoConfiguration.java @@ -24,7 +24,9 @@ import org.springframework.ai.model.SpringAIModels; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.api.OpenAiApi; @@ -72,7 +74,8 @@ public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperti ObjectProvider webClientBuilderProvider, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, ObjectProvider observationRegistry, - ObjectProvider observationConvention) { + ObjectProvider observationConvention, + ObjectProvider openAiToolExecutionEligibilityPredicate) { var openAiApi = openAiApi(chatProperties, commonProperties, restClientBuilderProvider.getIfAvailable(RestClient::builder), @@ -82,6 +85,8 @@ public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperti .openAiApi(openAiApi) .defaultOptions(chatProperties.getOptions()) .toolCallingManager(toolCallingManager) + .toolExecutionEligibilityPredicate(openAiToolExecutionEligibilityPredicate + .getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate())) .retryTemplate(retryTemplate) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .build(); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/VertexAiGeminiChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/VertexAiGeminiChatAutoConfiguration.java index a7e2e9590e..09f91c6a94 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/VertexAiGeminiChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/VertexAiGeminiChatAutoConfiguration.java @@ -25,7 +25,9 @@ import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.model.SpringAIModelProperties; import org.springframework.ai.model.SpringAIModels; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; @@ -93,12 +95,15 @@ public VertexAI vertexAi(VertexAiGeminiConnectionProperties connectionProperties public VertexAiGeminiChatModel vertexAiGeminiChat(VertexAI vertexAi, VertexAiGeminiChatProperties chatProperties, ToolCallingManager toolCallingManager, ApplicationContext context, RetryTemplate retryTemplate, ObjectProvider observationRegistry, - ObjectProvider observationConvention) { + ObjectProvider observationConvention, + ObjectProvider vertexAiGeminiToolExecutionEligibilityPredicate) { VertexAiGeminiChatModel chatModel = VertexAiGeminiChatModel.builder() .vertexAI(vertexAi) .defaultOptions(chatProperties.getOptions()) .toolCallingManager(toolCallingManager) + .toolExecutionEligibilityPredicate(vertexAiGeminiToolExecutionEligibilityPredicate + .getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate())) .retryTemplate(retryTemplate) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .build(); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/test/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/tool/FunctionCallWithPromptFunctionIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/test/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/tool/FunctionCallWithPromptFunctionIT.java index 62870fdc48..0725b52437 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/test/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/tool/FunctionCallWithPromptFunctionIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/test/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/tool/FunctionCallWithPromptFunctionIT.java @@ -50,7 +50,7 @@ public class FunctionCallWithPromptFunctionIT { void functionCallTest() { this.contextRunner .withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" - + VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH_LIGHT.getValue()) + + VertexAiGeminiChatModel.ChatModel.GEMINI_2_5_PRO.getValue()) .run(context -> { VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index a513c25aba..a8a78c9d56 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -64,8 +64,10 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.tool.definition.ToolDefinition; @@ -124,6 +126,12 @@ public class AnthropicChatModel implements ChatModel { private final ToolCallingManager toolCallingManager; + /** + * The tool execution eligibility predicate used to determine if a tool can be + * executed. + */ + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + /** * Conventions to use for generating observations. */ @@ -132,18 +140,27 @@ public class AnthropicChatModel implements ChatModel { public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + this(anthropicApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry, + new DefaultToolExecutionEligibilityPredicate()); + } + + public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, + ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { Assert.notNull(anthropicApi, "anthropicApi cannot be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); Assert.notNull(retryTemplate, "retryTemplate cannot be null"); Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null"); this.anthropicApi = anthropicApi; this.defaultOptions = defaultOptions; this.toolCallingManager = toolCallingManager; this.retryTemplate = retryTemplate; this.observationRegistry = observationRegistry; + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; } @Override @@ -184,8 +201,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null - && response.hasToolCalls()) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -243,7 +259,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage); - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && chatResponse.hasToolCalls() && chatResponse.hasFinishReasons(Set.of("tool_use"))) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of("tool_use"))) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.defer(() -> { @@ -540,6 +556,8 @@ public static final class Builder { private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + private Builder() { } @@ -563,6 +581,12 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) { return this; } + public Builder toolExecutionEligibilityPredicate( + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + return this; + } + public Builder observationRegistry(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; return this; @@ -574,7 +598,7 @@ public AnthropicChatModel build() { observationRegistry); } return new AnthropicChatModel(anthropicApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, - observationRegistry); + observationRegistry, toolExecutionEligibilityPredicate); } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 2405b9c859..3c3a6cf505 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -85,8 +85,10 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.tool.definition.ToolDefinition; @@ -156,17 +158,32 @@ public class AzureOpenAiChatModel implements ChatModel { */ private final ToolCallingManager toolCallingManager; + /** + * The tool execution eligibility predicate used to determine if a tool can be + * executed. + */ + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry) { + this(openAIClientBuilder, defaultOptions, toolCallingManager, observationRegistry, + new DefaultToolExecutionEligibilityPredicate()); + } + + public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions defaultOptions, + ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null"); this.openAIClient = openAIClientBuilder.buildClient(); this.openAIAsyncClient = openAIClientBuilder.buildAsyncClient(); this.defaultOptions = defaultOptions; this.toolCallingManager = toolCallingManager; this.observationRegistry = observationRegistry; + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; } public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata, @@ -244,8 +261,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null - && response.hasToolCalls()) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -352,8 +368,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha }); return chatResponseFlux.flatMap(chatResponse -> { - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) - && chatResponse.hasToolCalls()) { + if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.defer(() -> { @@ -919,6 +934,8 @@ public static class Builder { private ToolCallingManager toolCallingManager; + private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; private Builder() { @@ -939,6 +956,12 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) { return this; } + public Builder toolExecutionEligibilityPredicate( + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + return this; + } + public Builder observationRegistry(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; return this; @@ -947,10 +970,10 @@ public Builder observationRegistry(ObservationRegistry observationRegistry) { public AzureOpenAiChatModel build() { if (toolCallingManager != null) { return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, toolCallingManager, - observationRegistry); + observationRegistry, toolExecutionEligibilityPredicate); } return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, - observationRegistry); + observationRegistry, toolExecutionEligibilityPredicate); } } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 71f5b4cf96..ff7222ff9d 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -94,8 +94,10 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.tool.definition.ToolDefinition; @@ -151,6 +153,12 @@ public class BedrockProxyChatModel implements ChatModel { private final ToolCallingManager toolCallingManager; + /** + * The tool execution eligibility predicate used to determine if a tool can be + * executed. + */ + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + /** * Conventions to use for generating observations. */ @@ -159,16 +167,26 @@ public class BedrockProxyChatModel implements ChatModel { public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions, ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) { + this(bedrockRuntimeClient, bedrockRuntimeAsyncClient, defaultOptions, observationRegistry, toolCallingManager, + new DefaultToolExecutionEligibilityPredicate()); + } + + public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, + BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions, + ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null"); Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null"); Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate must not be null"); this.bedrockRuntimeClient = bedrockRuntimeClient; this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient; this.defaultOptions = defaultOptions; this.observationRegistry = observationRegistry; this.toolCallingManager = toolCallingManager; + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; } private static ToolCallingChatOptions from(ChatOptions options) { @@ -221,8 +239,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon return response; }); - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && chatResponse != null - && chatResponse.hasToolCalls() + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); if (toolExecutionResult.returnDirect()) { @@ -654,8 +671,7 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh Flux chatResponseFlux = chatResponses.switchMap(chatResponse -> { - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) - && chatResponse.hasToolCalls() + if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) { // FIXME: bounded elastic needs to be used since tool calling @@ -756,6 +772,8 @@ public static final class Builder { private ToolCallingManager toolCallingManager; + private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + private ToolCallingChatOptions defaultOptions = ToolCallingChatOptions.builder().build(); private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; @@ -774,6 +792,12 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) { return this; } + public Builder toolExecutionEligibilityPredicate( + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + return this; + } + public Builder credentialsProvider(AwsCredentialsProvider credentialsProvider) { Assert.notNull(credentialsProvider, "'credentialsProvider' must not be null."); this.credentialsProvider = credentialsProvider; @@ -852,13 +876,13 @@ public BedrockProxyChatModel build() { if (this.toolCallingManager != null) { bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry, - this.toolCallingManager); + this.toolCallingManager, this.toolExecutionEligibilityPredicate); } else { bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry, - DEFAULT_TOOL_CALLING_MANAGER); + DEFAULT_TOOL_CALLING_MANAGER, this.toolExecutionEligibilityPredicate); } if (this.customObservationConvention != null) { diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java index 7c21d30aad..5d73ae14c6 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java @@ -47,7 +47,7 @@ /** * @author Christian Tzolov */ -@Disabled +// @Disabled @SpringBootTest(classes = BedrockNovaChatClientIT.Config.class) @RequiresAwsCredentials public class BedrockNovaChatClientIT { diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 796deed002..7736e2a810 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -27,8 +27,11 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.tool.definition.ToolDefinition; import reactor.core.publisher.Flux; @@ -111,6 +114,12 @@ public class MistralAiChatModel implements ChatModel { private final ToolCallingManager toolCallingManager; + /** + * The tool execution eligibility predicate used to determine if a tool can be + * executed. + */ + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + /** * Conventions to use for generating observations. */ @@ -119,16 +128,25 @@ public class MistralAiChatModel implements ChatModel { public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + this(mistralAiApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry, + new DefaultToolExecutionEligibilityPredicate()); + } + + public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions defaultOptions, + ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { Assert.notNull(mistralAiApi, "mistralAiApi cannot be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); Assert.notNull(retryTemplate, "retryTemplate cannot be null"); Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null"); this.mistralAiApi = mistralAiApi; this.defaultOptions = defaultOptions; this.toolCallingManager = toolCallingManager; this.retryTemplate = retryTemplate; this.observationRegistry = observationRegistry; + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; } public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) { @@ -210,8 +228,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null - && response.hasToolCalls()) { + if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -300,7 +317,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux chatResponseFlux = chatResponse.flatMap(response -> { - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) { + if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.defer(() -> { @@ -530,6 +547,8 @@ public static class Builder { private ToolCallingManager toolCallingManager; + private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; @@ -552,6 +571,12 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) { return this; } + public Builder toolExecutionEligibilityPredicate( + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + return this; + } + public Builder retryTemplate(RetryTemplate retryTemplate) { this.retryTemplate = retryTemplate; return this; @@ -565,10 +590,10 @@ public Builder observationRegistry(ObservationRegistry observationRegistry) { public MistralAiChatModel build() { if (toolCallingManager != null) { return new MistralAiChatModel(mistralAiApi, defaultOptions, toolCallingManager, retryTemplate, - observationRegistry); + observationRegistry, toolExecutionEligibilityPredicate); } return new MistralAiChatModel(mistralAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, - observationRegistry); + observationRegistry, toolExecutionEligibilityPredicate); } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 72a768258b..5ffae904ca 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -49,8 +49,10 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; @@ -117,20 +119,35 @@ public class OllamaChatModel implements ChatModel { private final ToolCallingManager toolCallingManager; + /** + * The tool execution eligibility predicate used to determine if a tool can be + * executed. + */ + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { + this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions, + new DefaultToolExecutionEligibilityPredicate()); + } + + public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, + ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { Assert.notNull(ollamaApi, "ollamaApi must not be null"); Assert.notNull(defaultOptions, "defaultOptions must not be null"); Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); Assert.notNull(observationRegistry, "observationRegistry must not be null"); Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate must not be null"); this.chatApi = ollamaApi; this.defaultOptions = defaultOptions; this.toolCallingManager = toolCallingManager; this.observationRegistry = observationRegistry; this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions); + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy()); } @@ -245,8 +262,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon }); - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null - && response.hasToolCalls()) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -319,7 +335,7 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh // @formatter:off Flux chatResponseFlux = chatResponse.flatMap(response -> { - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.defer(() -> { @@ -514,6 +530,8 @@ public static final class Builder { private ToolCallingManager toolCallingManager; + private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults(); @@ -536,6 +554,12 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) { return this; } + public Builder toolExecutionEligibilityPredicate( + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + return this; + } + public Builder observationRegistry(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; return this; @@ -549,10 +573,10 @@ public Builder modelManagementOptions(ModelManagementOptions modelManagementOpti public OllamaChatModel build() { if (toolCallingManager != null) { return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager, - this.observationRegistry, this.modelManagementOptions); + this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate); } return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, - this.observationRegistry, this.modelManagementOptions); + this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 949bf69731..8c1ca4f2a8 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -58,8 +58,10 @@ import org.springframework.ai.content.Media; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; @@ -136,6 +138,12 @@ public class OpenAiChatModel implements ChatModel { private final ToolCallingManager toolCallingManager; + /** + * The tool execution eligibility predicate used to determine if a tool can be + * executed. + */ + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + /** * Conventions to use for generating observations. */ @@ -143,16 +151,25 @@ public class OpenAiChatModel implements ChatModel { public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + this(openAiApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry, + new DefaultToolExecutionEligibilityPredicate()); + } + + public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, + RetryTemplate retryTemplate, ObservationRegistry observationRegistry, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { Assert.notNull(openAiApi, "openAiApi cannot be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); Assert.notNull(retryTemplate, "retryTemplate cannot be null"); Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null"); this.openAiApi = openAiApi; this.defaultOptions = defaultOptions; this.toolCallingManager = toolCallingManager; this.retryTemplate = retryTemplate; this.observationRegistry = observationRegistry; + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; } @Override @@ -221,8 +238,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons }); - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null - && response.hasToolCalls()) { + if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -345,7 +361,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux flux = chatResponse.flatMap(response -> { - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) { + if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { return Flux.defer(() -> { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous @@ -684,6 +700,8 @@ public static final class Builder { private ToolCallingManager toolCallingManager; + private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; @@ -706,6 +724,12 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) { return this; } + public Builder toolExecutionEligibilityPredicate( + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + return this; + } + public Builder retryTemplate(RetryTemplate retryTemplate) { this.retryTemplate = retryTemplate; return this; @@ -719,10 +743,10 @@ public Builder observationRegistry(ObservationRegistry observationRegistry) { public OpenAiChatModel build() { if (toolCallingManager != null) { return new OpenAiChatModel(openAiApi, defaultOptions, toolCallingManager, retryTemplate, - observationRegistry); + observationRegistry, toolExecutionEligibilityPredicate); } return new OpenAiChatModel(openAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, - observationRegistry); + observationRegistry, toolExecutionEligibilityPredicate); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java index bb5f3be554..c7f632a697 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java @@ -267,7 +267,7 @@ void functionCallTest(String modelName) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "mistral-small-latest", "mistral-large-latest" }) + @ValueSource(strings = { "mistral-large-latest" }) void streamFunctionCallTest(String modelName) { UserMessage userMessage = new UserMessage( diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 7ba15f4518..301ece9ab5 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -77,9 +77,11 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.LegacyToolCallingManager; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.tool.definition.ToolDefinition; @@ -168,6 +170,12 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements */ private final ToolCallingManager toolCallingManager; + /** + * The tool execution eligibility predicate used to determine if a tool can be + * executed. + */ + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + /** * Conventions to use for generating observations. */ @@ -250,6 +258,24 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions opti public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + this(vertexAI, defaultOptions, toolCallingManager, retryTemplate, observationRegistry, + new DefaultToolExecutionEligibilityPredicate()); + } + + /** + * Creates a new instance of VertexAiGeminiChatModel. + * @param vertexAI the Vertex AI instance to use + * @param defaultOptions the default options to use + * @param toolCallingManager the tool calling manager to use. It is wrapped in a + * {@link VertexToolCallingManager} to ensure compatibility with Vertex AI's OpenAPI + * schema format. + * @param retryTemplate the retry template to use + * @param observationRegistry the observation registry to use + * @param toolExecutionEligibilityPredicate the tool execution eligibility predicate + */ + public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions defaultOptions, + ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { super(null, VertexAiGeminiChatOptions.builder().build(), List.of()); @@ -258,12 +284,14 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions defa Assert.notNull(defaultOptions.getModel(), "VertexAiGeminiChatOptions.modelName must not be null"); Assert.notNull(retryTemplate, "RetryTemplate must not be null"); Assert.notNull(toolCallingManager, "ToolCallingManager must not be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "ToolExecutionEligibilityPredicate must not be null"); this.vertexAI = vertexAI; this.defaultOptions = defaultOptions; this.generationConfig = toGenerationConfig(defaultOptions); this.retryTemplate = retryTemplate; this.observationRegistry = observationRegistry; + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; // Wrap the provided tool calling manager in a VertexToolCallingManager to ensure // compatibility with Vertex AI's OpenAPI schema format. @@ -430,8 +458,7 @@ private ChatResponse internalCall(Prompt prompt) { return chatResponse; })); - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null - && response.hasToolCalls()) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -547,7 +574,7 @@ public Flux internalStream(Prompt prompt) { // @formatter:off Flux chatResponseFlux = chatResponse1.flatMap(response -> { - if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) { + if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.defer(() -> { @@ -839,7 +866,9 @@ public enum ChatModel implements ChatModelDescription { GEMINI_2_0_FLASH("gemini-2.0-flash"), - GEMINI_2_0_FLASH_LIGHT("gemini-2.0-flash-lite-preview-02-05"); + GEMINI_2_0_FLASH_LIGHT("gemini-2.0-flash-lite"), + + GEMINI_2_5_PRO("gemini-2.5-pro-exp-03-25"); public final String value; @@ -879,6 +908,8 @@ public static class Builder { private ToolCallingManager toolCallingManager; + private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + private FunctionCallbackResolver functionCallbackResolver; private List toolFunctionCallbacks; @@ -905,6 +936,12 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) { return this; } + public Builder toolExecutionEligibilityPredicate( + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + return this; + } + @Deprecated public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { this.functionCallbackResolver = functionCallbackResolver; @@ -935,7 +972,7 @@ public VertexAiGeminiChatModel build() { "toolFunctionCallbacks cannot be set when toolCallingManager is set"); return new VertexAiGeminiChatModel(vertexAI, defaultOptions, toolCallingManager, retryTemplate, - observationRegistry); + observationRegistry, toolExecutionEligibilityPredicate); } if (functionCallbackResolver != null) { @@ -949,7 +986,7 @@ public VertexAiGeminiChatModel build() { } return new VertexAiGeminiChatModel(vertexAI, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, - observationRegistry); + observationRegistry, toolExecutionEligibilityPredicate); } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java index 1b6e3e8a00..e2c42ac484 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java @@ -304,7 +304,7 @@ public VertexAiGeminiChatModel vertexAiEmbedding(VertexAI vertexAi) { return VertexAiGeminiChatModel.builder() .vertexAI(vertexAi) .defaultOptions(VertexAiGeminiChatOptions.builder() - .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_5_PRO) .build()) .build(); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc index 71de37a335..63abe78c69 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc @@ -1090,7 +1090,22 @@ WARNING: Currently, the internal messages exchanged with the model regarding the === User-Controlled Tool Execution -There are cases where you'd rather control the tool execution lifecycle yourself. You can do so by setting the `internalToolExecutionEnabled` attribute of `ToolCallingChatOptions` to `false`. When you invoke a `ChatModel` with this option, the tool execution will be delegated to the caller, giving you full control over the tool execution lifecycle. It's your responsibility checking for tool calls in the `ChatResponse` and executing them using the `ToolCallingManager`. +There are cases where you'd rather control the tool execution lifecycle yourself. You can do so by setting the `internalToolExecutionEnabled` attribute of `ToolCallingChatOptions` to `false`. +Alternatevly you can implement your `ToolExecutionEligibilityPredicate` predicate to control the tool execution eligibility. +The default predicate implemementation looks like this: +[source,java] +---- +public class DefaultToolExecutionEligibilityPredicate implements ToolExecutionEligibilityPredicate { + + @Override + public boolean test(ChatOptions promptOptions, ChatResponse chatResponse) { + return ToolCallingChatOptions.isInternalToolExecutionEnabled(promptOptions) && chatResponse != null + && chatResponse.hasToolCalls(); + } +} +---- + +When you invoke a `ChatModel` with this option, the tool execution will be delegated to the caller, giving you full control over the tool execution lifecycle. It's your responsibility checking for tool calls in the `ChatResponse` and executing them using the `ToolCallingManager`. The following example demonstrates a minimal implementation of the user-controlled tool execution approach: diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicate.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicate.java new file mode 100644 index 0000000000..4cdc3ad8bb --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicate.java @@ -0,0 +1,36 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.model.tool; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.ChatOptions; + +/** + * Default implementation of {@link ToolExecutionEligibilityPredicate} that checks whether + * tool execution is enabled in the prompt options and if the chat response contains tool + * calls. + * + * @author Christian Tzolov + */ +public class DefaultToolExecutionEligibilityPredicate implements ToolExecutionEligibilityPredicate { + + @Override + public boolean test(ChatOptions promptOptions, ChatResponse chatResponse) { + return ToolCallingChatOptions.isInternalToolExecutionEnabled(promptOptions) && chatResponse != null + && chatResponse.hasToolCalls(); + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java new file mode 100644 index 0000000000..171f94aa03 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java @@ -0,0 +1,80 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.model.tool; + +import java.util.function.Function; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.util.Assert; + +/** + * Interface for determining when tool execution should be performed based on model + * responses. + * + * @author Christian Tzolov + */ +public interface ToolExecutionEligibilityChecker extends Function { + + /** + * Determines if tool execution should be performed based on the prompt options and + * chat response. + * @param promptOptions The options from the prompt + * @param chatResponse The response from the chat model + * @return true if tool execution should be performed, false otherwise + */ + default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse chatResponse) { + Assert.notNull(promptOptions, "promptOptions cannot be null"); + Assert.notNull(chatResponse, "chatResponse cannot be null"); + return this.isInternalToolExecutionEnabled(promptOptions) && this.isToolCallResponse(chatResponse); + } + + /** + * Determines if the response is a tool call message response. + * @param chatResponse The response from the chat model call + * @return true if the response is a tool call message response, false otherwise + */ + default boolean isToolCallResponse(ChatResponse chatResponse) { + Assert.notNull(chatResponse, "chatResponse cannot be null"); + return apply(chatResponse); + } + + /** + * Determines if tool execution should be performed by the Spring AI or by the client. + * @param chatOptions The options from the chat + * @return true if tool execution should be performed by Spring AI, false if it should + * be performed by the client + */ + default boolean isInternalToolExecutionEnabled(ChatOptions chatOptions) { + + Assert.notNull(chatOptions, "chatOptions cannot be null"); + boolean internalToolExecutionEnabled; + if (chatOptions instanceof ToolCallingChatOptions toolCallingChatOptions + && toolCallingChatOptions.isInternalToolExecutionEnabled() != null) { + internalToolExecutionEnabled = Boolean.TRUE.equals(toolCallingChatOptions.isInternalToolExecutionEnabled()); + } + else if (chatOptions instanceof FunctionCallingOptions functionCallingOptions + && functionCallingOptions.getProxyToolCalls() != null) { + internalToolExecutionEnabled = Boolean.TRUE.equals(!functionCallingOptions.getProxyToolCalls()); + } + else { + internalToolExecutionEnabled = true; + } + return internalToolExecutionEnabled; + } + +} \ No newline at end of file diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java new file mode 100644 index 0000000000..cb8244afd8 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java @@ -0,0 +1,45 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.model.tool; + +import java.util.function.BiPredicate; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.util.Assert; + +/** + * Interface for determining when tool execution should be performed based on model + * responses. + * + * @author Christian Tzolov + */ +public interface ToolExecutionEligibilityPredicate extends BiPredicate { + + /** + * Determines if tool execution should be performed based on the prompt options and + * chat response. + * @param promptOptions The options from the prompt + * @param chatResponse The response from the chat model + * @return true if tool execution should be performed, false otherwise + */ + default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse chatResponse) { + Assert.notNull(promptOptions, "promptOptions cannot be null"); + Assert.notNull(chatResponse, "chatResponse cannot be null"); + return test(promptOptions, chatResponse); + } + +} \ No newline at end of file diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java new file mode 100644 index 0000000000..44ca434b12 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java @@ -0,0 +1,169 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model.tool; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallingOptions; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link DefaultToolExecutionEligibilityPredicate}. + * + * @author Christian Tzolov + */ +class DefaultToolExecutionEligibilityPredicateTests { + + private final DefaultToolExecutionEligibilityPredicate predicate = new DefaultToolExecutionEligibilityPredicate(); + + @Test + void whenToolExecutionEnabledAndHasToolCalls() { + // Create a ToolCallingChatOptions with internal tool execution enabled + ToolCallingChatOptions options = ToolCallingChatOptions.builder().internalToolExecutionEnabled(true).build(); + + // Create a ChatResponse with tool calls + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); + AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); + + // Test the predicate + boolean result = predicate.test(options, chatResponse); + assertThat(result).isTrue(); + } + + @Test + void whenToolExecutionEnabledAndNoToolCalls() { + // Create a ToolCallingChatOptions with internal tool execution enabled + ToolCallingChatOptions options = ToolCallingChatOptions.builder().internalToolExecutionEnabled(true).build(); + + // Create a ChatResponse without tool calls + AssistantMessage assistantMessage = new AssistantMessage("test"); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); + + // Test the predicate + boolean result = predicate.test(options, chatResponse); + assertThat(result).isFalse(); + } + + @Test + void whenToolExecutionDisabledAndHasToolCalls() { + // Create a ToolCallingChatOptions with internal tool execution disabled + ToolCallingChatOptions options = ToolCallingChatOptions.builder().internalToolExecutionEnabled(false).build(); + + // Create a ChatResponse with tool calls + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); + AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); + + // Test the predicate + boolean result = predicate.test(options, chatResponse); + assertThat(result).isFalse(); + } + + @Test + void whenToolExecutionDisabledAndNoToolCalls() { + // Create a ToolCallingChatOptions with internal tool execution disabled + ToolCallingChatOptions options = ToolCallingChatOptions.builder().internalToolExecutionEnabled(false).build(); + + // Create a ChatResponse without tool calls + AssistantMessage assistantMessage = new AssistantMessage("test"); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); + + // Test the predicate + boolean result = predicate.test(options, chatResponse); + assertThat(result).isFalse(); + } + + @Test + void whenFunctionCallingOptionsAndToolExecutionEnabled() { + // Create a FunctionCallingOptions with proxy tool calls disabled (which means + // internal tool execution is enabled) + FunctionCallingOptions options = FunctionCallingOptions.builder().proxyToolCalls(false).build(); + + // Create a ChatResponse with tool calls + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); + AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); + + // Test the predicate + boolean result = predicate.test(options, chatResponse); + assertThat(result).isTrue(); + } + + @Test + void whenFunctionCallingOptionsAndToolExecutionDisabled() { + // Create a FunctionCallingOptions with proxy tool calls enabled (which means + // internal tool execution is disabled) + FunctionCallingOptions options = FunctionCallingOptions.builder().proxyToolCalls(true).build(); + + // Create a ChatResponse with tool calls + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); + AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); + + // Test the predicate + boolean result = predicate.test(options, chatResponse); + assertThat(result).isFalse(); + } + + @Test + void whenRegularChatOptionsAndHasToolCalls() { + // Create regular ChatOptions (not ToolCallingChatOptions or + // FunctionCallingOptions) + ChatOptions options = ChatOptions.builder().build(); + + // Create a ChatResponse with tool calls + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); + AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); + + // Test the predicate - should use default value (true) for internal tool + // execution + boolean result = predicate.test(options, chatResponse); + assertThat(result).isTrue(); + } + + @Test + void whenNullChatResponse() { + // Create a ToolCallingChatOptions with internal tool execution enabled + ToolCallingChatOptions options = ToolCallingChatOptions.builder().internalToolExecutionEnabled(true).build(); + + // Test the predicate with null ChatResponse + boolean result = predicate.test(options, null); + assertThat(result).isFalse(); + } + + @Test + void whenEmptyGenerationsList() { + // Create a ToolCallingChatOptions with internal tool execution enabled + ToolCallingChatOptions options = ToolCallingChatOptions.builder().internalToolExecutionEnabled(true).build(); + + // Create a ChatResponse with empty generations list + ChatResponse chatResponse = new ChatResponse(List.of()); + + // Test the predicate + boolean result = predicate.test(options, chatResponse); + assertThat(result).isFalse(); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java new file mode 100644 index 0000000000..124b0e8f68 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model.tool; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link ToolExecutionEligibilityPredicate}. + * + * @author Christian Tzolov + */ +class ToolExecutionEligibilityPredicateTests { + + @Test + void whenIsToolExecutionRequiredWithNullPromptOptions() { + ToolExecutionEligibilityPredicate predicate = new TestToolExecutionEligibilityPredicate(); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("test")))); + + assertThatThrownBy(() -> predicate.isToolExecutionRequired(null, chatResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("promptOptions cannot be null"); + } + + @Test + void whenIsToolExecutionRequiredWithNullChatResponse() { + ToolExecutionEligibilityPredicate predicate = new TestToolExecutionEligibilityPredicate(); + ChatOptions promptOptions = ChatOptions.builder().build(); + + assertThatThrownBy(() -> predicate.isToolExecutionRequired(promptOptions, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("chatResponse cannot be null"); + } + + @Test + void whenIsToolExecutionRequiredWithValidInputs() { + ToolExecutionEligibilityPredicate predicate = new TestToolExecutionEligibilityPredicate(); + ChatOptions promptOptions = ChatOptions.builder().build(); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("test")))); + + boolean result = predicate.isToolExecutionRequired(promptOptions, chatResponse); + assertThat(result).isTrue(); + } + + @Test + void whenTestMethodCalledDirectly() { + ToolExecutionEligibilityPredicate predicate = new TestToolExecutionEligibilityPredicate(); + ChatOptions promptOptions = ChatOptions.builder().build(); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("test")))); + + boolean result = predicate.test(promptOptions, chatResponse); + assertThat(result).isTrue(); + } + + /** + * Test implementation of {@link ToolExecutionEligibilityPredicate} that always + * returns true. + */ + private static class TestToolExecutionEligibilityPredicate implements ToolExecutionEligibilityPredicate { + + @Override + public boolean test(ChatOptions promptOptions, ChatResponse chatResponse) { + return true; + } + + } + +}