16
16
17
17
package org .springframework .ai .chat .memory .jdbc ;
18
18
19
- import java .sql .PreparedStatement ;
20
- import java .sql .ResultSet ;
21
- import java .sql .SQLException ;
22
- import java .util .List ;
23
-
24
19
import org .springframework .ai .chat .memory .ChatMemory ;
25
- import org .springframework .ai .chat .messages .AssistantMessage ;
26
- import org .springframework .ai .chat .messages .Message ;
27
- import org .springframework .ai .chat .messages .MessageType ;
28
- import org .springframework .ai .chat .messages .SystemMessage ;
29
- import org .springframework .ai .chat .messages .UserMessage ;
20
+ import org .springframework .ai .chat .messages .*;
21
+ import org .springframework .boot .jdbc .DatabaseDriver ;
30
22
import org .springframework .jdbc .core .BatchPreparedStatementSetter ;
31
23
import org .springframework .jdbc .core .JdbcTemplate ;
32
24
import org .springframework .jdbc .core .RowMapper ;
25
+ import org .springframework .util .Assert ;
26
+
27
+ import java .sql .Connection ;
28
+ import java .sql .PreparedStatement ;
29
+ import java .sql .ResultSet ;
30
+ import java .sql .SQLException ;
31
+ import java .util .List ;
33
32
34
33
/**
35
34
* An implementation of {@link ChatMemory} for JDBC. Creating an instance of
36
35
* JdbcChatMemory example:
37
36
* <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code>
38
37
*
39
38
* @author Jonathan Leijendekker
39
+ * @author Xavier Chopin
40
40
* @since 1.0.0
41
41
*/
42
42
public class JdbcChatMemory implements ChatMemory {
@@ -45,14 +45,33 @@ public class JdbcChatMemory implements ChatMemory {
45
45
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""" ;
46
46
47
47
private static final String QUERY_GET = """
48
- SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""" ;
48
+ SELECT content, type \
49
+ FROM ai_chat_memory \
50
+ WHERE conversation_id = ? \
51
+ ORDER BY "timestamp" DESC \
52
+ LIMIT ?
53
+ """ ;
54
+
55
+ private static final String MSSQL_QUERY_GET = """
56
+ SELECT content, type \
57
+ FROM ( \
58
+ SELECT TOP (?) content, type, [timestamp] \
59
+ FROM ai_chat_memory \
60
+ WHERE conversation_id = ? \
61
+ ORDER BY [timestamp] DESC \
62
+ ) AS recent \
63
+ ORDER BY [timestamp] ASC \
64
+ """ ;
49
65
50
66
private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?" ;
51
67
52
68
private final JdbcTemplate jdbcTemplate ;
53
69
70
+ private final DatabaseDriver driver ;
71
+
54
72
public JdbcChatMemory (JdbcChatMemoryConfig config ) {
55
73
this .jdbcTemplate = config .getJdbcTemplate ();
74
+ this .driver = this .detectDialect (this .jdbcTemplate );
56
75
}
57
76
58
77
public static JdbcChatMemory create (JdbcChatMemoryConfig config ) {
@@ -66,16 +85,19 @@ public void add(String conversationId, List<Message> messages) {
66
85
67
86
@ Override
68
87
public List <Message > get (String conversationId , int lastN ) {
69
- return this .jdbcTemplate .query (QUERY_GET , new MessageRowMapper (), conversationId , lastN );
88
+ return switch (driver ) {
89
+ case SQLSERVER -> this .jdbcTemplate .query (MSSQL_QUERY_GET , new MessageRowMapper (), lastN , conversationId );
90
+ default -> this .jdbcTemplate .query (QUERY_GET , new MessageRowMapper (), conversationId , lastN );
91
+ };
70
92
}
71
93
72
94
@ Override
73
95
public void clear (String conversationId ) {
74
96
this .jdbcTemplate .update (QUERY_CLEAR , conversationId );
75
97
}
76
98
77
- private record AddBatchPreparedStatement (String conversationId ,
78
- List < Message > messages ) implements BatchPreparedStatementSetter {
99
+ private record AddBatchPreparedStatement (String conversationId , List < Message > messages )
100
+ implements BatchPreparedStatementSetter {
79
101
@ Override
80
102
public void setValues (PreparedStatement ps , int i ) throws SQLException {
81
103
var message = this .messages .get (i );
@@ -108,4 +130,15 @@ public Message mapRow(ResultSet rs, int i) throws SQLException {
108
130
109
131
}
110
132
133
+ private DatabaseDriver detectDialect (JdbcTemplate jdbcTemplate ) {
134
+ try {
135
+ Assert .notNull (jdbcTemplate .getDataSource (), "jdbcTemplate.dataSource must not be null" );
136
+ try (Connection conn = jdbcTemplate .getDataSource ().getConnection ()) {
137
+ String url = conn .getMetaData ().getURL ();
138
+ return DatabaseDriver .fromJdbcUrl (url );
139
+ }
140
+ } catch (SQLException ex ) {
141
+ throw new IllegalStateException ("Impossible to detect dialect" , ex );
142
+ }
143
+ }
111
144
}
0 commit comments