16
16
17
17
package org .springframework .ai .chat .memory .jdbc ;
18
18
19
+ import org .springframework .ai .chat .memory .ChatMemory ;
20
+ import org .springframework .ai .chat .messages .*;
21
+ import org .springframework .boot .autoconfigure .condition .ConditionalOnMissingClass ;
22
+ import org .springframework .boot .jdbc .DatabaseDriver ;
23
+ import org .springframework .jdbc .core .BatchPreparedStatementSetter ;
24
+ import org .springframework .jdbc .core .JdbcTemplate ;
25
+ import org .springframework .jdbc .core .RowMapper ;
26
+ import org .springframework .stereotype .Component ;
27
+ import org .springframework .util .Assert ;
28
+
29
+ import java .sql .Connection ;
19
30
import java .sql .PreparedStatement ;
20
31
import java .sql .ResultSet ;
21
32
import java .sql .SQLException ;
22
33
import java .util .List ;
23
34
24
- import org .springframework .ai .chat .memory .ChatMemory ;
25
- import org .springframework .ai .chat .messages .AssistantMessage ;
26
- import org .springframework .ai .chat .messages .Message ;
27
- import org .springframework .ai .chat .messages .MessageType ;
28
- import org .springframework .ai .chat .messages .SystemMessage ;
29
- import org .springframework .ai .chat .messages .UserMessage ;
30
- import org .springframework .jdbc .core .BatchPreparedStatementSetter ;
31
- import org .springframework .jdbc .core .JdbcTemplate ;
32
- import org .springframework .jdbc .core .RowMapper ;
35
+ import static org .springframework .boot .jdbc .DatabaseDriver .SQLSERVER ;
33
36
34
37
/**
35
38
* An implementation of {@link ChatMemory} for JDBC. Creating an instance of
36
39
* JdbcChatMemory example:
37
40
* <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code>
38
41
*
39
42
* @author Jonathan Leijendekker
43
+ * @author Xavier Chopin
40
44
* @since 1.0.0
41
45
*/
42
46
public class JdbcChatMemory implements ChatMemory {
@@ -45,14 +49,33 @@ public class JdbcChatMemory implements ChatMemory {
45
49
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""" ;
46
50
47
51
private static final String QUERY_GET = """
48
- SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""" ;
52
+ SELECT content, type \
53
+ FROM ai_chat_memory \
54
+ WHERE conversation_id = ? \
55
+ ORDER BY "timestamp" DESC \
56
+ LIMIT ?
57
+ """ ;
58
+
59
+ private static final String MSSQL_QUERY_GET = """
60
+ SELECT content, type \
61
+ FROM ( \
62
+ SELECT TOP (?) content, type, [timestamp] \
63
+ FROM ai_chat_memory \
64
+ WHERE conversation_id = ? \
65
+ ORDER BY [timestamp] DESC \
66
+ ) AS recent \
67
+ ORDER BY [timestamp] ASC \
68
+ """ ;
49
69
50
70
private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?" ;
51
71
52
72
private final JdbcTemplate jdbcTemplate ;
53
73
74
+ private final DatabaseDriver driver ;
75
+
54
76
public JdbcChatMemory (JdbcChatMemoryConfig config ) {
55
77
this .jdbcTemplate = config .getJdbcTemplate ();
78
+ this .driver = this .detectDialect (this .jdbcTemplate );
56
79
}
57
80
58
81
public static JdbcChatMemory create (JdbcChatMemoryConfig config ) {
@@ -66,16 +89,19 @@ public void add(String conversationId, List<Message> messages) {
66
89
67
90
@ Override
68
91
public List <Message > get (String conversationId , int lastN ) {
69
- return this .jdbcTemplate .query (QUERY_GET , new MessageRowMapper (), conversationId , lastN );
92
+ return switch (driver ) {
93
+ case SQLSERVER -> this .jdbcTemplate .query (MSSQL_QUERY_GET , new MessageRowMapper (), lastN , conversationId );
94
+ default -> this .jdbcTemplate .query (QUERY_GET , new MessageRowMapper (), conversationId , lastN );
95
+ };
70
96
}
71
97
72
98
@ Override
73
99
public void clear (String conversationId ) {
74
100
this .jdbcTemplate .update (QUERY_CLEAR , conversationId );
75
101
}
76
102
77
- private record AddBatchPreparedStatement (String conversationId ,
78
- List < Message > messages ) implements BatchPreparedStatementSetter {
103
+ private record AddBatchPreparedStatement (String conversationId , List < Message > messages )
104
+ implements BatchPreparedStatementSetter {
79
105
@ Override
80
106
public void setValues (PreparedStatement ps , int i ) throws SQLException {
81
107
var message = this .messages .get (i );
@@ -108,4 +134,15 @@ public Message mapRow(ResultSet rs, int i) throws SQLException {
108
134
109
135
}
110
136
137
+ private DatabaseDriver detectDialect (JdbcTemplate jdbcTemplate ) {
138
+ try {
139
+ Assert .notNull (jdbcTemplate .getDataSource (), "jdbcTemplate.dataSource must not be null" );
140
+ try (Connection conn = jdbcTemplate .getDataSource ().getConnection ()) {
141
+ String url = conn .getMetaData ().getURL ();
142
+ return DatabaseDriver .fromJdbcUrl (url );
143
+ }
144
+ } catch (SQLException ex ) {
145
+ throw new IllegalStateException ("Impossible to detect dialect" , ex );
146
+ }
147
+ }
111
148
}
0 commit comments