diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/GuardrailAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/GuardrailAdvisor.java new file mode 100644 index 00000000000..35c64f91735 --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/GuardrailAdvisor.java @@ -0,0 +1,176 @@ +/* + * The {@code GuardrailAdvisor} class is an implementation of both {@link CallAdvisor} and {@link StreamAdvisor} + * that provides flexible input and output validation for chat client requests and responses. + * + * This advisor allows you to define custom validation logic for both user input and model output + * by supplying {@link Predicate} functions. If the input or output does not pass the specified validation, + * a configurable failure response is returned instead of proceeding with the normal processing chain. + * + * Typical use cases include enforcing content policies, blocking sensitive or inappropriate content, + * or implementing custom guardrails for AI-powered chat applications. + * + * The class also provides a builder for convenient and readable instantiation. + * + * Example usage: + *
+ * GuardrailAdvisor advisor = new GuardrailAdvisor.Builder()
+ *     .inputValidator(input -> !input.contains("forbidden"))
+ *     .outputValidator(output -> !output.contains("restricted"))
+ *     .failureResponse("Your request cannot be processed due to policy restrictions.")
+ *     .order(1)
+ *     .build();
+ * 
+ * + * @author Karson To + * @since 1.0.0 + */ + +package org.springframework.ai.chat.client.advisor; + +import org.springframework.ai.chat.client.ChatClientRequest; + +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.util.Assert; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +public class GuardrailAdvisor implements CallAdvisor, StreamAdvisor { + + private static final String DEFAULT_FAILURE_RESPONSE = + "Sorry, your request cannot be processed because it contains content that does not comply with our policy. " + + "Please revise your input and try again."; + + private static final int DEFAULT_ORDER = 0; + + private final String failureResponse; + + private final Predicate inputValidator; + + private final Predicate outputValidator; + + private final int order; + + + public GuardrailAdvisor(Predicate inputValidator, Predicate outputValidator, String failureResponse, + int order) { + Assert.notNull(inputValidator, "Input validator must not be null!"); + Assert.notNull(outputValidator, "Output validator must not be null!"); + Assert.notNull(failureResponse, "Failure response must not be null!"); + this.inputValidator = inputValidator; + this.outputValidator = outputValidator; + this.failureResponse = failureResponse; + this.order = order; + } + + private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest) { + return ChatClientResponse.builder().chatResponse( + ChatResponse.builder().generations(List.of(new Generation(new AssistantMessage(this.failureResponse)))) + .build()).context(Map.copyOf(chatClientRequest.context())).build(); + } + + @Override + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { + String input = chatClientRequest.prompt().getContents(); + if (!inputValidator.test(input)) { + return createFailureResponse(chatClientRequest); + } + ChatClientResponse response = callAdvisorChain.nextCall(chatClientRequest); + String output = null; + if (response != null && response.chatResponse() != null && response.chatResponse().getResults() != null + && !response.chatResponse().getResults().isEmpty()) { + Generation generation = response.chatResponse().getResults().get(0); + if (generation != null && generation.getOutput() != null) { + output = generation.getOutput().getText(); + } + } + if (!outputValidator.test(output != null ? output : "")) { + return createFailureResponse(chatClientRequest); + } + return response; + } + + @Override + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + String input = chatClientRequest.prompt().getContents(); + if (!inputValidator.test(input)) { + return Flux.just(createFailureResponse(chatClientRequest)); + } + return streamAdvisorChain.nextStream(chatClientRequest).map(response -> { + String output = null; + if (response != null && response.chatResponse() != null && response.chatResponse().getResults() != null + && !response.chatResponse().getResults().isEmpty()) { + Generation generation = response.chatResponse().getResults().get(0); + if (generation != null && generation.getOutput() != null) { + output = generation.getOutput().getText(); + } + } + if (!outputValidator.test(output != null ? output : "")) { + return createFailureResponse(chatClientRequest); + } + return response; + }); + } + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return this.order; + } + + public static final class Builder { + + private Predicate inputValidator = s -> true; + + private Predicate outputValidator = s -> true; + + private String failureResponse = DEFAULT_FAILURE_RESPONSE; + + private int order = DEFAULT_ORDER; + + private Builder() { + } + + public static Builder builder() { + return new Builder(); + } + + public Builder inputValidator(Predicate inputValidator) { + this.inputValidator = inputValidator; + return this; + } + + public Builder outputValidator(Predicate outputValidator) { + this.outputValidator = outputValidator; + return this; + } + + public Builder failureResponse(String failureResponse) { + this.failureResponse = failureResponse; + return this; + } + + public Builder order(int order) { + this.order = order; + return this; + } + + public GuardrailAdvisor build() { + return new GuardrailAdvisor(this.inputValidator, this.outputValidator, this.failureResponse, this.order); + } + } +} diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/GuardrailAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/GuardrailAdvisorTests.java new file mode 100644 index 00000000000..2e6ea5fb452 --- /dev/null +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/GuardrailAdvisorTests.java @@ -0,0 +1,183 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.UserMessage; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link GuardrailAdvisor}. + *

+ * This test class verifies the input and output validation logic of the GuardrailAdvisor, ensuring that inappropriate + * content is properly blocked and a failure response is returned. + *

+ * Main test coverage includes: + *

    + *
  • Blocking requests when the input does not meet policy requirements, and ensuring the downstream chain is not called.
  • + *
  • Blocking responses when the output does not meet policy requirements, and returning a failure message.
  • + *
  • Allowing requests and responses to pass through when both input and output are valid.
  • + *
  • Validating the same logic for both synchronous (call) and asynchronous (stream) advisor chains.
  • + *
+ *

+ * All dependencies are mocked using Mockito, and both normal and streaming scenarios are covered. + * + * @author Karson To + */ + +class GuardrailAdvisorTests { + + @Test + void testInputBlocked() { + Predicate inputValidator = input -> !input.contains("block"); + Predicate outputValidator = output -> true; + GuardrailAdvisor advisor = GuardrailAdvisor.Builder.builder().inputValidator(inputValidator) + .outputValidator(outputValidator).order(0).build(); + + ChatClientRequest request = mockRequest("this should block"); + CallAdvisorChain chain = mock(CallAdvisorChain.class); + + ChatClientResponse response = advisor.adviseCall(request, chain); + + assertTrue(response.chatResponse().getResults().get(0).getOutput().getText().contains("cannot be processed")); + verify(chain, never()).nextCall(any()); + } + + @Test + void testOutputBlocked() { + Predicate inputValidator = input -> true; + Predicate outputValidator = output -> !output.contains("badword"); + GuardrailAdvisor advisor = GuardrailAdvisor.Builder.builder().inputValidator(inputValidator) + .outputValidator(outputValidator).order(0).build(); + + ChatClientRequest request = mockRequest("normal input"); + CallAdvisorChain chain = mock(CallAdvisorChain.class); + + // 模拟返回带有 badword 的响应 + AssistantMessage msg = new AssistantMessage("this contains badword"); + Generation gen = new Generation(msg); + ChatResponse chatResponse = new ChatResponse(List.of(gen)); + ChatClientResponse clientResponse = ChatClientResponse.builder().chatResponse(chatResponse).context(Map.of()) + .build(); + when(chain.nextCall(any())).thenReturn(clientResponse); + + ChatClientResponse response = advisor.adviseCall(request, chain); + + assertTrue(response.chatResponse().getResults().get(0).getOutput().getText().contains("cannot be processed")); + } + + @Test + void testPassThrough() { + Predicate inputValidator = input -> true; + Predicate outputValidator = output -> true; + GuardrailAdvisor advisor = GuardrailAdvisor.Builder.builder().inputValidator(inputValidator) + .outputValidator(outputValidator).order(0).build(); + + ChatClientRequest request = mockRequest("hello"); + CallAdvisorChain chain = mock(CallAdvisorChain.class); + + AssistantMessage msg = new AssistantMessage("all good"); + Generation gen = new Generation(msg); + ChatResponse chatResponse = new ChatResponse(List.of(gen)); + ChatClientResponse clientResponse = ChatClientResponse.builder().chatResponse(chatResponse).context(Map.of()) + .build(); + when(chain.nextCall(any())).thenReturn(clientResponse); + + ChatClientResponse response = advisor.adviseCall(request, chain); + + assertEquals("all good", response.chatResponse().getResults().get(0).getOutput().getText()); + } + + @Test + void testStreamInputBlocked() { + Predicate inputValidator = input -> input.length() < 5; + Predicate outputValidator = output -> true; + GuardrailAdvisor advisor = GuardrailAdvisor.Builder.builder().inputValidator(inputValidator) + .outputValidator(outputValidator).order(0).build(); + + ChatClientRequest request = mockRequest("toolonginput"); + StreamAdvisorChain chain = mock(StreamAdvisorChain.class); + + Flux flux = advisor.adviseStream(request, chain); + List responses = flux.collectList().block(); + + assertNotNull(responses); + assertEquals(1, responses.size()); + assertTrue(responses.get(0).chatResponse().getResults().get(0).getOutput().getText() + .contains("cannot be processed")); + verify(chain, never()).nextStream(any()); + } + + @Test + void testStreamOutputBlocked() { + Predicate inputValidator = input -> true; + Predicate outputValidator = output -> !output.contains("bad"); + GuardrailAdvisor advisor = GuardrailAdvisor.Builder.builder().inputValidator(inputValidator) + .outputValidator(outputValidator).order(0).build(); + + ChatClientRequest request = mockRequest("ok"); + StreamAdvisorChain chain = mock(StreamAdvisorChain.class); + + AssistantMessage msg1 = new AssistantMessage("good"); + AssistantMessage msg2 = new AssistantMessage("bad output"); + Generation gen1 = new Generation(msg1); + Generation gen2 = new Generation(msg2); + ChatResponse chatResponse1 = new ChatResponse(List.of(gen1)); + ChatResponse chatResponse2 = new ChatResponse(List.of(gen2)); + ChatClientResponse resp1 = ChatClientResponse.builder().chatResponse(chatResponse1).context(Map.of()).build(); + ChatClientResponse resp2 = ChatClientResponse.builder().chatResponse(chatResponse2).context(Map.of()).build(); + + when(chain.nextStream(any())).thenReturn(Flux.just(resp1, resp2)); + + List responses = advisor.adviseStream(request, chain).collectList().block(); + + assertNotNull(responses); + assertEquals(2, responses.size()); + assertEquals("good", responses.get(0).chatResponse().getResults().get(0).getOutput().getText()); + assertTrue(responses.get(1).chatResponse().getResults().get(0).getOutput().getText() + .contains("cannot be processed")); + } + + private ChatClientRequest mockRequest(String content) { + ChatClientRequest request = mock(ChatClientRequest.class); + Prompt prompt = new Prompt(new UserMessage(content)); + when(request.prompt()).thenReturn(prompt); + when(request.context()).thenReturn(Map.of()); + return request; + } +} \ No newline at end of file