diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 270f3bef43d..c97f0aae085 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -90,6 +90,7 @@ * @author Alexandros Pappas * @author Jonghoon Park * @author Soby Chacko + * @author lambochen * @since 1.0.0 */ public class AnthropicChatModel implements ChatModel { @@ -174,6 +175,10 @@ public ChatResponse call(Prompt prompt) { } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return this.internalCall(prompt, previousChatResponse, 1); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { ChatCompletionRequest request = createRequest(prompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -203,7 +208,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -215,7 +220,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } @@ -236,6 +241,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return this.internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); @@ -260,7 +269,8 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of("tool_use"))) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, iterations) + && chatResponse.hasFinishReasons(Set.of("tool_use"))) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.defer(() -> { @@ -274,7 +284,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha else { // Send the tool execution result back to the model. return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - chatResponse); + chatResponse, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } @@ -437,6 +447,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setInternalToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(), + defaultOptions.getInternalToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -447,6 +460,8 @@ Prompt buildRequestPrompt(Prompt prompt) { else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions + .setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index dbfbee561c8..030ba12f977 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -44,6 +44,7 @@ * @author Thomas Vitale * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author lambochen * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -79,6 +80,9 @@ public class AnthropicChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + @JsonIgnore private Map toolContext = new HashMap<>(); @@ -109,6 +113,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .internalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null) .build(); @@ -226,6 +231,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getInternalToolExecutionMaxIterations() { + return this.internalToolExecutionMaxIterations; + } + + @Override + public void setInternalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + } + @Override @JsonIgnore public Double getFrequencyPenalty() { @@ -281,6 +296,7 @@ public boolean equals(Object o) { && Objects.equals(this.toolCallbacks, that.toolCallbacks) && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.internalToolExecutionMaxIterations, that.internalToolExecutionMaxIterations) && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.httpHeaders, that.httpHeaders); } @@ -289,7 +305,7 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP, this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext, this.httpHeaders); + this.internalToolExecutionMaxIterations, this.toolContext, this.httpHeaders); } public static class Builder { @@ -374,6 +390,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder internalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java index 62d97b459e4..2364e864f64 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.Metadata; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import static org.assertj.core.api.Assertions.assertThat; @@ -29,6 +30,7 @@ * Tests for {@link AnthropicChatOptions}. * * @author Alexandros Pappas + * @author lambochen */ class AnthropicChatOptionsTests { @@ -42,10 +44,13 @@ void testBuilderWithAllFields() { .topP(0.8) .topK(50) .metadata(new Metadata("userId_123")) + .internalToolExecutionMaxIterations(3) .build(); - assertThat(options).extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata") - .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123")); + assertThat(options) + .extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata", + "internalToolExecutionMaxIterations") + .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"), 3); } @Test @@ -59,6 +64,7 @@ void testCopy() { .topK(50) .metadata(new Metadata("userId_123")) .toolContext(Map.of("key1", "value1")) + .internalToolExecutionMaxIterations(3) .build(); AnthropicChatOptions copied = original.copy(); @@ -67,6 +73,8 @@ void testCopy() { // Ensure deep copy assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + + assertThat(copied.getInternalToolExecutionMaxIterations()).isEqualTo(3); } @Test @@ -79,6 +87,7 @@ void testSetters() { options.setTopP(0.8); options.setStopSequences(List.of("stop1", "stop2")); options.setMetadata(new Metadata("userId_123")); + options.setInternalToolExecutionMaxIterations(3); assertThat(options.getModel()).isEqualTo("test-model"); assertThat(options.getMaxTokens()).isEqualTo(100); @@ -87,6 +96,7 @@ void testSetters() { assertThat(options.getTopP()).isEqualTo(0.8); assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); assertThat(options.getMetadata()).isEqualTo(new Metadata("userId_123")); + assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3); } @Test @@ -99,6 +109,8 @@ void testDefaultValues() { assertThat(options.getTopP()).isNull(); assertThat(options.getStopSequences()).isNull(); assertThat(options.getMetadata()).isNull(); + assertThat(options.getInternalToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 1933f575300..21c0d36a0e1 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -122,8 +122,10 @@ * @author Berjan Jonker * @author Andres da Silva Santos * @author Bart Veenstra + * @author lambochen * @see ChatModel * @see com.azure.ai.openai.OpenAIClient + * @see ToolCallingChatOptions * @since 1.0.0 */ public class AzureOpenAiChatModel implements ChatModel { @@ -251,6 +253,10 @@ public ChatResponse call(Prompt prompt) { } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return internalCall(prompt, previousChatResponse, 1); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) @@ -270,7 +276,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -282,7 +288,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } @@ -298,6 +304,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return this.internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); @@ -377,7 +387,8 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha }); return chatResponseFlux.flatMap(chatResponse -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, + iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.defer(() -> { @@ -393,7 +404,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // Send the tool execution result back to the model. return this.internalStream( new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - chatResponse); + chatResponse, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } @@ -666,6 +677,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + runtimeOptions.setInternalToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(), + this.defaultOptions.getInternalToolExecutionMaxIterations())); requestOptions.setStreamUsage(ModelOptionsUtils.mergeOption(runtimeOptions.getStreamUsage(), this.defaultOptions.getStreamUsage())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), @@ -677,6 +691,8 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions + .setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations()); requestOptions.setStreamUsage(this.defaultOptions.getStreamUsage()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index da442b4ad4d..9f130b47e5d 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -48,6 +48,7 @@ * @author Ilayaperumal Gopinathan * @author Alexandros Pappas * @author Andres da Silva Santos + * @author lambochen */ @JsonInclude(Include.NON_NULL) public class AzureOpenAiChatOptions implements ToolCallingChatOptions { @@ -200,6 +201,9 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + /** * Whether to include token usage information in streaming chat completion responses. * Only applies to streaming responses. @@ -257,6 +261,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getInternalToolExecutionMaxIterations() { + return this.internalToolExecutionMaxIterations; + } + + @Override + public void setInternalToolExecutionMaxIterations(Integer internalToolExecutionMaxIterations) { + this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + } + public static Builder builder() { return new Builder(); } @@ -284,6 +298,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .enhancements(fromOptions.getEnhancements()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .internalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations()) .streamOptions(fromOptions.getStreamOptions()) .toolCallbacks( fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) @@ -504,6 +519,7 @@ public boolean equals(Object o) { && Objects.equals(this.toolCallbacks, that.toolCallbacks) && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.internalToolExecutionMaxIterations, that.internalToolExecutionMaxIterations) && Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.topLogProbs, that.topLogProbs) && Objects.equals(this.enhancements, that.enhancements) && Objects.equals(this.streamOptions, that.streamOptions) @@ -518,10 +534,10 @@ public boolean equals(Object o) { @Override public int hashCode() { return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat, - this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs, - this.topLogProbs, this.enhancements, this.streamOptions, this.reasoningEffort, this.enableStreamUsage, - this.toolContext, this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature, - this.topP); + this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, + this.internalToolExecutionMaxIterations, this.seed, this.logprobs, this.topLogProbs, this.enhancements, + this.streamOptions, this.reasoningEffort, this.enableStreamUsage, this.toolContext, this.maxTokens, + this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP); } public static class Builder { @@ -664,6 +680,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder internalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations); + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java index 789635d358e..82be30a85d6 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java @@ -24,6 +24,7 @@ import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; import com.azure.ai.openai.models.ChatCompletionStreamOptions; import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import static org.assertj.core.api.Assertions.assertThat; @@ -31,6 +32,7 @@ * Tests for {@link AzureOpenAiChatOptions}. * * @author Alexandros Pappas + * @author lambochen */ class AzureOpenAiChatOptionsTests { @@ -65,15 +67,16 @@ void testBuilderWithAllFields() { .topLogprobs(5) .enhancements(enhancements) .streamOptions(streamOptions) + .internalToolExecutionMaxIterations(3) .build(); assertThat(options) .extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop", "temperature", "topP", "user", "responseFormat", "streamUsage", "reasoningEffort", "seed", - "logprobs", "topLogProbs", "enhancements", "streamOptions") + "logprobs", "topLogProbs", "enhancements", "streamOptions", "internalToolExecutionMaxIterations") .containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8, List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, "low", 12345L, true, 5, - enhancements, streamOptions); + enhancements, streamOptions, 3); } @Test @@ -107,6 +110,7 @@ void testCopy() { .topLogprobs(5) .enhancements(enhancements) .streamOptions(streamOptions) + .internalToolExecutionMaxIterations(3) .build(); AzureOpenAiChatOptions copiedOptions = originalOptions.copy(); @@ -115,6 +119,8 @@ void testCopy() { // Ensure deep copy assertThat(copiedOptions.getStop()).isNotSameAs(originalOptions.getStop()); assertThat(copiedOptions.getToolContext()).isNotSameAs(originalOptions.getToolContext()); + + assertThat(copiedOptions.getInternalToolExecutionMaxIterations()).isEqualTo(3); } @Test @@ -145,6 +151,7 @@ void testSetters() { options.setTopLogProbs(5); options.setEnhancements(enhancements); options.setStreamOptions(streamOptions); + options.setInternalToolExecutionMaxIterations(3); assertThat(options.getDeploymentName()).isEqualTo("test-deployment"); options.setModel("test-model"); @@ -168,6 +175,7 @@ void testSetters() { assertThat(options.getEnhancements()).isEqualTo(enhancements); assertThat(options.getStreamOptions()).isEqualTo(streamOptions); assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3); } @Test @@ -193,6 +201,8 @@ void testDefaultValues() { assertThat(options.getEnhancements()).isNull(); assertThat(options.getStreamOptions()).isNull(); assertThat(options.getModel()).isNull(); + assertThat(options.getInternalToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); } } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index d30f2517756..c17e04b565c 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -133,6 +133,7 @@ * @author Alexandros Pappas * @author Jihoon Kim * @author Soby Chacko + * @author lambochen * @since 1.0.0 */ public class BedrockProxyChatModel implements ChatModel { @@ -217,6 +218,10 @@ public ChatResponse call(Prompt prompt) { } private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatResponse) { + return this.internalCall(prompt, perviousChatResponse, 1); + } + + private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatResponse, int iterations) { ConverseRequest converseRequest = this.createRequest(prompt); @@ -241,8 +246,8 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon return response; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) - && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, + iterations) && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -254,7 +259,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - chatResponse); + chatResponse, iterations + 1); } } return chatResponse; @@ -310,6 +315,9 @@ Prompt buildRequestPrompt(Prompt prompt) { .internalToolExecutionEnabled(runtimeOptions.getInternalToolExecutionEnabled() != null ? runtimeOptions.getInternalToolExecutionEnabled() : this.defaultOptions.getInternalToolExecutionEnabled()) + .internalToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(), + this.defaultOptions.getInternalToolExecutionMaxIterations())) .build(); } @@ -644,6 +652,10 @@ public Flux stream(Prompt prompt) { } private Flux internalStream(Prompt prompt, ChatResponse perviousChatResponse) { + return this.internalStream(prompt, perviousChatResponse, 1); + } + + private Flux internalStream(Prompt prompt, ChatResponse perviousChatResponse, int iterations) { Assert.notNull(prompt, "'prompt' must not be null"); return Flux.deferContextual(contextView -> { @@ -676,8 +688,8 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh Flux chatResponseFlux = chatResponses.switchMap(chatResponse -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) - && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, + iterations) && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous @@ -695,7 +707,7 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh // Send the tool execution result back to the model. return this.internalStream( new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - chatResponse); + chatResponse, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java index 4b7607c6e38..5dd16312dda 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -75,6 +75,7 @@ * backed by {@link DeepSeekApi}. * * @author Geng Rong + * @author lambochen */ public class DeepSeekChatModel implements ChatModel { @@ -151,6 +152,10 @@ public ChatResponse call(Prompt prompt) { } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return internalCall(prompt, previousChatResponse, 1); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { ChatCompletionRequest request = createRequest(prompt, false); @@ -205,7 +210,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -217,7 +222,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } @@ -231,6 +236,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); @@ -285,7 +294,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { return Flux.defer(() -> { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous @@ -299,7 +308,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha else { // Send the tool execution result back to the model. return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } @@ -397,6 +406,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setInternalToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(), + this.defaultOptions.getInternalToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -406,6 +418,8 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions + .setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java index 0731a1eb6cc..7728a595092 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java @@ -43,6 +43,7 @@ * chat completion * * @author Geng Rong + * @author lambochen */ @JsonInclude(Include.NON_NULL) public class DeepSeekChatOptions implements ToolCallingChatOptions { @@ -122,6 +123,9 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + /** * Tool Function Callbacks to register with the ChatModel. * For Prompt Options the toolCallbacks are automatically enabled for the duration of the prompt execution. @@ -289,6 +293,18 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + @JsonIgnore + public Integer getInternalToolExecutionMaxIterations() { + return this.internalToolExecutionMaxIterations; + } + + @Override + @JsonIgnore + public void setInternalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + } + public Boolean getLogprobs() { return this.logprobs; } @@ -332,7 +348,9 @@ public int hashCode() { return Objects.hash(this.model, this.frequencyPenalty, this.logprobs, this.topLogprobs, this.maxTokens, this.presencePenalty, this.responseFormat, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, - this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.toolContext); + this.toolCallbacks, this.toolNames, + this.internalToolExecutionEnabled, this.internalToolExecutionMaxIterations, + this.toolContext); } @@ -357,7 +375,9 @@ public boolean equals(Object o) { && Objects.equals(this.toolCallbacks, other.toolCallbacks) && Objects.equals(this.toolNames, other.toolNames) && Objects.equals(this.toolContext, other.toolContext) - && Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled); + && Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled) + && Objects.equals(this.internalToolExecutionMaxIterations, other.internalToolExecutionMaxIterations) + ; } public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) { @@ -378,6 +398,7 @@ public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) { fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .internalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } @@ -487,6 +508,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder internalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatOptionsTest.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatOptionsTest.java new file mode 100644 index 00000000000..cd6e8854806 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatOptionsTest.java @@ -0,0 +1,38 @@ +package org.springframework.ai.deepseek; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * @author lambochen + */ +class DeepSeekChatOptionsTest { + + @Test + void fromOptions() { + var original = new DeepSeekChatOptions(); + original.setInternalToolExecutionMaxIterations(3); + + var copy = DeepSeekChatOptions.fromOptions(original); + assertNotSame(original, copy); + assertSame(original.getInternalToolExecutionMaxIterations(), copy.getInternalToolExecutionMaxIterations()); + } + + @Test + void optionsDefault() { + var options = new DeepSeekChatOptions(); + + assertEquals(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS, + options.getInternalToolExecutionMaxIterations()); + } + + @Test + void optionsBuilder() { + var options = DeepSeekChatOptions.builder().internalToolExecutionMaxIterations(3).build(); + + assertEquals(3, options.getInternalToolExecutionMaxIterations()); + } + +} diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index e5a774cacf9..c9d4acaff5c 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -79,9 +79,11 @@ * @author Geng Rong * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author lambochen * @see ChatModel * @see StreamingChatModel * @see MiniMaxApi + * @see ToolCallingChatOptions * @since 1.0.0 M1 */ public class MiniMaxChatModel implements ChatModel { @@ -236,6 +238,10 @@ public ChatResponse call(Prompt prompt) { // Before moving any further, build the final request Prompt, // merging runtime and default options. Prompt requestPrompt = buildRequestPrompt(prompt); + return internalCall(requestPrompt, 1); + } + + private ChatResponse internalCall(Prompt requestPrompt, int iterations) { ChatCompletionRequest request = createRequest(requestPrompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -292,7 +298,8 @@ else if (!CollectionUtils.isEmpty(choice.messages())) { return chatResponse; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response, + iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -303,7 +310,9 @@ else if (!CollectionUtils.isEmpty(choice.messages())) { } else { // Send the tool execution result back to the model. - return this.call(new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions())); + return this.internalCall( + new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions()), + iterations + 1); } } @@ -320,6 +329,10 @@ public Flux stream(Prompt prompt) { // Before moving any further, build the final request Prompt, // merging runtime and default options. Prompt requestPrompt = buildRequestPrompt(prompt); + return internalStream(requestPrompt, 1); + } + + private Flux internalStream(Prompt requestPrompt, int iterations) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(requestPrompt, true); @@ -361,15 +374,15 @@ public Flux stream(Prompt prompt) { return buildGeneration(choice, metadata); }).toList(); return new ChatResponse(generations, from(chatCompletion2)); - } - catch (Exception e) { + } + catch (Exception e) { logger.error("Error processing chat completion", e); return new ChatResponse(List.of()); } })); Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response, iterations)) { return Flux.defer(() -> { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous @@ -382,7 +395,7 @@ public Flux stream(Prompt prompt) { } else { // Send the tool execution result back to the model. - return this.stream(new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions())); + return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions()), iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } @@ -472,6 +485,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setInternalToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(), + this.defaultOptions.getInternalToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -481,6 +497,8 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions + .setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index 9d2614396c5..84369a5f736 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -45,6 +45,7 @@ * @author Geng Rong * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author lambochen * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) @@ -153,6 +154,9 @@ public class MiniMaxChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + // @formatter:on public static Builder builder() { @@ -176,6 +180,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .toolCallbacks(fromOptions.getToolCallbacks()) .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .internalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext()) .build(); } @@ -349,6 +354,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getInternalToolExecutionMaxIterations() { + return this.internalToolExecutionMaxIterations; + } + + @Override + public void setInternalToolExecutionMaxIterations(Integer internalToolExecutionMaxIterations) { + this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + } + @Override public Map getToolContext() { return this.toolContext; @@ -380,6 +395,9 @@ public int hashCode() { result = prime * result + ((this.toolNames == null) ? 0 : this.toolNames.hashCode()); result = prime * result + ((this.internalToolExecutionEnabled == null) ? 0 : this.internalToolExecutionEnabled.hashCode()); + result = prime * result + ((this.internalToolExecutionMaxIterations == null) + ? ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS + : this.internalToolExecutionMaxIterations.hashCode()); result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); return result; } @@ -509,6 +527,15 @@ else if (!this.internalToolExecutionEnabled.equals(other.internalToolExecutionEn return false; } + if (this.internalToolExecutionMaxIterations == null) { + if (other.internalToolExecutionMaxIterations != null) { + return false; + } + } + else if (!this.internalToolExecutionMaxIterations.equals(other.internalToolExecutionMaxIterations)) { + return false; + } + if (this.toolNames == null) { if (other.toolNames != null) { return false; @@ -649,6 +676,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder internalToolExecutionMaxIterations(Integer internalToolExecutionMaxIterations) { + this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java index fa3a0489409..7371e7373c0 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java @@ -25,6 +25,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -116,4 +117,26 @@ void testToolCallingStream() { assertThat(content).contains("15"); } + @Test + void testOptionsDefaultValue() { + var options = new MiniMaxChatOptions(); + + assertThat(options.getInternalToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); + } + + @Test + void testOptionsSetter() { + var options = new MiniMaxChatOptions(); + options.setInternalToolExecutionMaxIterations(3); + assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3); + } + + @Test + void testOptionsBuilder() { + var options = MiniMaxChatOptions.builder().internalToolExecutionMaxIterations(3).build(); + + assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3); + } + } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index b9838dcedf1..c87eb0afa4a 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -83,7 +83,9 @@ * @author luocongqiu * @author Ilayaperumal Gopinathan * @author Alexandros Pappas + * @author lambochen * @since 1.0.0 + * @see ToolCallingChatOptions */ public class MistralAiChatModel implements ChatModel { @@ -181,6 +183,10 @@ public ChatResponse call(Prompt prompt) { } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return internalCall(prompt, previousChatResponse, 1); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { MistralAiApi.ChatCompletionRequest request = createRequest(prompt, false); @@ -225,7 +231,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -237,7 +243,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } @@ -253,6 +259,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { var request = createRequest(prompt, true); @@ -313,7 +323,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux chatResponseFlux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.defer(() -> { @@ -327,7 +337,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha else { // Send the tool execution result back to the model. return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } @@ -394,6 +404,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setInternalToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(), + this.defaultOptions.getInternalToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -403,6 +416,8 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions + .setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 2b392d5176a..8fb26aab885 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -45,6 +45,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Alexandros Pappas + * @author lambochen * @since 0.8.1 */ @JsonInclude(JsonInclude.Include.NON_NULL) @@ -156,6 +157,9 @@ public class MistralAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + @JsonIgnore private Map toolContext = new HashMap<>(); @@ -180,6 +184,7 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .toolCallbacks(fromOptions.getToolCallbacks()) .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .internalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext()) .build(); } @@ -347,6 +352,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getInternalToolExecutionMaxIterations() { + return this.internalToolExecutionMaxIterations; + } + + @Override + public void setInternalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + } + @Override @JsonIgnore public Integer getTopK() { @@ -374,7 +389,8 @@ public MistralAiChatOptions copy() { public int hashCode() { return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed, this.responseFormat, this.stop, this.frequencyPenalty, this.presencePenalty, this.n, this.tools, - this.toolChoice, this.toolCallbacks, this.tools, this.internalToolExecutionEnabled, this.toolContext); + this.toolChoice, this.toolCallbacks, this.tools, this.internalToolExecutionEnabled, + this.internalToolExecutionMaxIterations, this.toolContext); } @Override @@ -400,6 +416,7 @@ public boolean equals(Object obj) { && Objects.equals(this.toolCallbacks, other.toolCallbacks) && Objects.equals(this.toolNames, other.toolNames) && Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled) + && Objects.equals(this.internalToolExecutionMaxIterations, other.internalToolExecutionMaxIterations) && Objects.equals(this.toolContext, other.toolContext); } @@ -505,6 +522,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder internalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java index e6bf2490cc0..8cd0bc6478c 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java @@ -35,6 +35,7 @@ * @author Ricken Bazolo * @author Alexandros Pappas * @author Thomas Vitale + * @author lambochen * @since 0.8.1 */ @SpringBootTest(classes = MistralAiTestConfiguration.class) @@ -73,6 +74,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { MistralAiChatOptions defaultOptions = MistralAiChatOptions.builder() .model("DEFAULT_MODEL") .internalToolExecutionEnabled(true) + .internalToolExecutionMaxIterations(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS) .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) .toolNames("tool1", "tool2") .toolContext(Map.of("key1", "value1", "key2", "valueA")) @@ -85,6 +87,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { MistralAiChatOptions runtimeOptions = MistralAiChatOptions.builder() .internalToolExecutionEnabled(false) + .internalToolExecutionMaxIterations(3) .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) .toolNames("tool3") .toolContext(Map.of("key2", "valueB")) @@ -93,6 +96,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionMaxIterations()).isEqualTo(3); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() .stream() diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java new file mode 100644 index 00000000000..2b35d0cb989 --- /dev/null +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java @@ -0,0 +1,37 @@ +package org.springframework.ai.mistralai; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author lambochen + */ +class MistralAiChatOptionsTests { + + @Test + void testOptionsDefault() { + var options = new MistralAiChatOptions(); + + assertThat(options.getInternalToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); + } + + @Test + void testOptionsCustom() { + var options = new MistralAiChatOptions(); + + options.setInternalToolExecutionMaxIterations(3); + + assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3); + } + + @Test + void testBuilder() { + var options = MistralAiChatOptions.builder().internalToolExecutionMaxIterations(3).build(); + + assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3); + } + +} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 8d22df6ddcc..2a9393a068c 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -85,7 +85,9 @@ * @author Alexandros Pappas * @author Ilayaperumal Gopinathan * @author Sun Yuhan + * @author lambochen * @since 1.0.0 + * @see ToolCallingChatOptions */ public class OllamaChatModel implements ChatModel { @@ -220,10 +222,10 @@ public ChatResponse call(Prompt prompt) { // Before moving any further, build the final request Prompt, // merging runtime and default options. Prompt requestPrompt = buildRequestPrompt(prompt); - return this.internalCall(requestPrompt, null); + return this.internalCall(requestPrompt, null, 1); } - private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false); @@ -266,7 +268,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -278,7 +280,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } @@ -290,10 +292,10 @@ public Flux stream(Prompt prompt) { // Before moving any further, build the final request Prompt, // merging runtime and default options. Prompt requestPrompt = buildRequestPrompt(prompt); - return this.internalStream(requestPrompt, null); + return this.internalStream(requestPrompt, null, 1); } - private Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + private Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { OllamaApi.ChatRequest request = ollamaChatRequest(prompt, true); @@ -338,7 +340,7 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh // @formatter:off Flux chatResponseFlux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.defer(() -> { @@ -352,7 +354,7 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh else { // Send the tool execution result back to the model. return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } @@ -394,6 +396,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setInternalToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(), + this.defaultOptions.getInternalToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -403,6 +408,8 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions + .setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index a71be1ce2b2..68ce993b3bf 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -44,6 +44,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author lambochen * @since 0.8.0 * @see Ollama @@ -321,6 +322,9 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + /** * Tool Function Callbacks to register with the ChatModel. * For Prompt Options the toolCallbacks are automatically enabled for the duration of the prompt execution. @@ -397,6 +401,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .stop(fromOptions.getStop()) .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .internalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations()) .toolCallbacks(fromOptions.getToolCallbacks()) .toolContext(fromOptions.getToolContext()).build(); } @@ -746,6 +751,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getInternalToolExecutionMaxIterations() { + return this.internalToolExecutionMaxIterations; + } + + @Override + public void setInternalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + } + @Override @JsonIgnore public Integer getDimensions() { @@ -809,6 +824,7 @@ public boolean equals(Object o) { && Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop) && Objects.equals(this.toolCallbacks, that.toolCallbacks) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.internalToolExecutionMaxIterations, that.internalToolExecutionMaxIterations) && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.toolContext, that.toolContext); } @@ -820,7 +836,7 @@ public int hashCode() { this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext); + this.internalToolExecutionMaxIterations, this.toolContext); } public static class Builder { @@ -1029,6 +1045,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder internalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index 59baa37bec2..1892db91aa9 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -47,6 +47,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { OllamaOptions defaultOptions = OllamaOptions.builder() .model("MODEL_NAME") .internalToolExecutionEnabled(true) + .internalToolExecutionMaxIterations(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS) .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) .toolNames("tool1", "tool2") .toolContext(Map.of("key1", "value1", "key2", "valueA")) @@ -58,6 +59,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { OllamaOptions runtimeOptions = OllamaOptions.builder() .internalToolExecutionEnabled(false) + .internalToolExecutionMaxIterations(3) .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) .toolNames("tool3") .toolContext(Map.of("key2", "valueB")) @@ -66,6 +68,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionMaxIterations()).isEqualTo(3); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() .stream() diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 7da34176c15..64262522ded 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -105,9 +105,11 @@ * @author Alexandros Pappas * @author Soby Chacko * @author Jonghoon Park + * @author lambochen * @see ChatModel * @see StreamingChatModel * @see OpenAiApi + * @see ToolCallingChatOptions */ public class OpenAiChatModel implements ChatModel { @@ -182,6 +184,10 @@ public ChatResponse call(Prompt prompt) { } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return internalCall(prompt, previousChatResponse, 1); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { ChatCompletionRequest request = createRequest(prompt, false); @@ -240,7 +246,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -252,7 +258,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } @@ -268,6 +274,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); @@ -362,7 +372,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { return Flux.defer(() -> { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous @@ -376,7 +386,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha else { // Send the tool execution result back to the model. return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } @@ -520,6 +530,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setInternalToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(), + this.defaultOptions.getInternalToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -530,6 +543,8 @@ Prompt buildRequestPrompt(Prompt prompt) { else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions + .setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index a1a9fede77e..16f1c572e1d 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -49,6 +49,7 @@ * @author Mariusz Bernacki * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author lambochen * @since 0.8.0 */ @JsonInclude(Include.NON_NULL) @@ -218,6 +219,9 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + /** * Optional HTTP headers to be added to the chat completion request. */ @@ -262,6 +266,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .internalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .store(fromOptions.getStore()) .metadata(fromOptions.getMetadata()) @@ -504,6 +509,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getInternalToolExecutionMaxIterations() { + return this.internalToolExecutionMaxIterations; + } + + @Override + public void setInternalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + } + public Map getHttpHeaders() { return this.httpHeaders; } @@ -573,8 +588,9 @@ public int hashCode() { this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, this.responseFormat, this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders, - this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio, - this.store, this.metadata, this.reasoningEffort, this.webSearchOptions); + this.internalToolExecutionEnabled, this.internalToolExecutionMaxIterations, this.toolContext, + this.outputModalities, this.outputAudio, this.store, this.metadata, this.reasoningEffort, + this.webSearchOptions); } @Override @@ -603,6 +619,7 @@ public boolean equals(Object o) { && Objects.equals(this.httpHeaders, other.httpHeaders) && Objects.equals(this.toolContext, other.toolContext) && Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled) + && Objects.equals(this.internalToolExecutionMaxIterations, other.internalToolExecutionMaxIterations) && Objects.equals(this.outputModalities, other.outputModalities) && Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store) && Objects.equals(this.metadata, other.metadata) @@ -765,6 +782,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder internalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations); + return this; + } + public Builder httpHeaders(Map httpHeaders) { this.options.httpHeaders = httpHeaders; return this; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index 3d7623c96f4..e5885569cdf 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -36,6 +36,7 @@ /** * @author Christian Tzolov * @author Thomas Vitale + * @author lambochen */ class ChatCompletionRequestTests { @@ -44,6 +45,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder() .model("DEFAULT_MODEL") .internalToolExecutionEnabled(true) + .internalToolExecutionMaxIterations(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS) .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) .toolNames("tool1", "tool2") .toolContext(Map.of("key1", "value1", "key2", "valueA")) @@ -56,6 +58,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { OpenAiChatOptions runtimeOptions = OpenAiChatOptions.builder() .internalToolExecutionEnabled(false) + .internalToolExecutionMaxIterations(10) .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) .toolNames("tool3") .toolContext(Map.of("key2", "valueB")) @@ -64,6 +67,8 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionMaxIterations()) + .isEqualTo(10); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() .stream() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java index d09808f1a31..e47bbc79506 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions; @@ -80,6 +81,7 @@ void testBuilderWithAllFields() { .metadata(metadata) .reasoningEffort("medium") .internalToolExecutionEnabled(false) + .internalToolExecutionMaxIterations(10) .httpHeaders(Map.of("header1", "value1")) .toolContext(toolContext) .build(); @@ -89,10 +91,10 @@ void testBuilderWithAllFields() { "maxCompletionTokens", "n", "outputModalities", "outputAudio", "presencePenalty", "responseFormat", "streamOptions", "seed", "stop", "temperature", "topP", "tools", "toolChoice", "user", "parallelToolCalls", "store", "metadata", "reasoningEffort", "internalToolExecutionEnabled", - "httpHeaders", "toolContext") + "internalToolExecutionMaxIterations", "httpHeaders", "toolContext") .containsExactly("test-model", 0.5, logitBias, true, 5, 100, 50, 2, outputModalities, outputAudio, 0.8, responseFormat, streamOptions, 12345, stopSequences, 0.7, 0.9, tools, toolChoice, "test-user", true, - false, metadata, "medium", false, Map.of("header1", "value1"), toolContext); + false, metadata, "medium", false, 10, Map.of("header1", "value1"), toolContext); assertThat(options.getStreamUsage()).isTrue(); assertThat(options.getStreamOptions()).isEqualTo(StreamOptions.INCLUDE_USAGE); @@ -139,6 +141,7 @@ void testCopy() { .metadata(metadata) .reasoningEffort("low") .internalToolExecutionEnabled(true) + .internalToolExecutionMaxIterations(3) .httpHeaders(Map.of("header1", "value1")) .build(); @@ -187,6 +190,7 @@ void testSetters() { options.setMetadata(metadata); options.setReasoningEffort("high"); options.setInternalToolExecutionEnabled(false); + options.setInternalToolExecutionMaxIterations(3); options.setHttpHeaders(Map.of("header2", "value2")); assertThat(options.getModel()).isEqualTo("test-model"); @@ -214,6 +218,7 @@ void testSetters() { assertThat(options.getMetadata()).isEqualTo(metadata); assertThat(options.getReasoningEffort()).isEqualTo("high"); assertThat(options.getInternalToolExecutionEnabled()).isFalse(); + assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3); assertThat(options.getHttpHeaders()).isEqualTo(Map.of("header2", "value2")); assertThat(options.getStreamUsage()).isTrue(); options.setStreamUsage(false); @@ -253,6 +258,8 @@ void testDefaultValues() { assertThat(options.getReasoningEffort()).isNull(); assertThat(options.getToolCallbacks()).isNotNull().isEmpty(); assertThat(options.getInternalToolExecutionEnabled()).isNull(); + assertThat(options.getInternalToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); assertThat(options.getHttpHeaders()).isNotNull().isEmpty(); assertThat(options.getToolContext()).isEqualTo(new HashMap<>()); assertThat(options.getStreamUsage()).isFalse(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index e7401d9d81b..e10dfdd2307 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -32,22 +32,48 @@ public class OpenAiTestConfiguration { @Bean public OpenAiApi openAiApi() { - return OpenAiApi.builder().apiKey(getApiKey()).build(); + var builder = OpenAiApi.builder().apiKey(getApiKey()); + + String baseUrl = getBaseUrl(); + if (StringUtils.hasText(baseUrl)) { + builder.baseUrl(baseUrl); + } + String completionsPath = getCompletionsPath(); + if (StringUtils.hasText(completionsPath)) { + builder.completionsPath(completionsPath); + } + + return builder.build(); } @Bean public OpenAiImageApi openAiImageApi() { - return OpenAiImageApi.builder().apiKey(getApiKey()).build(); + var builder = OpenAiImageApi.builder().apiKey(getApiKey()); + String baseUrl = getBaseUrl(); + if (StringUtils.hasText(baseUrl)) { + builder.baseUrl(baseUrl); + } + return builder.build(); } @Bean public OpenAiAudioApi openAiAudioApi() { - return OpenAiAudioApi.builder().apiKey(getApiKey()).build(); + var builder = OpenAiAudioApi.builder().apiKey(getApiKey()); + String baseUrl = getBaseUrl(); + if (StringUtils.hasText(baseUrl)) { + builder.baseUrl(baseUrl); + } + return builder.build(); } @Bean public OpenAiModerationApi openAiModerationApi() { - return OpenAiModerationApi.builder().apiKey(getApiKey()).build(); + var builder = OpenAiModerationApi.builder().apiKey(getApiKey()); + String baseUrl = getBaseUrl(); + if (StringUtils.hasText(baseUrl)) { + builder.baseUrl(baseUrl); + } + return builder.build(); } private ApiKey getApiKey() { @@ -59,6 +85,22 @@ private ApiKey getApiKey() { return new SimpleApiKey(apiKey); } + private String getBaseUrl() { + String baseUrl = System.getenv("OPENAI_BASE_URL"); + if (StringUtils.hasText(baseUrl)) { + return baseUrl; + } + return null; + } + + private String getCompletionsPath() { + String path = System.getenv("OPENAI_COMPLETIONS_PATH"); + if (StringUtils.hasText(path)) { + return path; + } + return null; + } + @Bean public OpenAiChatModel openAiChatModel(OpenAiApi api) { return OpenAiChatModel.builder() diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 01ab8b96c02..ed709e58061 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -136,10 +136,12 @@ * @author Jihoon Kim * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author lambochen * @since 0.8.1 * @see VertexAiGeminiChatOptions * @see ToolCallingManager * @see ChatModel + * @see ToolCallingChatOptions */ public class VertexAiGeminiChatModel implements ChatModel, DisposableBean { @@ -393,6 +395,10 @@ public ChatResponse call(Prompt prompt) { } private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return this.internalCall(prompt, previousChatResponse, 1); + } + + private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) @@ -425,7 +431,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon return chatResponse; })); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -437,7 +443,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } @@ -469,6 +475,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setInternalToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(), + this.defaultOptions.getInternalToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -483,6 +492,8 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions + .setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); @@ -503,6 +514,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return this.internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -538,7 +553,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux flux = chatResponseFlux.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.defer(() -> { @@ -551,7 +566,10 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha } else { // Send the tool execution result back to the model. - return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response); + return this.internalStream( + new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response, + iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 68ae24a92e2..46315a9dec2 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -45,6 +45,7 @@ * @author Grogdunn * @author Ilayaperumal Gopinathan * @author Soby Chacko + * @author lambochen * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -126,6 +127,9 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + @JsonIgnore private Map toolContext = new HashMap<>(); @@ -161,6 +165,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); options.setSafetySettings(fromOptions.getSafetySettings()); options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()); + options.setInternalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations()); options.setToolContext(fromOptions.getToolContext()); return options; } @@ -281,6 +286,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getInternalToolExecutionMaxIterations() { + return this.internalToolExecutionMaxIterations; + } + + @Override + public void setInternalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + } + @Override public Double getFrequencyPenalty() { return this.frequencyPenalty; @@ -346,6 +361,7 @@ public boolean equals(Object o) { && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.safetySettings, that.safetySettings) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.internalToolExecutionMaxIterations, that.internalToolExecutionMaxIterations) && Objects.equals(this.toolContext, that.toolContext); } @@ -354,7 +370,7 @@ public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings, - this.internalToolExecutionEnabled, this.toolContext); + this.internalToolExecutionEnabled, this.internalToolExecutionMaxIterations, this.toolContext); } @Override @@ -478,6 +494,11 @@ public Builder internalToolExecutionEnabled(boolean internalToolExecutionEnabled return this; } + public Builder internalToolExecutionMaxIterations(Integer internalToolExecutionMaxIterations) { + this.options.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptionsTest.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptionsTest.java new file mode 100644 index 00000000000..d313890c942 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptionsTest.java @@ -0,0 +1,43 @@ +package org.springframework.ai.vertexai.gemini; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; + +import static org.junit.jupiter.api.Assertions.*; + +class VertexAiGeminiChatOptionsTest { + + @Test + void optionsDefault() { + var options = new VertexAiGeminiChatOptions(); + + assertEquals(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS, + options.getInternalToolExecutionMaxIterations()); + } + + @Test + void builderDefault() { + var options = VertexAiGeminiChatOptions.builder().build(); + + assertEquals(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS, + options.getInternalToolExecutionMaxIterations()); + } + + @Test + void testBuilder() { + var options = VertexAiGeminiChatOptions.builder().internalToolExecutionMaxIterations(3).build(); + + assertEquals(3, options.getInternalToolExecutionMaxIterations()); + } + + @Test + void fromOptions() { + var original = new VertexAiGeminiChatOptions(); + original.setInternalToolExecutionMaxIterations(3); + + var copied = VertexAiGeminiChatOptions.fromOptions(original); + + assertEquals(original.getInternalToolExecutionMaxIterations(), copied.getInternalToolExecutionMaxIterations()); + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 408666fdc34..3f2ff1242ee 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -82,9 +82,11 @@ * @author Geng Rong * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author lambochen * @see ChatModel * @see StreamingChatModel * @see ZhiPuAiApi + * @see ToolCallingChatOptions * @since 1.0.0 M1 */ public class ZhiPuAiChatModel implements ChatModel { @@ -237,6 +239,10 @@ public ChatResponse call(Prompt prompt) { // Before moving any further, build the final request Prompt, // merging runtime and default options. Prompt requestPrompt = buildRequestPrompt(prompt); + return internalCall(requestPrompt, 1); + } + + private ChatResponse internalCall(Prompt requestPrompt, int iterations) { ChatCompletionRequest request = createRequest(requestPrompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -255,7 +261,7 @@ public ChatResponse call(Prompt prompt) { var chatCompletion = completionEntity.getBody(); if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); + logger.warn("No chat completion returned for prompt: {}", requestPrompt); return new ChatResponse(List.of()); } @@ -263,12 +269,12 @@ public ChatResponse call(Prompt prompt) { List generations = choices.stream().map(choice -> { // @formatter:off - Map metadata = Map.of( - "id", chatCompletion.id(), - "role", choice.message().role() != null ? choice.message().role().name() : "", - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" - ); - // @formatter:on + Map metadata = Map.of( + "id", chatCompletion.id(), + "role", choice.message().role() != null ? choice.message().role().name() : "", + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" + ); + // @formatter:on return buildGeneration(choice, metadata); }).toList(); @@ -278,7 +284,8 @@ public ChatResponse call(Prompt prompt) { return chatResponse; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response, + iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -289,7 +296,9 @@ public ChatResponse call(Prompt prompt) { } else { // Send the tool execution result back to the model. - return this.call(new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions())); + return this.internalCall( + new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions()), + iterations + 1); } } return response; @@ -302,6 +311,10 @@ public ChatOptions getDefaultOptions() { @Override public Flux stream(Prompt prompt) { + return internalStream(prompt, 1); + } + + private Flux internalStream(Prompt prompt, int iterations) { return Flux.deferContextual(contextView -> { // Before moving any further, build the final request Prompt, // merging runtime and default options. @@ -332,18 +345,18 @@ public Flux stream(Prompt prompt) { String id = chatCompletion2.id(); // @formatter:off - List generations = chatCompletion2.choices().stream().map(choice -> { - if (choice.message().role() != null) { - roleMap.putIfAbsent(id, choice.message().role().name()); - } - Map metadata = Map.of( - "id", chatCompletion2.id(), - "role", roleMap.getOrDefault(id, ""), - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" - ); - return buildGeneration(choice, metadata); - }).toList(); - // @formatter:on + List generations = chatCompletion2.choices().stream().map(choice -> { + if (choice.message().role() != null) { + roleMap.putIfAbsent(id, choice.message().role().name()); + } + Map metadata = Map.of( + "id", chatCompletion2.id(), + "role", roleMap.getOrDefault(id, ""), + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" + ); + return buildGeneration(choice, metadata); + }).toList(); + // @formatter:on return new ChatResponse(generations, from(chatCompletion2)); } @@ -356,7 +369,7 @@ public Flux stream(Prompt prompt) { // @formatter:off Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response, iterations)) { return Flux.defer(() -> { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous @@ -369,15 +382,17 @@ public Flux stream(Prompt prompt) { } else { // Send the tool execution result back to the model. - return this.stream(new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions())); + return this.internalStream( + new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions()), + iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } return Flux.just(response); - }) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on return new MessageAggregator().aggregate(flux, observationContext::setResponse); @@ -449,6 +464,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setInternalToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(), + this.defaultOptions.getInternalToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -458,6 +476,8 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions + .setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index 8b8d3974413..1b306d619f3 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -42,6 +42,7 @@ * @author Geng Rong * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author lambochen * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) @@ -125,6 +126,9 @@ public class ZhiPuAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + @JsonIgnore private Map toolContext = new HashMap<>(); // @formatter:on @@ -148,6 +152,7 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { .toolCallbacks(fromOptions.getToolCallbacks()) .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .internalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext()) .build(); } @@ -307,6 +312,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getInternalToolExecutionMaxIterations() { + return this.internalToolExecutionMaxIterations; + } + + @Override + public void setInternalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + } + @Override public Map getToolContext() { return this.toolContext; @@ -331,6 +346,8 @@ public int hashCode() { result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); result = prime * result + ((this.internalToolExecutionEnabled == null) ? 0 : this.internalToolExecutionEnabled.hashCode()); + result = prime * result + ((this.internalToolExecutionMaxIterations == null) ? 0 + : this.internalToolExecutionMaxIterations.hashCode()); result = prime * result + ((this.toolCallbacks == null) ? 0 : this.toolCallbacks.hashCode()); result = prime * result + ((this.toolNames == null) ? 0 : this.toolNames.hashCode()); result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); @@ -437,6 +454,14 @@ else if (!this.doSample.equals(other.doSample)) { else if (!this.internalToolExecutionEnabled.equals(other.internalToolExecutionEnabled)) { return false; } + if (this.internalToolExecutionMaxIterations == null) { + if (other.internalToolExecutionMaxIterations != null) { + return false; + } + } + else if (!this.internalToolExecutionMaxIterations.equals(other.internalToolExecutionMaxIterations)) { + return false; + } if (this.toolContext == null) { if (other.toolContext != null) { return false; @@ -468,6 +493,10 @@ public ToolCallingChatOptions merge(ChatOptions options) { builder.internalToolExecutionEnabled(toolCallingChatOptions.getInternalToolExecutionEnabled() != null ? (toolCallingChatOptions).getInternalToolExecutionEnabled() : this.getInternalToolExecutionEnabled()); + builder.internalToolExecutionMaxIterations( + toolCallingChatOptions.getInternalToolExecutionMaxIterations() != null + ? toolCallingChatOptions.getInternalToolExecutionMaxIterations() + : this.getInternalToolExecutionMaxIterations()); Set toolNames = new HashSet<>(); if (this.toolNames != null) { @@ -498,6 +527,7 @@ public ToolCallingChatOptions merge(ChatOptions options) { } else { builder.internalToolExecutionEnabled(this.internalToolExecutionEnabled); + builder.internalToolExecutionMaxIterations(this.internalToolExecutionMaxIterations); builder.toolNames(this.toolNames != null ? new HashSet<>(this.toolNames) : null); builder.toolCallbacks(this.toolCallbacks != null ? new ArrayList<>(this.toolCallbacks) : null); builder.toolContext(this.toolContext != null ? new HashMap<>(this.toolContext) : null); @@ -603,6 +633,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder internalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatOptionsTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatOptionsTests.java new file mode 100644 index 00000000000..0ba7ad99a76 --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatOptionsTests.java @@ -0,0 +1,36 @@ +package org.springframework.ai.zhipuai.chat; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +/** + * @author lambochen + */ +class ZhiPuAiChatOptionsTests { + + @Test + void testDefaultValue() { + var options = new ZhiPuAiChatOptions(); + + assertThat(options.getInternalToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); + } + + @Test + void testSetter() { + var options = new ZhiPuAiChatOptions(); + options.setInternalToolExecutionMaxIterations(3); + assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3); + } + + @Test + void testBuilder() { + var options = ZhiPuAiChatOptions.builder().internalToolExecutionMaxIterations(3).build(); + + assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3); + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java index 870db6931b9..ac551fea104 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java @@ -33,6 +33,7 @@ * Default implementation of {@link ToolCallingChatOptions}. * * @author Thomas Vitale + * @author lambochen * @since 1.0.0 */ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { @@ -46,6 +47,9 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { @Nullable private Boolean internalToolExecutionEnabled; + @Nullable + private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + @Nullable private String model; @@ -118,6 +122,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getInternalToolExecutionMaxIterations() { + return this.internalToolExecutionMaxIterations; + } + + @Override + public void setInternalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) { + this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations; + } + @Override @Nullable public String getModel() { @@ -206,6 +220,7 @@ public T copy() { options.setToolNames(getToolNames()); options.setToolContext(getToolContext()); options.setInternalToolExecutionEnabled(getInternalToolExecutionEnabled()); + options.setInternalToolExecutionMaxIterations(getInternalToolExecutionMaxIterations()); options.setModel(getModel()); options.setFrequencyPenalty(getFrequencyPenalty()); options.setMaxTokens(getMaxTokens()); @@ -277,6 +292,13 @@ public ToolCallingChatOptions.Builder internalToolExecutionEnabled( return this; } + @Override + public ToolCallingChatOptions.Builder internalToolExecutionMaxIterations( + @Nullable Integer internalToolExecutionMaxIterations) { + this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations); + return this; + } + @Override public ToolCallingChatOptions.Builder model(@Nullable String model) { this.options.setModel(model); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index f06e71aa869..583f2f0d791 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -37,12 +37,20 @@ * * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author lambochen * @since 1.0.0 */ public interface ToolCallingChatOptions extends ChatOptions { boolean DEFAULT_TOOL_EXECUTION_ENABLED = true; + /** + * No limit for tool execution attempts. + */ + int TOOL_EXECUTION_NO_LIMIT = Integer.MAX_VALUE; + + int DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS = TOOL_EXECUTION_NO_LIMIT; + /** * ToolCallbacks to be registered with the ChatModel. */ @@ -76,6 +84,20 @@ public interface ToolCallingChatOptions extends ChatOptions { */ void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled); + /** + * Get the maximum number of iteration for tool execution. + * @return the maximum number of iteration. + * @see #getInternalToolExecutionEnabled() + */ + @Nullable + Integer getInternalToolExecutionMaxIterations(); + + /** + * Set the maximum number of iteration for tool execution. + * @param internalToolExecutionMaxIterations the maximum number of iteration. + */ + void setInternalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations); + /** * Get the configured tool context. * @return the tool context map. @@ -109,6 +131,21 @@ static boolean isInternalToolExecutionEnabled(ChatOptions chatOptions) { return internalToolExecutionEnabled; } + static boolean isInternalToolExecutionEnabled(ChatOptions chatOptions, int toolExecutionIterations) { + boolean isInternalToolExecutionEnabled = isInternalToolExecutionEnabled(chatOptions); + if (!isInternalToolExecutionEnabled) { + return false; + } + + if (chatOptions instanceof ToolCallingChatOptions toolCallingChatOptions + && toolCallingChatOptions.getInternalToolExecutionMaxIterations() != null) { + int maxIterations = toolCallingChatOptions.getInternalToolExecutionMaxIterations(); + return toolExecutionIterations <= maxIterations; + } + + return DEFAULT_TOOL_EXECUTION_ENABLED; + } + static Set mergeToolNames(Set runtimeToolNames, Set defaultToolNames) { Assert.notNull(runtimeToolNames, "runtimeToolNames cannot be null"); Assert.notNull(defaultToolNames, "defaultToolNames cannot be null"); @@ -178,6 +215,13 @@ interface Builder extends ChatOptions.Builder { */ Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled); + /** + * the maximum number of attempts for tool execution. + * @param internalToolExecutionMaxIterations the maximum number of iteration. + * @return the {@link ToolCallingChatOptions} Builder. + */ + Builder internalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations); + /** * Add a {@link Map} of context values into tool context. * @param context the map representing the tool context. diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java index 6ba92766929..5c12046e1c4 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java @@ -27,6 +27,7 @@ * responses. * * @author Christian Tzolov + * @author lambochen */ public interface ToolExecutionEligibilityChecker extends Function { @@ -43,6 +44,23 @@ default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse return this.isInternalToolExecutionEnabled(promptOptions) && this.isToolCallResponse(chatResponse); } + /** + * Determines if tool execution should be performed based on the prompt options and + * chat response and toolExecutionIterations. + * @param promptOptions The options from the prompt + * @param chatResponse The response from the chat model + * @param toolExecutionIterations The number of toolExecutionIterations to execute the + * tool + * @return true if tool execution should be performed, false otherwise + */ + default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse chatResponse, + int toolExecutionIterations) { + Assert.notNull(promptOptions, "promptOptions cannot be null"); + Assert.notNull(chatResponse, "chatResponse cannot be null"); + return this.isInternalToolExecutionEnabled(promptOptions, toolExecutionIterations) + && this.isToolCallResponse(chatResponse); + } + /** * Determines if the response is a tool call message response. * @param chatResponse The response from the chat model call @@ -74,4 +92,25 @@ default boolean isInternalToolExecutionEnabled(ChatOptions chatOptions) { return internalToolExecutionEnabled; } + /** + * Determines if tool execution should be performed by the Spring AI or by the client. + * @param chatOptions The options from the chat + * @param toolExecutionIterations The number of toolExecutionIterations to execute the + * tool + * @return true if tool execution should be performed by Spring AI, false if it should + * be performed by the client + */ + default boolean isInternalToolExecutionEnabled(ChatOptions chatOptions, int toolExecutionIterations) { + boolean internalToolExecutionEnabled = isInternalToolExecutionEnabled(chatOptions); + if (!internalToolExecutionEnabled) { + return internalToolExecutionEnabled; + } + + if (chatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { + return toolCallingChatOptions.getInternalToolExecutionMaxIterations() == null + || toolExecutionIterations <= toolCallingChatOptions.getInternalToolExecutionMaxIterations(); + } + return internalToolExecutionEnabled; + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java index e3f048ebd41..a1181017986 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java @@ -43,4 +43,28 @@ default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse return test(promptOptions, chatResponse); } + /** + * Determines if tool execution should be performed based on the prompt options and + * chat response and the number of toolExecutionIterations. + * @param promptOptions The options from the prompt + * @param chatResponse The response from the chat model + * @param toolExecutionIterations The number of toolExecutionIterations + * @return true if tool execution should be performed, false otherwise + * @see ToolCallingChatOptions#getInternalToolExecutionMaxIterations() + * @see #isToolExecutionRequired(ChatOptions, ChatResponse) + */ + default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse chatResponse, + int toolExecutionIterations) { + boolean isToolExecutionRequired = isToolExecutionRequired(promptOptions, chatResponse); + if (!isToolExecutionRequired) { + return isToolExecutionRequired; + } + + if (promptOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { + return toolCallingChatOptions.getInternalToolExecutionMaxIterations() == null + || toolExecutionIterations <= toolCallingChatOptions.getInternalToolExecutionMaxIterations(); + } + return isToolExecutionRequired; + } + } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java index 45557f23a6d..309c74b92e8 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java @@ -33,6 +33,7 @@ * Unit tests for {@link DefaultToolCallingChatOptions}. * * @author Thomas Vitale + * @author lambochen */ class DefaultToolCallingChatOptionsTests { @@ -140,6 +141,7 @@ void copyShouldCreateNewInstanceWithSameValues() { original.setToolNames(Set.of("tool1")); original.setToolContext(Map.of("key", "value")); original.setInternalToolExecutionEnabled(true); + original.setInternalToolExecutionMaxIterations(ToolCallingChatOptions.TOOL_EXECUTION_NO_LIMIT); original.setModel("gpt-4"); original.setTemperature(0.7); @@ -150,6 +152,8 @@ void copyShouldCreateNewInstanceWithSameValues() { assertThat(c.getToolNames()).isEqualTo(original.getToolNames()); assertThat(c.getToolContext()).isEqualTo(original.getToolContext()); assertThat(c.getInternalToolExecutionEnabled()).isEqualTo(original.getInternalToolExecutionEnabled()); + assertThat(c.getInternalToolExecutionMaxIterations()) + .isEqualTo(original.getInternalToolExecutionMaxIterations()); assertThat(c.getModel()).isEqualTo(original.getModel()); assertThat(c.getTemperature()).isEqualTo(original.getTemperature()); }); @@ -180,6 +184,7 @@ void builderShouldCreateOptionsWithAllProperties() { .toolNames(Set.of("tool1")) .toolContext(context) .internalToolExecutionEnabled(true) + .internalToolExecutionMaxIterations(3) .model("gpt-4") .temperature(0.7) .maxTokens(100) @@ -195,6 +200,7 @@ void builderShouldCreateOptionsWithAllProperties() { assertThat(o.getToolNames()).containsExactly("tool1"); assertThat(o.getToolContext()).isEqualTo(context); assertThat(o.getInternalToolExecutionEnabled()).isTrue(); + assertThat(o.getInternalToolExecutionMaxIterations()).isEqualTo(3); assertThat(o.getModel()).isEqualTo("gpt-4"); assertThat(o.getTemperature()).isEqualTo(0.7); assertThat(o.getMaxTokens()).isEqualTo(100); @@ -233,6 +239,13 @@ void deprecatedMethodsShouldWorkCorrectly() { options.setInternalToolExecutionEnabled(true); assertThat(options.getInternalToolExecutionEnabled()).isTrue(); + + // default value check + assertThat(options.getInternalToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); + + options.setInternalToolExecutionMaxIterations(3); + assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3); } } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java index 6d5d599dccd..8e17c5252d7 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java @@ -33,6 +33,7 @@ * Unit tests for {@link ToolCallingChatOptions}. * * @author Thomas Vitale + * @author lambochen */ class ToolCallingChatOptionsTests { @@ -50,6 +51,20 @@ void whenToolCallingChatOptionsAndExecutionEnabledFalse() { assertThat(ToolCallingChatOptions.isInternalToolExecutionEnabled(options)).isFalse(); } + @Test + void whenToolCallingChatOptionsAndMaxIterationsOver() { + ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + options.setInternalToolExecutionMaxIterations(1); + // 3 > 1 + assertThat(ToolCallingChatOptions.isInternalToolExecutionEnabled(options, 3)).isFalse(); + } + + @Test + void whenToolCallingChatOptionsAndMaxIterationsDefault() { + ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + assertThat(ToolCallingChatOptions.isInternalToolExecutionEnabled(options, 1)).isTrue(); + } + @Test void whenToolCallingChatOptionsAndExecutionEnabledDefault() { ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityCheckerTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityCheckerTest.java new file mode 100644 index 00000000000..21b79860f4c --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityCheckerTest.java @@ -0,0 +1,54 @@ +package org.springframework.ai.model.tool; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +class ToolExecutionEligibilityCheckerTest { + + @Test + void isToolExecutionRequired() { + ToolExecutionEligibilityChecker checker = new TestToolExecutionEligibilityChecker(); + + ToolCallingChatOptions promptOptions = ToolCallingChatOptions.builder().build(); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("test")))); + promptOptions.setInternalToolExecutionMaxIterations(2); + + assertThat(checker.isToolExecutionRequired(promptOptions, chatResponse, 1)).isTrue(); + assertThat(checker.isToolExecutionRequired(promptOptions, chatResponse, 2)).isTrue(); + + // attempts value is oversize + assertThat(checker.isToolExecutionRequired(promptOptions, chatResponse, 3)).isFalse(); + } + + @Test + void isInternalToolExecutionEnabled() { + + ToolExecutionEligibilityChecker checker = new TestToolExecutionEligibilityChecker(); + + ToolCallingChatOptions promptOptions = ToolCallingChatOptions.builder().build(); + promptOptions.setInternalToolExecutionMaxIterations(2); + + assertThat(checker.isInternalToolExecutionEnabled(promptOptions, 1)).isTrue(); + assertThat(checker.isInternalToolExecutionEnabled(promptOptions, 2)).isTrue(); + + // attempts value is oversize + assertThat(checker.isInternalToolExecutionEnabled(promptOptions, 3)).isFalse(); + + } + + class TestToolExecutionEligibilityChecker implements ToolExecutionEligibilityChecker { + + @Override + public Boolean apply(ChatResponse chatResponse) { + return true; + } + + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java index d347f9190f1..7e1c15849f0 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java @@ -45,6 +45,20 @@ void whenIsToolExecutionRequiredWithNullPromptOptions() { .hasMessageContaining("promptOptions cannot be null"); } + @Test + void whenIsToolExecutionRequiredWithAttempts() { + ToolExecutionEligibilityPredicate predicate = new TestToolExecutionEligibilityPredicate(); + ToolCallingChatOptions promptOptions = ToolCallingChatOptions.builder().build(); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("test")))); + promptOptions.setInternalToolExecutionMaxIterations(2); + + assertThat(predicate.isToolExecutionRequired(promptOptions, chatResponse, 1)).isTrue(); + assertThat(predicate.isToolExecutionRequired(promptOptions, chatResponse, 2)).isTrue(); + + // attempts value is oversize + assertThat(predicate.isToolExecutionRequired(promptOptions, chatResponse, 3)).isFalse(); + } + @Test void whenIsToolExecutionRequiredWithNullChatResponse() { ToolExecutionEligibilityPredicate predicate = new TestToolExecutionEligibilityPredicate();