Skip to content

Commit be93230

Browse files
committed
fix: get query for MSSQL Server
Signed-off-by: Xavier Chopin <[email protected]>
1 parent c0bc623 commit be93230

File tree

7 files changed

+306
-15
lines changed

7 files changed

+306
-15
lines changed

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@
6666
<artifactId>postgresql</artifactId>
6767
<scope>test</scope>
6868
</dependency>
69+
70+
<dependency>
71+
<groupId>org.testcontainers</groupId>
72+
<artifactId>mssqlserver</artifactId>
73+
<scope>test</scope>
74+
</dependency>
6975
</dependencies>
7076

7177
</project>

memory/spring-ai-model-chat-memory-jdbc/pom.xml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@
6969
<optional>true</optional>
7070
</dependency>
7171

72+
<dependency>
73+
<groupId>com.microsoft.sqlserver</groupId>
74+
<artifactId>mssql-jdbc</artifactId>
75+
<optional>true</optional>
76+
</dependency>
77+
7278
<!-- TESTING -->
7379
<dependency>
7480
<groupId>org.springframework.boot</groupId>
@@ -82,6 +88,12 @@
8288
<scope>test</scope>
8389
</dependency>
8490

91+
<dependency>
92+
<groupId>org.testcontainers</groupId>
93+
<artifactId>mssqlserver</artifactId>
94+
<scope>test</scope>
95+
</dependency>
96+
8597
<dependency>
8698
<groupId>org.testcontainers</groupId>
8799
<artifactId>postgresql</artifactId>

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,27 @@
1616

1717
package org.springframework.ai.chat.memory.jdbc;
1818

19-
import java.sql.PreparedStatement;
20-
import java.sql.ResultSet;
21-
import java.sql.SQLException;
22-
import java.util.List;
23-
2419
import org.springframework.ai.chat.memory.ChatMemory;
25-
import org.springframework.ai.chat.messages.AssistantMessage;
26-
import org.springframework.ai.chat.messages.Message;
27-
import org.springframework.ai.chat.messages.MessageType;
28-
import org.springframework.ai.chat.messages.SystemMessage;
29-
import org.springframework.ai.chat.messages.UserMessage;
20+
import org.springframework.ai.chat.messages.*;
21+
import org.springframework.boot.jdbc.DatabaseDriver;
3022
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
3123
import org.springframework.jdbc.core.JdbcTemplate;
3224
import org.springframework.jdbc.core.RowMapper;
25+
import org.springframework.util.Assert;
26+
27+
import java.sql.Connection;
28+
import java.sql.PreparedStatement;
29+
import java.sql.ResultSet;
30+
import java.sql.SQLException;
31+
import java.util.List;
3332

3433
/**
3534
* An implementation of {@link ChatMemory} for JDBC. Creating an instance of
3635
* JdbcChatMemory example:
3736
* <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code>
3837
*
3938
* @author Jonathan Leijendekker
39+
* @author Xavier Chopin
4040
* @since 1.0.0
4141
*/
4242
public class JdbcChatMemory implements ChatMemory {
@@ -45,14 +45,33 @@ public class JdbcChatMemory implements ChatMemory {
4545
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""";
4646

4747
private static final String QUERY_GET = """
48-
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""";
48+
SELECT content, type \
49+
FROM ai_chat_memory \
50+
WHERE conversation_id = ? \
51+
ORDER BY "timestamp" DESC \
52+
LIMIT ?
53+
""";
54+
55+
private static final String MSSQL_QUERY_GET = """
56+
SELECT content, type \
57+
FROM ( \
58+
SELECT TOP (?) content, type, [timestamp] \
59+
FROM ai_chat_memory \
60+
WHERE conversation_id = ? \
61+
ORDER BY [timestamp] DESC \
62+
) AS recent \
63+
ORDER BY [timestamp] ASC \
64+
""";
4965

5066
private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";
5167

5268
private final JdbcTemplate jdbcTemplate;
5369

70+
private final DatabaseDriver driver;
71+
5472
public JdbcChatMemory(JdbcChatMemoryConfig config) {
5573
this.jdbcTemplate = config.getJdbcTemplate();
74+
this.driver = this.detectDialect(this.jdbcTemplate);
5675
}
5776

5877
public static JdbcChatMemory create(JdbcChatMemoryConfig config) {
@@ -66,16 +85,19 @@ public void add(String conversationId, List<Message> messages) {
6685

6786
@Override
6887
public List<Message> get(String conversationId, int lastN) {
69-
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
88+
return switch (driver) {
89+
case SQLSERVER -> this.jdbcTemplate.query(MSSQL_QUERY_GET, new MessageRowMapper(), lastN, conversationId);
90+
default -> this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
91+
};
7092
}
7193

7294
@Override
7395
public void clear(String conversationId) {
7496
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
7597
}
7698

77-
private record AddBatchPreparedStatement(String conversationId,
78-
List<Message> messages) implements BatchPreparedStatementSetter {
99+
private record AddBatchPreparedStatement(String conversationId, List<Message> messages)
100+
implements BatchPreparedStatementSetter {
79101
@Override
80102
public void setValues(PreparedStatement ps, int i) throws SQLException {
81103
var message = this.messages.get(i);
@@ -108,4 +130,15 @@ public Message mapRow(ResultSet rs, int i) throws SQLException {
108130

109131
}
110132

133+
private DatabaseDriver detectDialect(JdbcTemplate jdbcTemplate) {
134+
try {
135+
Assert.notNull(jdbcTemplate.getDataSource(), "jdbcTemplate.dataSource must not be null");
136+
try (Connection conn = jdbcTemplate.getDataSource().getConnection()) {
137+
String url = conn.getMetaData().getURL();
138+
return DatabaseDriver.fromJdbcUrl(url);
139+
}
140+
} catch (SQLException ex) {
141+
throw new IllegalStateException("Impossible to detect dialect", ex);
142+
}
143+
}
111144
}

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
3636

3737
hints.resources()
3838
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql")
39+
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-mssql.sql")
3940
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql");
4041
}
4142

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
CREATE TABLE ai_chat_memory (
2+
conversation_id VARCHAR(36) NOT NULL,
3+
content NVARCHAR(MAX) NOT NULL,
4+
type VARCHAR(10) NOT NULL,
5+
[timestamp] DATETIME2 NOT NULL DEFAULT SYSDATETIME(),
6+
CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL'))
7+
);
8+
9+
CREATE INDEX ai_chat_memory_conversation_id_timestamp_idx ON ai_chat_memory(conversation_id, [timestamp] DESC);

0 commit comments

Comments
 (0)