24
24
import java .util .concurrent .ConcurrentHashMap ;
25
25
import java .util .stream .Collectors ;
26
26
27
+ import io .micrometer .common .lang .NonNull ;
27
28
import io .micrometer .observation .Observation ;
28
29
import io .micrometer .observation .ObservationRegistry ;
29
30
import io .micrometer .observation .contextpropagation .ObservationThreadLocalAccessor ;
60
61
import org .springframework .ai .model .function .FunctionCallingOptions ;
61
62
import org .springframework .ai .model .tool .ToolCallingChatOptions ;
62
63
import org .springframework .ai .model .tool .ToolCallingManager ;
64
+ import org .springframework .ai .model .tool .ToolExecutionEligibilityChecker ;
63
65
import org .springframework .ai .model .tool .ToolExecutionResult ;
64
66
import org .springframework .ai .openai .api .OpenAiApi ;
65
67
import org .springframework .ai .openai .api .OpenAiApi .ChatCompletion ;
@@ -136,23 +138,34 @@ public class OpenAiChatModel implements ChatModel {
136
138
137
139
private final ToolCallingManager toolCallingManager ;
138
140
141
+ private final ToolExecutionEligibilityChecker toolExecutionEligibilityChecker ;
142
+
139
143
/**
140
144
* Conventions to use for generating observations.
141
145
*/
142
146
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION ;
143
147
144
148
public OpenAiChatModel (OpenAiApi openAiApi , OpenAiChatOptions defaultOptions , ToolCallingManager toolCallingManager ,
145
149
RetryTemplate retryTemplate , ObservationRegistry observationRegistry ) {
150
+ this (openAiApi , defaultOptions , toolCallingManager , retryTemplate , observationRegistry ,
151
+ chatResponse -> chatResponse != null && chatResponse .hasToolCalls ());
152
+ }
153
+
154
+ public OpenAiChatModel (OpenAiApi openAiApi , OpenAiChatOptions defaultOptions , ToolCallingManager toolCallingManager ,
155
+ RetryTemplate retryTemplate , ObservationRegistry observationRegistry ,
156
+ ToolExecutionEligibilityChecker toolExecutionEligibilityChecker ) {
146
157
Assert .notNull (openAiApi , "openAiApi cannot be null" );
147
158
Assert .notNull (defaultOptions , "defaultOptions cannot be null" );
148
159
Assert .notNull (toolCallingManager , "toolCallingManager cannot be null" );
149
160
Assert .notNull (retryTemplate , "retryTemplate cannot be null" );
150
161
Assert .notNull (observationRegistry , "observationRegistry cannot be null" );
162
+ Assert .notNull (toolExecutionEligibilityChecker , "toolExecutionEligibilityChecker cannot be null" );
151
163
this .openAiApi = openAiApi ;
152
164
this .defaultOptions = defaultOptions ;
153
165
this .toolCallingManager = toolCallingManager ;
154
166
this .retryTemplate = retryTemplate ;
155
167
this .observationRegistry = observationRegistry ;
168
+ this .toolExecutionEligibilityChecker = toolExecutionEligibilityChecker ;
156
169
}
157
170
158
171
@ Override
@@ -221,8 +234,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
221
234
222
235
});
223
236
224
- if (ToolCallingChatOptions .isInternalToolExecutionEnabled (prompt .getOptions ()) && response != null
225
- && response .hasToolCalls ()) {
237
+ if (toolExecutionEligibilityChecker .isToolExecutionRequired (prompt .getOptions (), response )) {
226
238
var toolExecutionResult = this .toolCallingManager .executeToolCalls (prompt , response );
227
239
if (toolExecutionResult .returnDirect ()) {
228
240
// Return tool execution result directly to the client.
@@ -345,7 +357,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
345
357
346
358
// @formatter:off
347
359
Flux <ChatResponse > flux = chatResponse .flatMap (response -> {
348
- if (ToolCallingChatOptions . isInternalToolExecutionEnabled (prompt .getOptions ()) && response . hasToolCalls ( )) {
360
+ if (toolExecutionEligibilityChecker . isToolExecutionRequired (prompt .getOptions (), response )) {
349
361
return Flux .defer (() -> {
350
362
// FIXME: bounded elastic needs to be used since tool calling
351
363
// is currently only synchronous
@@ -684,6 +696,9 @@ public static final class Builder {
684
696
685
697
private ToolCallingManager toolCallingManager ;
686
698
699
+ private ToolExecutionEligibilityChecker toolExecutionEligibilityChecker = chatResponse -> chatResponse != null
700
+ && chatResponse .hasToolCalls ();
701
+
687
702
private RetryTemplate retryTemplate = RetryUtils .DEFAULT_RETRY_TEMPLATE ;
688
703
689
704
private ObservationRegistry observationRegistry = ObservationRegistry .NOOP ;
@@ -706,6 +721,12 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
706
721
return this ;
707
722
}
708
723
724
+ public Builder toolExecutionEligibilityChecker (
725
+ ToolExecutionEligibilityChecker toolExecutionEligibilityChecker ) {
726
+ this .toolExecutionEligibilityChecker = toolExecutionEligibilityChecker ;
727
+ return this ;
728
+ }
729
+
709
730
public Builder retryTemplate (RetryTemplate retryTemplate ) {
710
731
this .retryTemplate = retryTemplate ;
711
732
return this ;
@@ -719,10 +740,10 @@ public Builder observationRegistry(ObservationRegistry observationRegistry) {
719
740
public OpenAiChatModel build () {
720
741
if (toolCallingManager != null ) {
721
742
return new OpenAiChatModel (openAiApi , defaultOptions , toolCallingManager , retryTemplate ,
722
- observationRegistry );
743
+ observationRegistry , toolExecutionEligibilityChecker );
723
744
}
724
745
return new OpenAiChatModel (openAiApi , defaultOptions , DEFAULT_TOOL_CALLING_MANAGER , retryTemplate ,
725
- observationRegistry );
746
+ observationRegistry , toolExecutionEligibilityChecker );
726
747
}
727
748
728
749
}
0 commit comments