diff --git a/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java b/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java index 3d4183507c9..1feadc9b72b 100644 --- a/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java +++ b/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java @@ -56,6 +56,8 @@ */ public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor { + public static final String SIMILARITY_THRESHOLD = "chat_memory_vector_store_similarity_threshold"; + public static final String TOP_K = "chat_memory_vector_store_top_k"; private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId"; @@ -64,6 +66,8 @@ public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor private static final int DEFAULT_TOP_K = 20; + private static final double DEFAULT_SIMILARITY_THRESHOLD = 0; + private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate(""" {instructions} @@ -79,6 +83,8 @@ public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor private final int defaultTopK; + private final double defaultSimilarityThreshold; + private final String defaultConversationId; private final int order; @@ -88,14 +94,17 @@ public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor private final VectorStore vectorStore; private VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultTopK, - String defaultConversationId, int order, Scheduler scheduler, VectorStore vectorStore) { + double defaultSimilarityThreshold, String defaultConversationId, int order, Scheduler scheduler, + VectorStore vectorStore) { Assert.notNull(systemPromptTemplate, "systemPromptTemplate cannot be null"); Assert.isTrue(defaultTopK > 0, "topK must be greater than 0"); + Assert.isTrue(defaultSimilarityThreshold >= 0, "similarityThreshold must be equal to or greater than 0"); Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty"); Assert.notNull(scheduler, "scheduler cannot be null"); Assert.notNull(vectorStore, "vectorStore cannot be null"); this.systemPromptTemplate = systemPromptTemplate; this.defaultTopK = defaultTopK; + this.defaultSimilarityThreshold = defaultSimilarityThreshold; this.defaultConversationId = defaultConversationId; this.order = order; this.scheduler = scheduler; @@ -121,10 +130,12 @@ public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorC String conversationId = getConversationId(request.context(), this.defaultConversationId); String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : ""; int topK = getChatMemoryTopK(request.context()); + double similarityThreshold = getChatMemorySimilarityThreshold(request.context()); String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'"; var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder() .query(query) .topK(topK) + .similarityThreshold(similarityThreshold) .filterExpression(filter) .build(); java.util.List documents = this.vectorStore @@ -156,6 +167,11 @@ private int getChatMemoryTopK(Map context) { return context.containsKey(TOP_K) ? Integer.parseInt(context.get(TOP_K).toString()) : this.defaultTopK; } + private double getChatMemorySimilarityThreshold(Map context) { + return context.containsKey(SIMILARITY_THRESHOLD) + ? Double.parseDouble(context.get(SIMILARITY_THRESHOLD).toString()) : this.defaultSimilarityThreshold; + } + @Override public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) { List assistantMessages = new ArrayList<>(); @@ -221,6 +237,8 @@ public static class Builder { private Integer defaultTopK = DEFAULT_TOP_K; + private Double defaultSimilarityThreshold = DEFAULT_SIMILARITY_THRESHOLD; + private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID; private Scheduler scheduler = BaseAdvisor.DEFAULT_SCHEDULER; @@ -257,6 +275,17 @@ public Builder defaultTopK(int defaultTopK) { return this; } + /** + * Set the similarity threshold for retrieving relevant documents. + * @param defaultSimilarityThreshold the required similarity for documents to + * retrieve + * @return this builder + */ + public Builder defaultSimilarityThreshold(Double defaultSimilarityThreshold) { + this.defaultSimilarityThreshold = defaultSimilarityThreshold; + return this; + } + /** * Set the conversation id. * @param conversationId the conversation id @@ -287,8 +316,8 @@ public Builder order(int order) { * @return the advisor */ public VectorStoreChatMemoryAdvisor build() { - return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.defaultTopK, this.conversationId, - this.order, this.scheduler, this.vectorStore); + return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.defaultTopK, + this.defaultSimilarityThreshold, this.conversationId, this.order, this.scheduler, this.vectorStore); } } diff --git a/advisors/spring-ai-advisors-vector-store/src/test/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisorTests.java b/advisors/spring-ai-advisors-vector-store/src/test/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisorTests.java index 749a4ffeef9..fd44de5d845 100644 --- a/advisors/spring-ai-advisors-vector-store/src/test/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisorTests.java +++ b/advisors/spring-ai-advisors-vector-store/src/test/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisorTests.java @@ -91,4 +91,14 @@ void whenDefaultTopKIsNegativeThenThrow() { .hasMessageContaining("topK must be greater than 0"); } + @Test + void whenDefaultSimilarityThresholdIsLessThanZeroThenThrow() { + VectorStore vectorStore = Mockito.mock(VectorStore.class); + + assertThatThrownBy( + () -> VectorStoreChatMemoryAdvisor.builder(vectorStore).defaultSimilarityThreshold(-0.1).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("similarityThreshold must be equal to or greater than 0"); + } + }