diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index f5a1e8cd11a..28089504925 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -61,6 +61,7 @@ import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; import org.springframework.util.StringUtils; @@ -705,6 +706,7 @@ public TemplateRenderer getTemplateRenderer() { * Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose * settings are replicated from this {@code ChatClientRequest}. */ + @Override public Builder mutate() { DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient .builder(this.chatModel, this.observationRegistry, this.observationConvention) @@ -713,6 +715,10 @@ public Builder mutate() { .defaultToolContext(this.toolContext) .defaultToolNames(StringUtils.toStringArray(this.toolNames)); + if (!CollectionUtils.isEmpty(this.advisors)) { + builder.defaultAdvisors(a -> a.advisors(this.advisors).params(this.advisorParams)); + } + if (StringUtils.hasText(this.userText)) { builder.defaultUser( u -> u.text(this.userText).params(this.userParams).media(this.media.toArray(new Media[0]))); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index fd795971a2e..1c72596f490 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -48,6 +48,7 @@ import org.springframework.ai.content.Media; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.StructuredOutputConverter; +import org.springframework.ai.template.TemplateRenderer; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.core.ParameterizedTypeReference; @@ -60,6 +61,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Unit tests for {@link DefaultChatClient}. @@ -124,6 +126,51 @@ void whenPromptWithOptionsThenReturn() { assertThat(spec.getChatOptions()).isEqualTo(chatOptions); } + @Test + void testMutate() { + var media = mock(Media.class); + var toolCallback = mock(ToolCallback.class); + var advisor = mock(Advisor.class); + var templateRenderer = mock(TemplateRenderer.class); + var chatOptions = mock(ChatOptions.class); + var copyChatOptions = mock(ChatOptions.class); + when(chatOptions.copy()).thenReturn(copyChatOptions); + var toolContext = new HashMap(); + var userMessage1 = mock(UserMessage.class); + var userMessage2 = mock(UserMessage.class); + + DefaultChatClientBuilder defaultChatClientBuilder = new DefaultChatClientBuilder(mock(ChatModel.class)); + defaultChatClientBuilder.addMessages(List.of(userMessage1, userMessage2)); + ChatClient originalChatClient = defaultChatClientBuilder.defaultAdvisors(advisor) + .defaultOptions(chatOptions) + .defaultUser(u -> u.text("original user {userParams}").param("userParams", "user value2").media(media)) + .defaultSystem(s -> s.text("original system {sysParams}").param("sysParams", "system value1")) + .defaultTemplateRenderer(templateRenderer) + .defaultToolNames("toolName1", "toolName2") + .defaultToolCallbacks(toolCallback) + .defaultToolContext(toolContext) + .build(); + var originalSpec = (DefaultChatClient.DefaultChatClientRequestSpec) originalChatClient.prompt(); + + ChatClient mutateChatClient = originalChatClient.mutate().build(); + var mutateSpec = (DefaultChatClient.DefaultChatClientRequestSpec) mutateChatClient.prompt(); + + assertThat(mutateSpec).isNotSameAs(originalSpec); + + assertThat(mutateSpec.getMessages()).hasSize(2).containsOnly(userMessage1, userMessage2); + assertThat(mutateSpec.getAdvisors()).hasSize(1).containsOnly(advisor); + assertThat(mutateSpec.getChatOptions()).isEqualTo(copyChatOptions); + assertThat(mutateSpec.getUserText()).isEqualTo("original user {userParams}"); + assertThat(mutateSpec.getUserParams()).containsEntry("userParams", "user value2"); + assertThat(mutateSpec.getMedia()).hasSize(1).containsOnly(media); + assertThat(mutateSpec.getSystemText()).isEqualTo("original system {sysParams}"); + assertThat(mutateSpec.getSystemParams()).containsEntry("sysParams", "system value1"); + assertThat(mutateSpec.getTemplateRenderer()).isEqualTo(templateRenderer); + assertThat(mutateSpec.getToolNames()).containsExactly("toolName1", "toolName2"); + assertThat(mutateSpec.getToolCallbacks()).containsExactly(toolCallback); + assertThat(mutateSpec.getToolContext()).isEqualTo(toolContext); + } + @Test void whenMutateChatClientRequest() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();