Skip to content

Commit 6ddaa46

Browse files
committed
fix: get query for MSSQL Server
Signed-off-by: Xavier Chopin <[email protected]>
1 parent c0bc623 commit 6ddaa46

File tree

8 files changed

+313
-16
lines changed

8 files changed

+313
-16
lines changed

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml

+6
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@
6666
<artifactId>postgresql</artifactId>
6767
<scope>test</scope>
6868
</dependency>
69+
70+
<dependency>
71+
<groupId>org.testcontainers</groupId>
72+
<artifactId>mssqlserver</artifactId>
73+
<scope>test</scope>
74+
</dependency>
6975
</dependencies>
7076

7177
</project>

memory/spring-ai-model-chat-memory-jdbc/pom.xml

+12
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@
6969
<optional>true</optional>
7070
</dependency>
7171

72+
<dependency>
73+
<groupId>com.microsoft.sqlserver</groupId>
74+
<artifactId>mssql-jdbc</artifactId>
75+
<optional>true</optional>
76+
</dependency>
77+
7278
<!-- TESTING -->
7379
<dependency>
7480
<groupId>org.springframework.boot</groupId>
@@ -82,6 +88,12 @@
8288
<scope>test</scope>
8389
</dependency>
8490

91+
<dependency>
92+
<groupId>org.testcontainers</groupId>
93+
<artifactId>mssqlserver</artifactId>
94+
<scope>test</scope>
95+
</dependency>
96+
8597
<dependency>
8698
<groupId>org.testcontainers</groupId>
8799
<artifactId>postgresql</artifactId>

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java

+50-13
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,31 @@
1616

1717
package org.springframework.ai.chat.memory.jdbc;
1818

19+
import org.springframework.ai.chat.memory.ChatMemory;
20+
import org.springframework.ai.chat.messages.*;
21+
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
22+
import org.springframework.boot.jdbc.DatabaseDriver;
23+
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
24+
import org.springframework.jdbc.core.JdbcTemplate;
25+
import org.springframework.jdbc.core.RowMapper;
26+
import org.springframework.stereotype.Component;
27+
import org.springframework.util.Assert;
28+
29+
import java.sql.Connection;
1930
import java.sql.PreparedStatement;
2031
import java.sql.ResultSet;
2132
import java.sql.SQLException;
2233
import java.util.List;
2334

24-
import org.springframework.ai.chat.memory.ChatMemory;
25-
import org.springframework.ai.chat.messages.AssistantMessage;
26-
import org.springframework.ai.chat.messages.Message;
27-
import org.springframework.ai.chat.messages.MessageType;
28-
import org.springframework.ai.chat.messages.SystemMessage;
29-
import org.springframework.ai.chat.messages.UserMessage;
30-
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
31-
import org.springframework.jdbc.core.JdbcTemplate;
32-
import org.springframework.jdbc.core.RowMapper;
35+
import static org.springframework.boot.jdbc.DatabaseDriver.SQLSERVER;
3336

3437
/**
3538
* An implementation of {@link ChatMemory} for JDBC. Creating an instance of
3639
* JdbcChatMemory example:
3740
* <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code>
3841
*
3942
* @author Jonathan Leijendekker
43+
* @author Xavier Chopin
4044
* @since 1.0.0
4145
*/
4246
public class JdbcChatMemory implements ChatMemory {
@@ -45,14 +49,33 @@ public class JdbcChatMemory implements ChatMemory {
4549
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""";
4650

4751
private static final String QUERY_GET = """
48-
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""";
52+
SELECT content, type \
53+
FROM ai_chat_memory \
54+
WHERE conversation_id = ? \
55+
ORDER BY "timestamp" DESC \
56+
LIMIT ?
57+
""";
58+
59+
private static final String MSSQL_QUERY_GET = """
60+
SELECT content, type \
61+
FROM ( \
62+
SELECT TOP (?) content, type, [timestamp] \
63+
FROM ai_chat_memory \
64+
WHERE conversation_id = ? \
65+
ORDER BY [timestamp] DESC \
66+
) AS recent \
67+
ORDER BY [timestamp] ASC \
68+
""";
4969

5070
private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";
5171

5272
private final JdbcTemplate jdbcTemplate;
5373

74+
private final DatabaseDriver driver;
75+
5476
public JdbcChatMemory(JdbcChatMemoryConfig config) {
5577
this.jdbcTemplate = config.getJdbcTemplate();
78+
this.driver = this.detectDialect(this.jdbcTemplate);
5679
}
5780

5881
public static JdbcChatMemory create(JdbcChatMemoryConfig config) {
@@ -66,16 +89,19 @@ public void add(String conversationId, List<Message> messages) {
6689

6790
@Override
6891
public List<Message> get(String conversationId, int lastN) {
69-
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
92+
return switch (driver) {
93+
case SQLSERVER -> this.jdbcTemplate.query(MSSQL_QUERY_GET, new MessageRowMapper(), lastN, conversationId);
94+
default -> this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
95+
};
7096
}
7197

7298
@Override
7399
public void clear(String conversationId) {
74100
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
75101
}
76102

77-
private record AddBatchPreparedStatement(String conversationId,
78-
List<Message> messages) implements BatchPreparedStatementSetter {
103+
private record AddBatchPreparedStatement(String conversationId, List<Message> messages)
104+
implements BatchPreparedStatementSetter {
79105
@Override
80106
public void setValues(PreparedStatement ps, int i) throws SQLException {
81107
var message = this.messages.get(i);
@@ -108,4 +134,15 @@ public Message mapRow(ResultSet rs, int i) throws SQLException {
108134

109135
}
110136

137+
private DatabaseDriver detectDialect(JdbcTemplate jdbcTemplate) {
138+
try {
139+
Assert.notNull(jdbcTemplate.getDataSource(), "jdbcTemplate.dataSource must not be null");
140+
try (Connection conn = jdbcTemplate.getDataSource().getConnection()) {
141+
String url = conn.getMetaData().getURL();
142+
return DatabaseDriver.fromJdbcUrl(url);
143+
}
144+
} catch (SQLException ex) {
145+
throw new IllegalStateException("Impossible to detect dialect", ex);
146+
}
147+
}
111148
}

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@
1616

1717
package org.springframework.ai.chat.memory.jdbc;
1818

19+
import org.springframework.boot.jdbc.DatabaseDriver;
1920
import org.springframework.jdbc.core.JdbcTemplate;
2021
import org.springframework.util.Assert;
2122

23+
import java.sql.Connection;
24+
import java.sql.SQLException;
25+
2226
/**
2327
* Configuration for {@link JdbcChatMemory}.
2428
*
@@ -50,7 +54,6 @@ private Builder() {
5054

5155
public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) {
5256
Assert.notNull(jdbcTemplate, "jdbc template must not be null");
53-
5457
this.jdbcTemplate = jdbcTemplate;
5558
return this;
5659
}
@@ -60,7 +63,6 @@ public JdbcChatMemoryConfig build() {
6063

6164
return new JdbcChatMemoryConfig(this);
6265
}
63-
6466
}
6567

6668
}

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
3636

3737
hints.resources()
3838
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql")
39+
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-mssql.sql")
3940
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql");
4041
}
4142

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
CREATE TABLE ai_chat_memory (
2+
conversation_id VARCHAR(36) NOT NULL,
3+
content NVARCHAR(MAX) NOT NULL,
4+
type VARCHAR(10) NOT NULL,
5+
[timestamp] DATETIME2 NOT NULL DEFAULT SYSDATETIME(),
6+
CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL'))
7+
);
8+
9+
CREATE INDEX ai_chat_memory_conversation_id_timestamp_idx ON ai_chat_memory(conversation_id, [timestamp] DESC);
+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
* @author Jonathan Leijendekker
5454
*/
5555
@Testcontainers
56-
class JdbcChatMemoryIT {
56+
class JdbcChatMemoryPostgresSQLIT {
5757

5858
@Container
5959
@SuppressWarnings("resource")

0 commit comments

Comments
 (0)