Skip to content

Commit 72a271c

Browse files
committed
feat(tool): Add ToolExecutionEligibilityChecker interface
Introduce a new ToolExecutionEligibilityChecker interface to provide a more flexible way to determine when tool execution should be performed based on model responses. This abstraction replaces the hardcoded logic previously scattered across the codebase. - Adds a new ToolExecutionEligibilityChecker interface in spring-ai-core - Integrates the checker into OpenAiChatModel with appropriate defaults - Updates OpenAiChatAutoConfiguration to support the new interface - Provides a default implementation that maintains backward compatibility Signed-off-by: Christian Tzolov <[email protected]>
1 parent 8329402 commit 72a271c

File tree

3 files changed

+133
-23
lines changed

3 files changed

+133
-23
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiChatAutoConfiguration.java

+27-18
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
2626
import org.springframework.ai.model.function.FunctionCallbackResolver;
2727
import org.springframework.ai.model.tool.ToolCallingManager;
28+
import org.springframework.ai.model.tool.ToolExecutionEligibilityChecker;
2829
import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration;
2930
import org.springframework.ai.openai.OpenAiChatModel;
3031
import org.springframework.ai.openai.api.OpenAiApi;
@@ -59,8 +60,7 @@
5960
SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class })
6061
@ConditionalOnClass(OpenAiApi.class)
6162
@EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class })
62-
@ConditionalOnProperty(name = SpringAIModelProperties.CHAT_MODEL, havingValue = SpringAIModels.OPENAI,
63-
matchIfMissing = true)
63+
@ConditionalOnProperty(name = SpringAIModelProperties.CHAT_MODEL, havingValue = SpringAIModels.OPENAI, matchIfMissing = true)
6464
@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class,
6565
WebClientAutoConfiguration.class, ToolCallingAutoConfiguration.class })
6666
public class OpenAiChatAutoConfiguration {
@@ -72,25 +72,34 @@ public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperti
7272
ObjectProvider<WebClient.Builder> webClientBuilderProvider, ToolCallingManager toolCallingManager,
7373
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler,
7474
ObjectProvider<ObservationRegistry> observationRegistry,
75-
ObjectProvider<ChatModelObservationConvention> observationConvention) {
75+
ObjectProvider<ChatModelObservationConvention> observationConvention,
76+
ObjectProvider<ToolExecutionEligibilityChecker> openAiToolExecutionEligibilityChecker) {
7677

7778
var openAiApi = openAiApi(chatProperties, commonProperties,
7879
restClientBuilderProvider.getIfAvailable(RestClient::builder),
7980
webClientBuilderProvider.getIfAvailable(WebClient::builder), responseErrorHandler, "chat");
8081

8182
var chatModel = OpenAiChatModel.builder()
82-
.openAiApi(openAiApi)
83-
.defaultOptions(chatProperties.getOptions())
84-
.toolCallingManager(toolCallingManager)
85-
.retryTemplate(retryTemplate)
86-
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
87-
.build();
83+
.openAiApi(openAiApi)
84+
.defaultOptions(chatProperties.getOptions())
85+
.toolCallingManager(toolCallingManager)
86+
.toolExecutionEligibilityChecker(openAiToolExecutionEligibilityChecker
87+
.getIfUnique(() -> chatResponse -> chatResponse != null && chatResponse.hasToolCalls()))
88+
.retryTemplate(retryTemplate)
89+
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
90+
.build();
8891

8992
observationConvention.ifAvailable(chatModel::setObservationConvention);
9093

9194
return chatModel;
9295
}
9396

97+
@Bean("openAiToolExecutionEligibilityChecker")
98+
@ConditionalOnMissingBean
99+
public ToolExecutionEligibilityChecker openAiToolExecutionEligibilityChecker() {
100+
return chatResponse -> chatResponse != null && chatResponse.hasToolCalls();
101+
}
102+
94103
private OpenAiApi openAiApi(OpenAiChatProperties chatProperties, OpenAiConnectionProperties commonProperties,
95104
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
96105
ResponseErrorHandler responseErrorHandler, String modelType) {
@@ -99,15 +108,15 @@ private OpenAiApi openAiApi(OpenAiChatProperties chatProperties, OpenAiConnectio
99108
commonProperties, chatProperties, modelType);
100109

101110
return OpenAiApi.builder()
102-
.baseUrl(resolved.baseUrl())
103-
.apiKey(new SimpleApiKey(resolved.apiKey()))
104-
.headers(resolved.headers())
105-
.completionsPath(chatProperties.getCompletionsPath())
106-
.embeddingsPath(OpenAiEmbeddingProperties.DEFAULT_EMBEDDINGS_PATH)
107-
.restClientBuilder(restClientBuilder)
108-
.webClientBuilder(webClientBuilder)
109-
.responseErrorHandler(responseErrorHandler)
110-
.build();
111+
.baseUrl(resolved.baseUrl())
112+
.apiKey(new SimpleApiKey(resolved.apiKey()))
113+
.headers(resolved.headers())
114+
.completionsPath(chatProperties.getCompletionsPath())
115+
.embeddingsPath(OpenAiEmbeddingProperties.DEFAULT_EMBEDDINGS_PATH)
116+
.restClientBuilder(restClientBuilder)
117+
.webClientBuilder(webClientBuilder)
118+
.responseErrorHandler(responseErrorHandler)
119+
.build();
111120
}
112121

113122
@Bean

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

+26-5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.concurrent.ConcurrentHashMap;
2525
import java.util.stream.Collectors;
2626

27+
import io.micrometer.common.lang.NonNull;
2728
import io.micrometer.observation.Observation;
2829
import io.micrometer.observation.ObservationRegistry;
2930
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
@@ -60,6 +61,7 @@
6061
import org.springframework.ai.model.function.FunctionCallingOptions;
6162
import org.springframework.ai.model.tool.ToolCallingChatOptions;
6263
import org.springframework.ai.model.tool.ToolCallingManager;
64+
import org.springframework.ai.model.tool.ToolExecutionEligibilityChecker;
6365
import org.springframework.ai.model.tool.ToolExecutionResult;
6466
import org.springframework.ai.openai.api.OpenAiApi;
6567
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
@@ -136,23 +138,34 @@ public class OpenAiChatModel implements ChatModel {
136138

137139
private final ToolCallingManager toolCallingManager;
138140

141+
private final ToolExecutionEligibilityChecker toolExecutionEligibilityChecker;
142+
139143
/**
140144
* Conventions to use for generating observations.
141145
*/
142146
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
143147

144148
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager,
145149
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
150+
this(openAiApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry,
151+
chatResponse -> chatResponse != null && chatResponse.hasToolCalls());
152+
}
153+
154+
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager,
155+
RetryTemplate retryTemplate, ObservationRegistry observationRegistry,
156+
ToolExecutionEligibilityChecker toolExecutionEligibilityChecker) {
146157
Assert.notNull(openAiApi, "openAiApi cannot be null");
147158
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
148159
Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
149160
Assert.notNull(retryTemplate, "retryTemplate cannot be null");
150161
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
162+
Assert.notNull(toolExecutionEligibilityChecker, "toolExecutionEligibilityChecker cannot be null");
151163
this.openAiApi = openAiApi;
152164
this.defaultOptions = defaultOptions;
153165
this.toolCallingManager = toolCallingManager;
154166
this.retryTemplate = retryTemplate;
155167
this.observationRegistry = observationRegistry;
168+
this.toolExecutionEligibilityChecker = toolExecutionEligibilityChecker;
156169
}
157170

158171
@Override
@@ -221,8 +234,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
221234

222235
});
223236

224-
if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null
225-
&& response.hasToolCalls()) {
237+
if (toolExecutionEligibilityChecker.isToolExecutionRequired(prompt.getOptions(), response)) {
226238
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
227239
if (toolExecutionResult.returnDirect()) {
228240
// Return tool execution result directly to the client.
@@ -345,7 +357,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
345357

346358
// @formatter:off
347359
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
348-
if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) {
360+
if (toolExecutionEligibilityChecker.isToolExecutionRequired(prompt.getOptions(), response)) {
349361
return Flux.defer(() -> {
350362
// FIXME: bounded elastic needs to be used since tool calling
351363
// is currently only synchronous
@@ -684,6 +696,9 @@ public static final class Builder {
684696

685697
private ToolCallingManager toolCallingManager;
686698

699+
private ToolExecutionEligibilityChecker toolExecutionEligibilityChecker = chatResponse -> chatResponse != null
700+
&& chatResponse.hasToolCalls();
701+
687702
private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
688703

689704
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
@@ -706,6 +721,12 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
706721
return this;
707722
}
708723

724+
public Builder toolExecutionEligibilityChecker(
725+
ToolExecutionEligibilityChecker toolExecutionEligibilityChecker) {
726+
this.toolExecutionEligibilityChecker = toolExecutionEligibilityChecker;
727+
return this;
728+
}
729+
709730
public Builder retryTemplate(RetryTemplate retryTemplate) {
710731
this.retryTemplate = retryTemplate;
711732
return this;
@@ -719,10 +740,10 @@ public Builder observationRegistry(ObservationRegistry observationRegistry) {
719740
public OpenAiChatModel build() {
720741
if (toolCallingManager != null) {
721742
return new OpenAiChatModel(openAiApi, defaultOptions, toolCallingManager, retryTemplate,
722-
observationRegistry);
743+
observationRegistry, toolExecutionEligibilityChecker);
723744
}
724745
return new OpenAiChatModel(openAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,
725-
observationRegistry);
746+
observationRegistry, toolExecutionEligibilityChecker);
726747
}
727748

728749
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright 2025 - 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.model.tool;
17+
18+
import java.util.function.Function;
19+
20+
import org.springframework.ai.chat.model.ChatResponse;
21+
import org.springframework.ai.chat.prompt.ChatOptions;
22+
import org.springframework.ai.model.function.FunctionCallingOptions;
23+
import org.springframework.util.Assert;
24+
25+
/**
26+
* Interface for determining when tool execution should be performed based on model
27+
* responses.
28+
*
29+
* @author Christian Tzolov
30+
*/
31+
public interface ToolExecutionEligibilityChecker extends Function<ChatResponse, Boolean> {
32+
33+
/**
34+
* Determines if tool execution should be performed based on the prompt options and
35+
* chat response.
36+
* @param promptOptions The options from the prompt
37+
* @param chatResponse The response from the chat model
38+
* @return true if tool execution should be performed, false otherwise
39+
*/
40+
default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse chatResponse) {
41+
Assert.notNull(promptOptions, "promptOptions cannot be null");
42+
Assert.notNull(chatResponse, "chatResponse cannot be null");
43+
return this.isInternalToolExecutionEnabled(promptOptions) && this.isToolCallResponse(chatResponse);
44+
}
45+
46+
/**
47+
* Determines if the response is a tool call message response.
48+
* @param chatResponse The response from the chat model call
49+
* @return true if the response is a tool call message response, false otherwise
50+
*/
51+
default boolean isToolCallResponse(ChatResponse chatResponse) {
52+
Assert.notNull(chatResponse, "chatResponse cannot be null");
53+
return apply(chatResponse);
54+
}
55+
56+
/**
57+
* Determines if tool execution should be performed by the Spring AI or by the client.
58+
* @param chatOptions The options from the chat
59+
* @return true if tool execution should be performed by Spring AI, false if it should
60+
* be performed by the client
61+
*/
62+
default boolean isInternalToolExecutionEnabled(ChatOptions chatOptions) {
63+
64+
Assert.notNull(chatOptions, "chatOptions cannot be null");
65+
boolean internalToolExecutionEnabled;
66+
if (chatOptions instanceof ToolCallingChatOptions toolCallingChatOptions
67+
&& toolCallingChatOptions.isInternalToolExecutionEnabled() != null) {
68+
internalToolExecutionEnabled = Boolean.TRUE.equals(toolCallingChatOptions.isInternalToolExecutionEnabled());
69+
}
70+
else if (chatOptions instanceof FunctionCallingOptions functionCallingOptions
71+
&& functionCallingOptions.getProxyToolCalls() != null) {
72+
internalToolExecutionEnabled = Boolean.TRUE.equals(!functionCallingOptions.getProxyToolCalls());
73+
}
74+
else {
75+
internalToolExecutionEnabled = true;
76+
}
77+
return internalToolExecutionEnabled;
78+
}
79+
80+
}

0 commit comments

Comments
 (0)