From 51d216d0cf44b325288e586a583ae80cb9625dac Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Fri, 8 Sep 2023 14:01:18 +0200 Subject: [PATCH] Spring AI RAG basic integration --- .gitignore | 2 + build.gradle | 3 + .../music/config/ai/AiConfiguration.java | 54 ++++++++++++++ .../music/config/ai/MessageRetriever.java | 74 +++++++++++++++++++ .../config/ai/VectorStoreInitializer.java | 52 +++++++++++++ .../samples/music/web/AlbumController.java | 16 +++- src/main/resources/application.yml | 3 + src/main/resources/prompts/system-qa.st | 8 ++ 8 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/cloudfoundry/samples/music/config/ai/AiConfiguration.java create mode 100644 src/main/java/org/cloudfoundry/samples/music/config/ai/MessageRetriever.java create mode 100644 src/main/java/org/cloudfoundry/samples/music/config/ai/VectorStoreInitializer.java create mode 100644 src/main/resources/prompts/system-qa.st diff --git a/.gitignore b/.gitignore index 3bd5cb9ef..6cc546905 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,5 @@ build/ *.log* /classes/ + +.vscode \ No newline at end of file diff --git a/build.gradle b/build.gradle index df581fbb2..5bd7ee474 100644 --- a/build.gradle +++ b/build.gradle @@ -8,6 +8,7 @@ plugins { repositories { mavenCentral() + mavenLocal() } ext { @@ -23,6 +24,8 @@ dependencies { implementation "org.springframework.boot:spring-boot-starter-data-redis" implementation "org.springframework.boot:spring-boot-starter-validation" + implementation "org.springframework.experimental.ai:spring-ai-openai-spring-boot-starter:0.2.0-SNAPSHOT" + // Java CfEnv implementation "io.pivotal.cfenv:java-cfenv-boot:${javaCfEnvVersion}" diff --git a/src/main/java/org/cloudfoundry/samples/music/config/ai/AiConfiguration.java b/src/main/java/org/cloudfoundry/samples/music/config/ai/AiConfiguration.java new file mode 100644 index 000000000..85b6ecf32 --- /dev/null +++ b/src/main/java/org/cloudfoundry/samples/music/config/ai/AiConfiguration.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.cloudfoundry.samples.music.config.ai; + +import org.springframework.ai.client.AiClient; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.retriever.impl.VectorStoreRetriever; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.impl.InMemoryVectorStore; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * + * @author Christian Tzolov + */ +@Configuration +public class AiConfiguration { + + @Bean + public VectorStore vectorStore(EmbeddingClient embeddingClient) { + return new InMemoryVectorStore(embeddingClient); + } + + @Bean + public VectorStoreRetriever vectorStoreRetriever(VectorStore vectorStore) { + return new VectorStoreRetriever(vectorStore); + } + + @Bean + public VectorStoreInitializer vectorStoreInitializer(VectorStore vectorStore) { + return new VectorStoreInitializer(vectorStore); + } + + @Bean + public MessageRetriever messageRetriever(VectorStoreRetriever vectorStoreRetriever, AiClient aiClient) { + return new MessageRetriever(vectorStoreRetriever, aiClient); + } + +} diff --git a/src/main/java/org/cloudfoundry/samples/music/config/ai/MessageRetriever.java b/src/main/java/org/cloudfoundry/samples/music/config/ai/MessageRetriever.java new file mode 100644 index 000000000..26ec00fec --- /dev/null +++ b/src/main/java/org/cloudfoundry/samples/music/config/ai/MessageRetriever.java @@ -0,0 +1,74 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.cloudfoundry.samples.music.config.ai; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.springframework.ai.client.AiClient; +import org.springframework.ai.client.AiResponse; +import org.springframework.ai.client.Generation; +import org.springframework.ai.document.Document; +import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.prompt.SystemPromptTemplate; +import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.retriever.impl.VectorStoreRetriever; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.io.Resource; + +/** + * + * @author Christian Tzolov + */ +public class MessageRetriever { + + @Value("classpath:/prompts/system-qa.st") + private Resource systemPrompt; + + private VectorStoreRetriever vectorStoreRetriever; + + private AiClient aiClient; + + public MessageRetriever(VectorStoreRetriever vectorStoreRetriever, AiClient aiClient) { + this.vectorStoreRetriever = vectorStoreRetriever; + this.aiClient = aiClient; + } + + public Generation retrieve(String message) { + List relatedDocuments = this.vectorStoreRetriever.retrieve(message); + + Message systemMessage = getSystemMessage(relatedDocuments); + UserMessage userMessage = new UserMessage(message); + + Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); + + AiResponse response = aiClient.generate(prompt); + + return response.getGeneration(); + } + + private Message getSystemMessage(List relatedDocuments) { + + String documents = relatedDocuments.stream().map(entry -> entry.getContent()).collect(Collectors.joining("\n")); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("documents", documents)); + return systemMessage; + + } +} diff --git a/src/main/java/org/cloudfoundry/samples/music/config/ai/VectorStoreInitializer.java b/src/main/java/org/cloudfoundry/samples/music/config/ai/VectorStoreInitializer.java new file mode 100644 index 000000000..d23561e02 --- /dev/null +++ b/src/main/java/org/cloudfoundry/samples/music/config/ai/VectorStoreInitializer.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.cloudfoundry.samples.music.config.ai; + +import java.util.List; + +import org.springframework.ai.document.Document; +import org.springframework.ai.loader.impl.JsonLoader; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.context.event.ApplicationReadyEvent; +import org.springframework.context.ApplicationListener; +import org.springframework.core.io.Resource; + +/** + * + * @author Christian Tzolov + */ +public class VectorStoreInitializer implements ApplicationListener { + + private VectorStore vectorStore; + + @Value("classpath:/albums.json") + private Resource albumsResource; + + public VectorStoreInitializer(VectorStore vectorStore) { + this.vectorStore = vectorStore; + } + + @Override + public void onApplicationEvent(ApplicationReadyEvent event) { + JsonLoader jsonLoader = new JsonLoader(this.albumsResource, + "artist", "title", "releaseYear", "genre"); + List documents = jsonLoader.load(); + this.vectorStore.add(documents); + } + +} diff --git a/src/main/java/org/cloudfoundry/samples/music/web/AlbumController.java b/src/main/java/org/cloudfoundry/samples/music/web/AlbumController.java index 7f98276b9..2ce6e56f7 100644 --- a/src/main/java/org/cloudfoundry/samples/music/web/AlbumController.java +++ b/src/main/java/org/cloudfoundry/samples/music/web/AlbumController.java @@ -1,8 +1,11 @@ package org.cloudfoundry.samples.music.web; +import org.cloudfoundry.samples.music.config.ai.MessageRetriever; import org.cloudfoundry.samples.music.domain.Album; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + +import org.springframework.ai.client.Generation; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.repository.CrudRepository; import org.springframework.web.bind.annotation.*; @@ -14,10 +17,13 @@ public class AlbumController { private static final Logger logger = LoggerFactory.getLogger(AlbumController.class); private CrudRepository repository; + private MessageRetriever messageRetriever; @Autowired - public AlbumController(CrudRepository repository) { + public AlbumController(CrudRepository repository, MessageRetriever messageRetriever) { this.repository = repository; + this.messageRetriever = messageRetriever; + } @RequestMapping(method = RequestMethod.GET) @@ -48,4 +54,12 @@ public void deleteById(@PathVariable String id) { logger.info("Deleting album " + id); repository.deleteById(id); } + + // + @GetMapping("/ai/rag") + public Generation generate( + @RequestParam(value = "message", defaultValue = "Suggest rock music albums?") String message) { + return messageRetriever.retrieve(message); + } + } \ No newline at end of file diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index b1fe20272..ba78d66f3 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -1,4 +1,7 @@ spring: + ai: + openai: + api-key: "YOUR KEY" jpa: generate-ddl: true diff --git a/src/main/resources/prompts/system-qa.st b/src/main/resources/prompts/system-qa.st new file mode 100644 index 000000000..c3ddc552b --- /dev/null +++ b/src/main/resources/prompts/system-qa.st @@ -0,0 +1,8 @@ +You're assisting with questions about music albums and artists. +Use the information from the DOCUMENTS section to provide accurate answers. +The the answer involves referring to the artist, title, genre or release year of the album, include the album name in the response. +In addition for each album write a short paragraph, describing the album, critical receptions, Influence and legacy and Track listing. +If unsure, simply state that you don't know. + +DOCUMENTS: +{documents} \ No newline at end of file