Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tool): Add ToolExecutionEligibilityPredicate interface #2585

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
Expand Down Expand Up @@ -79,12 +81,15 @@ public AnthropicApi anthropicApi(AnthropicConnectionProperties connectionPropert
public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi, AnthropicChatProperties chatProperties,
RetryTemplate retryTemplate, ToolCallingManager toolCallingManager,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention) {
ObjectProvider<ChatModelObservationConvention> observationConvention,
ObjectProvider<ToolExecutionEligibilityPredicate> anthropicToolExecutionEligibilityPredicate) {

var chatModel = AnthropicChatModel.builder()
.anthropicApi(anthropicApi)
.defaultOptions(chatProperties.getOptions())
.toolCallingManager(toolCallingManager)
.toolExecutionEligibilityPredicate(anthropicToolExecutionEligibilityPredicate
.getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate()))
.retryTemplate(retryTemplate)
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import org.springframework.ai.model.SpringAIModels;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
Expand Down Expand Up @@ -59,12 +61,15 @@ public class AzureOpenAiChatAutoConfiguration {
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder,
AzureOpenAiChatProperties chatProperties, ToolCallingManager toolCallingManager,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention) {
ObjectProvider<ChatModelObservationConvention> observationConvention,
ObjectProvider<ToolExecutionEligibilityPredicate> azureOpenAiToolExecutionEligibilityPredicate) {

var chatModel = AzureOpenAiChatModel.builder()
.openAIClientBuilder(openAIClientBuilder)
.defaultOptions(chatProperties.getOptions())
.toolCallingManager(toolCallingManager)
.toolExecutionEligibilityPredicate(azureOpenAiToolExecutionEligibilityPredicate
.getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate()))
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.build();
observationConvention.ifAvailable(chatModel::setObservationConvention);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
Expand Down Expand Up @@ -70,7 +72,8 @@ public BedrockProxyChatModel bedrockProxyChatModel(AwsCredentialsProvider creden
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention,
ObjectProvider<BedrockRuntimeClient> bedrockRuntimeClient,
ObjectProvider<BedrockRuntimeAsyncClient> bedrockRuntimeAsyncClient) {
ObjectProvider<BedrockRuntimeAsyncClient> bedrockRuntimeAsyncClient,
ObjectProvider<ToolExecutionEligibilityPredicate> bedrockToolExecutionEligibilityPredicate) {

var chatModel = BedrockProxyChatModel.builder()
.credentialsProvider(credentialsProvider)
Expand All @@ -79,6 +82,8 @@ public BedrockProxyChatModel bedrockProxyChatModel(AwsCredentialsProvider creden
.defaultOptions(chatProperties.getOptions())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.toolCallingManager(toolCallingManager)
.toolExecutionEligibilityPredicate(bedrockToolExecutionEligibilityPredicate
.getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate()))
.bedrockRuntimeClient(bedrockRuntimeClient.getIfAvailable())
.bedrockRuntimeAsyncClient(bedrockRuntimeAsyncClient.getIfAvailable())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import org.springframework.ai.model.SpringAIModels;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration;
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
import org.springframework.beans.factory.ObjectProvider;
Expand Down Expand Up @@ -69,7 +71,8 @@ public MistralAiChatModel mistralAiChatModel(MistralAiCommonProperties commonPro
MistralAiChatProperties chatProperties, ObjectProvider<RestClient.Builder> restClientBuilderProvider,
ToolCallingManager toolCallingManager, RetryTemplate retryTemplate,
ResponseErrorHandler responseErrorHandler, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention) {
ObjectProvider<ChatModelObservationConvention> observationConvention,
ObjectProvider<ToolExecutionEligibilityPredicate> mistralAiToolExecutionEligibilityPredicate) {

var mistralAiApi = mistralAiApi(chatProperties.getApiKey(), commonProperties.getApiKey(),
chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
Expand All @@ -79,6 +82,8 @@ public MistralAiChatModel mistralAiChatModel(MistralAiCommonProperties commonPro
.mistralAiApi(mistralAiApi)
.defaultOptions(chatProperties.getOptions())
.toolCallingManager(toolCallingManager)
.toolExecutionEligibilityPredicate(mistralAiToolExecutionEligibilityPredicate
.getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate()))
.retryTemplate(retryTemplate)
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import org.springframework.ai.model.SpringAIModels;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaApi;
Expand Down Expand Up @@ -64,14 +66,17 @@ public class OllamaChatAutoConfiguration {
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
OllamaInitializationProperties initProperties, ToolCallingManager toolCallingManager,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention) {
ObjectProvider<ChatModelObservationConvention> observationConvention,
ObjectProvider<ToolExecutionEligibilityPredicate> ollamaToolExecutionEligibilityPredicate) {
var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
: PullModelStrategy.NEVER;

var chatModel = OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(properties.getOptions())
.toolCallingManager(toolCallingManager)
.toolExecutionEligibilityPredicate(ollamaToolExecutionEligibilityPredicate
.getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate()))
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.modelManagementOptions(
new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import org.springframework.ai.model.SpringAIModels;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.api.OpenAiApi;
Expand Down Expand Up @@ -72,7 +74,8 @@ public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperti
ObjectProvider<WebClient.Builder> webClientBuilderProvider, ToolCallingManager toolCallingManager,
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention) {
ObjectProvider<ChatModelObservationConvention> observationConvention,
ObjectProvider<ToolExecutionEligibilityPredicate> openAiToolExecutionEligibilityPredicate) {

var openAiApi = openAiApi(chatProperties, commonProperties,
restClientBuilderProvider.getIfAvailable(RestClient::builder),
Expand All @@ -82,6 +85,8 @@ public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperti
.openAiApi(openAiApi)
.defaultOptions(chatProperties.getOptions())
.toolCallingManager(toolCallingManager)
.toolExecutionEligibilityPredicate(openAiToolExecutionEligibilityPredicate
.getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate()))
.retryTemplate(retryTemplate)
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.model.SpringAIModelProperties;
import org.springframework.ai.model.SpringAIModels;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration;
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
Expand Down Expand Up @@ -93,12 +95,15 @@ public VertexAI vertexAi(VertexAiGeminiConnectionProperties connectionProperties
public VertexAiGeminiChatModel vertexAiGeminiChat(VertexAI vertexAi, VertexAiGeminiChatProperties chatProperties,
ToolCallingManager toolCallingManager, ApplicationContext context, RetryTemplate retryTemplate,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention) {
ObjectProvider<ChatModelObservationConvention> observationConvention,
ObjectProvider<ToolExecutionEligibilityPredicate> vertexAiGeminiToolExecutionEligibilityPredicate) {

VertexAiGeminiChatModel chatModel = VertexAiGeminiChatModel.builder()
.vertexAI(vertexAi)
.defaultOptions(chatProperties.getOptions())
.toolCallingManager(toolCallingManager)
.toolExecutionEligibilityPredicate(vertexAiGeminiToolExecutionEligibilityPredicate
.getIfUnique(() -> new DefaultToolExecutionEligibilityPredicate()))
.retryTemplate(retryTemplate)
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class FunctionCallWithPromptFunctionIT {
void functionCallTest() {
this.contextRunner
.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model="
+ VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH_LIGHT.getValue())
+ VertexAiGeminiChatModel.ChatModel.GEMINI_2_5_PRO.getValue())
.run(context -> {

VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.content.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
Expand Down Expand Up @@ -124,6 +126,12 @@ public class AnthropicChatModel implements ChatModel {

private final ToolCallingManager toolCallingManager;

/**
* The tool execution eligibility predicate used to determine if a tool can be
* executed.
*/
private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;

/**
* Conventions to use for generating observations.
*/
Expand All @@ -132,18 +140,27 @@ public class AnthropicChatModel implements ChatModel {
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
ToolCallingManager toolCallingManager, RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) {
this(anthropicApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry,
new DefaultToolExecutionEligibilityPredicate());
}

public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry,
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {

Assert.notNull(anthropicApi, "anthropicApi cannot be null");
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
Assert.notNull(retryTemplate, "retryTemplate cannot be null");
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null");

this.anthropicApi = anthropicApi;
this.defaultOptions = defaultOptions;
this.toolCallingManager = toolCallingManager;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
}

@Override
Expand Down Expand Up @@ -184,8 +201,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
return chatResponse;
});

if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null
&& response.hasToolCalls()) {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
Expand Down Expand Up @@ -243,7 +259,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);

if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && chatResponse.hasToolCalls() && chatResponse.hasFinishReasons(Set.of("tool_use"))) {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of("tool_use"))) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.defer(() -> {
Expand Down Expand Up @@ -540,6 +556,8 @@ public static final class Builder {

private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate();

private Builder() {
}

Expand All @@ -563,6 +581,12 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
return this;
}

public Builder toolExecutionEligibilityPredicate(
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
return this;
}

public Builder observationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
Expand All @@ -574,7 +598,7 @@ public AnthropicChatModel build() {
observationRegistry);
}
return new AnthropicChatModel(anthropicApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,
observationRegistry);
observationRegistry, toolExecutionEligibilityPredicate);
}

}
Expand Down
Loading