diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml index 19e4f0e7bb7..bc7a8c566cc 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml @@ -55,6 +55,13 @@ <scope>test</scope> </dependency> + + <dependency> + <groupId>com.microsoft.sqlserver</groupId> + <artifactId>mssql-jdbc</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>org.testcontainers</groupId> <artifactId>junit-jupiter</artifactId> @@ -66,6 +73,12 @@ <artifactId>postgresql</artifactId> <scope>test</scope> </dependency> + + <dependency> + <groupId>org.testcontainers</groupId> + <artifactId>mssqlserver</artifactId> + <scope>test</scope> + </dependency> </dependencies> </project> diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationMSSQLServerIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationMSSQLServerIT.java new file mode 100644 index 00000000000..91dab38cc51 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationMSSQLServerIT.java @@ -0,0 +1,98 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.memory.jdbc.JdbcChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.testcontainers.containers.MSSQLServerContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Xavier Chopin + */ +@Testcontainers +class JdbcChatMemoryAutoConfigurationMSSQLServerIT { + + static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("mcr.microsoft.com/mssql/server:2022-latest"); + + @Container + @SuppressWarnings("resource") + static MSSQLServerContainer<?> mssqlContainer = new MSSQLServerContainer<>(DEFAULT_IMAGE_NAME) + .acceptLicense() + .withEnv("MSSQL_DATABASE", "chat_memory_auto_configuration_test") + .withPassword("Strong!NotR34LLyPassword"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withPropertyValues(String.format("spring.datasource.url=%s", mssqlContainer.getJdbcUrl()), + String.format("spring.datasource.username=%s", mssqlContainer.getUsername()), + String.format("spring.datasource.password=%s", mssqlContainer.getPassword())); + + @Test + void jdbcChatMemoryScriptDatabaseInitializer_shouldBeLoaded() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=true") + .run(context -> assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue()); + } + + @Test + void jdbcChatMemoryScriptDatabaseInitializer_shouldNotBeLoaded() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=false") + .run(context -> assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isFalse()); + } + + @Test + void addGetAndClear_shouldAllExecute() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=true").run(context -> { + var chatMemory = context.getBean(JdbcChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from the user"); + + chatMemory.add(conversationId, userMessage); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(1); + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(List.of(userMessage)); + + chatMemory.clear(conversationId); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEmpty(); + + var multipleMessages = List.<Message>of(new UserMessage("Message from the user 1"), + new AssistantMessage("Message from the assistant 1")); + + chatMemory.add(conversationId, multipleMessages); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(multipleMessages.size()); + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(multipleMessages); + }); + } + +} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationPostgresIT.java similarity index 98% rename from auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java rename to auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationPostgresIT.java index df9a49d85b9..83016aeaa1e 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationPostgresIT.java @@ -16,15 +16,7 @@ package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; -import java.util.List; -import java.util.UUID; - import org.junit.jupiter.api.Test; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; - import org.springframework.ai.chat.memory.jdbc.JdbcChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -33,6 +25,13 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import java.util.List; +import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; @@ -40,7 +39,7 @@ * @author Jonathan Leijendekker */ @Testcontainers -class JdbcChatMemoryAutoConfigurationIT { +class JdbcChatMemoryAutoConfigurationPostgresIT { static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("postgres:17"); diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseMSSQLServerIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseMSSQLServerIT.java new file mode 100644 index 00000000000..a093dd1f160 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseMSSQLServerIT.java @@ -0,0 +1,66 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; + +import org.junit.jupiter.api.Test; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.testcontainers.containers.MSSQLServerContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import javax.sql.DataSource; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Xavier Chopin + */ +@Testcontainers +class JdbcChatMemoryDataSourceScriptDatabaseMSSQLServerIT { + + static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("mcr.microsoft.com/mssql/server:2022-latest"); + + @Container + @SuppressWarnings("resource") + static MSSQLServerContainer<?> mssqlContainer = new MSSQLServerContainer<>(DEFAULT_IMAGE_NAME) + .acceptLicense() + .withEnv("MSSQL_DATABASE", "chat_memory_test") + .withPassword("Strong!NotR34LLyPassword"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withPropertyValues(String.format("spring.datasource.url=%s", mssqlContainer.getJdbcUrl()), + String.format("spring.datasource.username=%s", mssqlContainer.getUsername()), + String.format("spring.datasource.password=%s", mssqlContainer.getPassword())); + + @Test + void getSettings_shouldHaveSchemaLocations() { + this.contextRunner.run(context -> { + var dataSource = context.getBean(DataSource.class); + var settings = JdbcChatMemoryDataSourceScriptDatabaseInitializer.getSettings(dataSource); + + assertThat(settings.getSchemaLocations()) + .containsOnly("classpath:org/springframework/ai/chat/memory/jdbc/schema-sqlserver.sql"); + }); + } + +} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerTests.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabasePostgresIT.java similarity index 97% rename from auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerTests.java rename to auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabasePostgresIT.java index f563c67cdf1..2b765037ff7 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerTests.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabasePostgresIT.java @@ -16,18 +16,17 @@ package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; -import javax.sql.DataSource; - import org.junit.jupiter.api.Test; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.utility.DockerImageName; -import org.springframework.boot.autoconfigure.AutoConfigurations; -import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; -import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; -import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import javax.sql.DataSource; import static org.assertj.core.api.Assertions.assertThat; @@ -35,7 +34,7 @@ * @author Jonathan Leijendekker */ @Testcontainers -class JdbcChatMemoryDataSourceScriptDatabaseInitializerTests { +class JdbcChatMemoryDataSourceScriptDatabasePostgresIT { static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("postgres:17"); diff --git a/memory/spring-ai-model-chat-memory-jdbc/pom.xml b/memory/spring-ai-model-chat-memory-jdbc/pom.xml index c8a734d38b8..4c2811e1e77 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/pom.xml +++ b/memory/spring-ai-model-chat-memory-jdbc/pom.xml @@ -46,8 +46,8 @@ </dependency> <dependency> - <groupId>org.springframework</groupId> - <artifactId>spring-jdbc</artifactId> + <groupId>org.springframework.boot</groupId> + <artifactId>spring-boot-starter-data-jdbc</artifactId> </dependency> <dependency> @@ -69,6 +69,12 @@ <optional>true</optional> </dependency> + <dependency> + <groupId>com.microsoft.sqlserver</groupId> + <artifactId>mssql-jdbc</artifactId> + <optional>true</optional> + </dependency> + <!-- TESTING --> <dependency> <groupId>org.springframework.boot</groupId> @@ -82,6 +88,12 @@ <scope>test</scope> </dependency> + <dependency> + <groupId>org.testcontainers</groupId> + <artifactId>mssqlserver</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>org.testcontainers</groupId> <artifactId>postgresql</artifactId> diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java index 6c9825bac1b..0261f11aa93 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java @@ -16,20 +16,19 @@ package org.springframework.ai.chat.memory.jdbc; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.List; - import org.springframework.ai.chat.memory.ChatMemory; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.messages.*; +import org.springframework.boot.jdbc.DatabaseDriver; import org.springframework.jdbc.core.BatchPreparedStatementSetter; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.RowMapper; +import org.springframework.util.Assert; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; /** * An implementation of {@link ChatMemory} for JDBC. Creating an instance of @@ -37,6 +36,7 @@ * <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code> * * @author Jonathan Leijendekker + * @author Xavier Chopin * @since 1.0.0 */ public class JdbcChatMemory implements ChatMemory { @@ -47,12 +47,18 @@ public class JdbcChatMemory implements ChatMemory { private static final String QUERY_GET = """ SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?"""; + private static final String MSSQL_QUERY_GET = """ + SELECT TOP (?) content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC"""; + private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?"; private final JdbcTemplate jdbcTemplate; + private final DatabaseDriver driver; + public JdbcChatMemory(JdbcChatMemoryConfig config) { this.jdbcTemplate = config.getJdbcTemplate(); + this.driver = this.detectDatabaseDriver(this.jdbcTemplate); } public static JdbcChatMemory create(JdbcChatMemoryConfig config) { @@ -66,7 +72,10 @@ public void add(String conversationId, List<Message> messages) { @Override public List<Message> get(String conversationId, int lastN) { - return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN); + return switch (driver) { + case SQLSERVER -> this.jdbcTemplate.query(MSSQL_QUERY_GET, new MessageRowMapper(), lastN, conversationId); + default -> this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN); + }; } @Override @@ -74,8 +83,8 @@ public void clear(String conversationId) { this.jdbcTemplate.update(QUERY_CLEAR, conversationId); } - private record AddBatchPreparedStatement(String conversationId, - List<Message> messages) implements BatchPreparedStatementSetter { + private record AddBatchPreparedStatement(String conversationId, List<Message> messages) + implements BatchPreparedStatementSetter { @Override public void setValues(PreparedStatement ps, int i) throws SQLException { var message = this.messages.get(i); @@ -108,4 +117,14 @@ public Message mapRow(ResultSet rs, int i) throws SQLException { } + private DatabaseDriver detectDatabaseDriver(JdbcTemplate jdbcTemplate) { + Assert.notNull(jdbcTemplate.getDataSource(), "jdbcTemplate.dataSource must not be null"); + try { + Connection conn = jdbcTemplate.getDataSource().getConnection(); + String url = conn.getMetaData().getURL(); + return DatabaseDriver.fromJdbcUrl(url); + } catch (SQLException ex) { + throw new IllegalStateException("Impossible to detect the database driver", ex); + } + } } diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java index 6740602e3f8..4df4b39f8ab 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java @@ -16,16 +16,17 @@ package org.springframework.ai.chat.memory.jdbc.aot.hint; -import javax.sql.DataSource; - import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; +import javax.sql.DataSource; + /** * A {@link RuntimeHintsRegistrar} for JDBC Chat Memory hints * * @author Jonathan Leijendekker + * @author Xavier Chopin */ class JdbcChatMemoryRuntimeHints implements RuntimeHintsRegistrar { @@ -36,6 +37,7 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { hints.resources() .registerPattern("org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql") + .registerPattern("org/springframework/ai/chat/memory/jdbc/schema-sqlserver.sql") .registerPattern("org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql"); } diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-sqlserver.sql b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-sqlserver.sql new file mode 100644 index 00000000000..1d5c95d6e3e --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-sqlserver.sql @@ -0,0 +1,9 @@ +CREATE TABLE ai_chat_memory ( + conversation_id VARCHAR(36) NOT NULL, + content NVARCHAR(MAX) NOT NULL, + type VARCHAR(10) NOT NULL, + [timestamp] DATETIME2 NOT NULL DEFAULT SYSDATETIME(), + CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL')) +); + +CREATE INDEX ai_chat_memory_conversation_id_timestamp_idx ON ai_chat_memory(conversation_id, [timestamp] DESC); \ No newline at end of file diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryMSSQLServerIT.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryMSSQLServerIT.java new file mode 100644 index 00000000000..0f1a2323c24 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryMSSQLServerIT.java @@ -0,0 +1,246 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.jdbc; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.*; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.MSSQLServerContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import javax.sql.DataSource; +import java.sql.Timestamp; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +/** + * @author Xavier Chopin + */ +@Testcontainers +class JdbcChatMemoryMSSQLServerIT { + + @Container + @SuppressWarnings("resource") + static MSSQLServerContainer<?> mssqlContainer = new MSSQLServerContainer<>("mcr.microsoft.com/mssql/server:2022-latest") + .acceptLicense() + .withEnv("MSSQL_DATABASE", "chat_memory_test") + .withPassword("Strong!NotR34LLyPassword") + .withInitScript("org/springframework/ai/chat/memory/jdbc/schema-sqlserver.sql"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues(String.format("app.datasource.url=%s", mssqlContainer.getJdbcUrl()), + String.format("app.datasource.username=%s", mssqlContainer.getUsername()), + String.format("app.datasource.password=%s", mssqlContainer.getPassword())); + + @Test + void correctChatMemoryInstance() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + + assertThat(chatMemory).isInstanceOf(JdbcChatMemory.class); + }); + } + + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" }) + void add_shouldInsertSingleMessage(String content, MessageType messageType) { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var message = switch (messageType) { + case ASSISTANT -> new AssistantMessage(content + " - " + conversationId); + case USER -> new UserMessage(content + " - " + conversationId); + case SYSTEM -> new SystemMessage(content + " - " + conversationId); + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + }; + + chatMemory.add(conversationId, message); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; + var result = jdbcTemplate.queryForMap(query, conversationId); + + assertThat(result.size()).isEqualTo(4); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(messageType.name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + }); + } + + @Test + void add_shouldInsertMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemory.add(conversationId, messages); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; + var results = jdbcTemplate.queryForList(query, conversationId); + + assertThat(results.size()).isEqualTo(messages.size()); + + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertThat(result.get("conversation_id")).isNotNull(); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(message.getMessageType().name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + } + }); + } + + @Test + void get_shouldReturnMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.<Message>of(new AssistantMessage("Message from assistant 1 - " + conversationId), + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemory.add(conversationId, messages); + + var results = chatMemory.get(conversationId, Integer.MAX_VALUE); + + assertThat(results.size()).isEqualTo(messages.size()); + assertThat(results).isEqualTo(messages); + }); + } + + @Test + void givenLimitN_shouldReturnNMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + Message expected = new AssistantMessage("Message from assistant 1 - " + conversationId); + + var messages = List.<Message>of(expected, + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemory.add(conversationId, messages); + + var results = chatMemory.get(conversationId, 1); + + assertThat(results.size()).isEqualTo(1); + assertThat(results).isEqualTo(List.of(expected)); + }); + } + + @Test + void givenNonExistingId_shouldReturnEmptyList() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + + List<Message> messages = List.of( + new AssistantMessage("Message from assistant 1 - " + conversationId), + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId) + ); + + chatMemory.add(conversationId, messages); + + var nonExistingUUID = UUID.randomUUID().toString(); + + assertDoesNotThrow(() -> { + List<Message> actual = chatMemory.get(nonExistingUUID, Integer.MAX_VALUE); + assertThat(actual).isEmpty(); + }); + }); + } + + @Test + void clear_shouldDeleteMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemory.add(conversationId, messages); + + chatMemory.clear(conversationId); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM ai_chat_memory WHERE conversation_id = ?", + Integer.class, conversationId); + + assertThat(count).isZero(); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + public ChatMemory chatMemory(JdbcTemplate jdbcTemplate) { + var config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build(); + + return JdbcChatMemory.create(config); + } + + @Bean + public JdbcTemplate jdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public DataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().build(); + } + + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryPostgresIT.java similarity index 82% rename from memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java rename to memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryPostgresIT.java index 96b0e7ca5f8..7890e9b72e6 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryPostgresIT.java @@ -16,27 +16,12 @@ package org.springframework.ai.chat.memory.jdbc; -import java.sql.Timestamp; -import java.util.List; -import java.util.UUID; - -import javax.sql.DataSource; - import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.MountableFile; - import org.springframework.ai.chat.memory.ChatMemory; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.messages.*; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; @@ -46,14 +31,25 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.MountableFile; + +import javax.sql.DataSource; +import java.sql.Timestamp; +import java.util.List; +import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; /** * @author Jonathan Leijendekker + * @author Xavier Chopin */ @Testcontainers -class JdbcChatMemoryIT { +class JdbcChatMemoryPostgresIT { @Container @SuppressWarnings("resource") @@ -161,6 +157,51 @@ void get_shouldReturnMessages() { }); } + @Test + void givenLimitN_shouldReturnNMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + Message expected = new AssistantMessage("Message from assistant 1 - " + conversationId); + + var messages = List.<Message>of(expected, + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemory.add(conversationId, messages); + + var results = chatMemory.get(conversationId, 1); + + assertThat(results.size()).isEqualTo(1); + assertThat(results).isEqualTo(List.of(expected)); + }); + } + + @Test + void givenNonExistingId_shouldReturnEmptyList() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + + List<Message> messages = List.of( + new AssistantMessage("Message from assistant 1 - " + conversationId), + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId) + ); + + chatMemory.add(conversationId, messages); + + var nonExistingUUID = UUID.randomUUID().toString(); + + assertDoesNotThrow(() -> { + List<Message> actual = chatMemory.get(nonExistingUUID, Integer.MAX_VALUE); + assertThat(actual).isEmpty(); + }); + }); + } + @Test void clear_shouldDeleteMessages() { this.contextRunner.run(context -> {