Skip to content

Commit 372753e

Browse files
authored
Merge pull request #26 from welsir/master
会话联调
2 parents c987ce9 + 8847d1e commit 372753e

File tree

22 files changed

+1704
-612
lines changed

22 files changed

+1704
-612
lines changed

prompto-lab-app/pom.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
<description>提示词工程</description>
1616
<properties>
1717
<java.version>17</java.version>
18+
<maven.compiler.source>17</maven.compiler.source>
19+
<maven.compiler.target>17</maven.compiler.target>
1820
</properties>
1921
<dependencies>
2022
<dependency>

prompto-lab-app/src/main/java/io/github/timemachinelab/config/CorsConfig.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ public void addCorsMappings(CorsRegistry registry) {
2727
public CorsConfigurationSource corsConfigurationSource() {
2828
CorsConfiguration configuration = new CorsConfiguration();
2929

30-
// 允许的源
31-
configuration.setAllowedOriginPatterns(Arrays.asList("http://localhost:*", "http://127.0.0.1:*"));
30+
// 允许的源 - 使用具体的域名模式而不是通配符
31+
configuration.setAllowedOriginPatterns(Arrays.asList(
32+
"http://localhost:*",
33+
"http://127.0.0.1:*",
34+
"https://localhost:*",
35+
"https://127.0.0.1:*"
36+
));
3237

3338
// 允许的HTTP方法
3439
configuration.setAllowedMethods(Arrays.asList("GET", "POST", "PUT", "DELETE", "OPTIONS"));

prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
11
package io.github.timemachinelab.controller;
22

3+
import io.github.timemachinelab.core.session.application.ConversationService;
4+
import io.github.timemachinelab.core.session.application.MessageProcessingService;
5+
import io.github.timemachinelab.core.session.application.SessionManagementService;
6+
import io.github.timemachinelab.core.session.domain.entity.ConversationSession;
7+
import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation;
8+
import io.github.timemachinelab.core.session.infrastructure.web.dto.UnifiedAnswerRequest;
9+
import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageResponse;
310
import io.github.timemachinelab.entity.req.RetryRequest;
411
import io.github.timemachinelab.entity.resp.ApiResult;
512
import io.github.timemachinelab.entity.resp.RetryResponse;
613
import lombok.extern.slf4j.Slf4j;
14+
import org.springframework.http.MediaType;
715
import org.springframework.http.ResponseEntity;
816
import org.springframework.validation.annotation.Validated;
917
import org.springframework.web.bind.annotation.*;
18+
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
1019

20+
import javax.annotation.Resource;
1121
import javax.validation.Valid;
22+
import java.io.IOException;
23+
import java.util.Map;
24+
import java.util.UUID;
25+
import java.util.concurrent.ConcurrentHashMap;
1226

1327
/**
1428
* 用户交互控制器
@@ -23,6 +37,55 @@
2337
@Validated
2438
public class UserInteractionController {
2539

40+
@Resource
41+
private ConversationService conversationService;
42+
@Resource
43+
private MessageProcessingService messageProcessingService;
44+
@Resource
45+
private SessionManagementService sessionManagementService;
46+
private final Map<String, SseEmitter> sseEmitters = new ConcurrentHashMap<>();
47+
48+
/**
49+
* 建立SSE连接
50+
*/
51+
@GetMapping(value = "/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
52+
public SseEmitter streamConversation(@RequestParam(required = false) String sessionId) {
53+
log.info("建立SSE连接 - 会话ID: {}", sessionId);
54+
55+
if(sessionId == null || sessionId.isEmpty()) {
56+
sessionId = UUID.randomUUID().toString();
57+
}
58+
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
59+
sseEmitters.put(sessionId, emitter);
60+
61+
// 连接建立时发送欢迎消息
62+
try {
63+
emitter.send(SseEmitter.event()
64+
.name("connected")
65+
.data("SSE连接已建立,会话ID: " + sessionId));
66+
} catch (IOException e) {
67+
log.error("发送欢迎消息失败: {}", e.getMessage());
68+
}
69+
70+
// 设置连接事件处理
71+
String finalSessionId = sessionId;
72+
emitter.onCompletion(() -> {
73+
log.info("SSE连接完成: {}", finalSessionId);
74+
});
75+
76+
emitter.onTimeout(() -> {
77+
log.info("SSE连接超时: {}", finalSessionId);
78+
sseEmitters.remove(finalSessionId);
79+
});
80+
81+
emitter.onError((ex) -> {
82+
log.error("SSE连接错误: {} - {}", finalSessionId, ex.getMessage());
83+
sseEmitters.remove(finalSessionId);
84+
});
85+
86+
return emitter;
87+
}
88+
2689
/**
2790
* 重试接口
2891
*
@@ -34,7 +97,9 @@ public ResponseEntity<ApiResult<RetryResponse>> retry(@Valid @RequestBody RetryR
3497
try {
3598
log.info("收到重试请求 - nodeId: {}, sessionId: {}, whyretry: {}",
3699
request.getNodeId(), request.getSessionId(), request.getWhyretry());
37-
100+
101+
102+
38103
// 构建响应数据
39104
RetryResponse response = RetryResponse.builder()
40105
.nodeId(request.getNodeId())
@@ -53,4 +118,90 @@ public ResponseEntity<ApiResult<RetryResponse>> retry(@Valid @RequestBody RetryR
53118
return ResponseEntity.badRequest().body(ApiResult.serverError("重试请求处理失败: " + e.getMessage()));
54119
}
55120
}
121+
122+
/**
123+
* 处理统一答案请求
124+
* 支持单选、多选、输入框、表单等多种问题类型的回答
125+
*/
126+
@PostMapping("/message")
127+
public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswerRequest request) {
128+
try {
129+
log.info("接收到答案请求 - 会话ID: {}, 节点ID: {}, 问题类型: {}",
130+
request.getSessionId(),
131+
request.getNodeId(),
132+
request.getQuestionType());
133+
134+
// 1. 会话管理和验证
135+
String userId = request.getUserId();
136+
137+
ConversationSession session = sessionManagementService.getOrCreateSession(userId, request.getSessionId());
138+
139+
// 2. 验证nodeId是否属于该会话
140+
if (request.getNodeId() != null && !sessionManagementService.validateNodeId(session.getSessionId(), request.getNodeId())) {
141+
log.warn("无效的节点ID - 会话: {}, 节点: {}", session.getSessionId(), request.getNodeId());
142+
return ResponseEntity.badRequest().body("无效的节点ID");
143+
}
144+
145+
// 3. 验证答案格式
146+
if (!messageProcessingService.validateAnswer(request)) {
147+
log.warn("答案格式验证失败: {}", request);
148+
return ResponseEntity.badRequest().body("答案格式不正确");
149+
}
150+
151+
// 4. 处理答案并转换为消息
152+
String processedMessage = messageProcessingService.preprocessMessage(
153+
null, // 没有额外的原始消息
154+
request,
155+
session
156+
);
157+
158+
// 5. 发送处理后的消息给AI服务
159+
conversationService.processUserMessage(
160+
session.getUserId(),
161+
processedMessage,
162+
response -> sendSseMessage(session.getSessionId(), response)
163+
);
164+
165+
return ResponseEntity.ok("答案处理成功");
166+
167+
} catch (Exception e) {
168+
log.error("处理答案失败 - 会话ID: {}, 错误: {}", request.getSessionId(), e.getMessage(), e);
169+
return ResponseEntity.internalServerError().body("答案处理失败: " + e.getMessage());
170+
}
171+
}
172+
173+
/**
174+
* 通过SSE发送消息给客户端
175+
*
176+
* @param sessionId 会话ID
177+
* @param response 消息响应对象
178+
*/
179+
private void sendSseMessage(String sessionId, QuestionGenerationOperation.QuestionGenerationResponse response) {
180+
SseEmitter emitter = sseEmitters.get(sessionId);
181+
if (emitter != null) {
182+
try {
183+
emitter.send(SseEmitter.event()
184+
.name("message")
185+
.data(response));
186+
log.info("SSE消息发送成功 - 会话: {}, 消息: {}", sessionId, response);
187+
} catch (IOException e) {
188+
log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, e.getMessage());
189+
sseEmitters.remove(sessionId);
190+
}
191+
} else {
192+
log.warn("SSE连接不存在 - 会话: {}", sessionId);
193+
}
194+
}
195+
196+
/**
197+
* 获取SSE连接状态
198+
*/
199+
@GetMapping("/sse-status")
200+
public ResponseEntity<Map<String, Object>> getSseStatus() {
201+
Map<String, Object> status = new ConcurrentHashMap<>();
202+
status.put("connectedSessions", sseEmitters.keySet());
203+
status.put("totalConnections", sseEmitters.size());
204+
status.put("timestamp", System.currentTimeMillis());
205+
return ResponseEntity.ok(status);
206+
}
56207
}

prompto-lab-app/src/main/java/io/github/timemachinelab/core/question/QuestionParser.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.alibaba.fastjson2.JSON;
44
import com.alibaba.fastjson2.JSONException;
55
import com.alibaba.fastjson2.JSONObject;
6+
import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation;
67
import lombok.extern.slf4j.Slf4j;
78

89
import java.util.Arrays;
@@ -37,7 +38,7 @@ public class QuestionParser {
3738
* @return BaseQuestion对象
3839
* @throws QuestionParseException 解析失败时抛出异常
3940
*/
40-
public static BaseQuestion parseQuestion(String jsonStr) throws QuestionParseException {
41+
public static QuestionGenerationOperation.QuestionGenerationResponse parseQuestion(String jsonStr) throws QuestionParseException {
4142
if (jsonStr == null || jsonStr.trim().isEmpty()) {
4243
throw new QuestionParseException("JSON字符串不能为空", jsonStr, "输入为空或null");
4344
}
@@ -52,7 +53,7 @@ public static BaseQuestion parseQuestion(String jsonStr) throws QuestionParseExc
5253

5354
// 收集所有解析失败的原因
5455
List<String> failureReasons = new ArrayList<>();
55-
56+
String parentId = jsonObject.getString("parentId");
5657
// 依次尝试解析成不同类型
5758
for (Class<? extends BaseQuestion> questionType : QUESTION_TYPES) {
5859
try {
@@ -61,7 +62,7 @@ public static BaseQuestion parseQuestion(String jsonStr) throws QuestionParseExc
6162
String validationResult = validateQuestion(question, jsonObject);
6263
if (validationResult == null) {
6364
log.info("成功解析为: {}", questionType.getSimpleName());
64-
return question;
65+
return new QuestionGenerationOperation.QuestionGenerationResponse(question,parentId);
6566
} else {
6667
failureReasons.add(questionType.getSimpleName() + ": " + validationResult);
6768
}

prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/ConversationService.java

Lines changed: 26 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
package io.github.timemachinelab.core.session.application;
22

3+
import com.alibaba.fastjson.JSON;
4+
import com.alibaba.fastjson.JSONObject;
5+
import io.github.timemachinelab.core.qatree.QaTree;
6+
import io.github.timemachinelab.core.qatree.QaTreeDomain;
7+
import io.github.timemachinelab.core.qatree.QaTreeNode;
8+
import io.github.timemachinelab.core.question.BaseQuestion;
39
import io.github.timemachinelab.core.session.domain.entity.ConversationSession;
10+
import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation;
411
import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageResponse;
512
import io.github.timemachinelab.core.session.infrastructure.ai.ConversationOperation;
613
import io.github.timemachinelab.sfchain.core.AIService;
714
import lombok.RequiredArgsConstructor;
815
import lombok.extern.slf4j.Slf4j;
916
import org.springframework.stereotype.Service;
1017

18+
import javax.annotation.Resource;
1119
import java.util.concurrent.ConcurrentHashMap;
1220
import java.util.Map;
1321
import java.util.List;
@@ -18,93 +26,43 @@
1826
@RequiredArgsConstructor
1927
@Slf4j
2028
public class ConversationService {
21-
29+
30+
@Resource
2231
private final AIService aiService;
32+
@Resource
33+
private SessionManagementService sessionManagementService;
34+
private final QaTreeDomain qaTreeDomain;
2335

24-
private final Map<String, ConversationSession> sessions = new ConcurrentHashMap<>();
2536

26-
public ConversationSession createSession(String userId) {
27-
ConversationSession session = new ConversationSession(userId);
28-
sessions.put(session.getSessionId(), session);
29-
return session;
30-
}
31-
32-
public ConversationSession getSession(String sessionId) {
33-
return sessions.get(sessionId);
34-
}
35-
36-
public void processUserMessage(String sessionId, String userMessage, Consumer<MessageResponse> sseCallback) {
37-
ConversationSession session = sessions.get(sessionId);
37+
public void processUserMessage(String userId, String userMessage, Consumer<QuestionGenerationOperation.QuestionGenerationResponse> sseCallback) {
38+
ConversationSession session = sessionManagementService.getUserCurrentSession(userId);
3839
if (session == null) {
39-
log.warn("会话不存在: {}", sessionId);
40+
log.warn("会话不存在");
4041
return;
4142
}
42-
43-
// 1. 添加用户消息到会话历史
44-
session.addMessage(userMessage, "user");
45-
46-
// 2. 发送用户消息确认
47-
sseCallback.accept(MessageResponse.userAnswer("user_" + System.currentTimeMillis(), userMessage));
48-
49-
// 3. 调用AI服务获取回复
50-
processAIResponse(session, userMessage, sseCallback);
43+
44+
processAIResponse(userMessage, sseCallback);
5145
}
5246

53-
private void processAIResponse(ConversationSession session, String userMessage, Consumer<MessageResponse> sseCallback) {
47+
private void processAIResponse(String userMessage, Consumer<QuestionGenerationOperation.QuestionGenerationResponse> sseCallback) {
5448
try {
55-
// 构建对话历史
56-
List<ConversationOperation.ConversationHistory> history = buildConversationHistory(session);
57-
49+
50+
JSONObject object = JSON.parseObject(userMessage);
51+
5852
// 创建AI请求
59-
ConversationOperation.ConversationRequest request = new ConversationOperation.ConversationRequest(
60-
session.getSessionId(),
61-
"current",
62-
userMessage
63-
);
64-
request.setConversationHistory(history);
65-
53+
QuestionGenerationOperation.QuestionGenerationRequest request = new QuestionGenerationOperation.QuestionGenerationRequest(object.getString("prompt"),object.getString("tree"),object.getString("input"));
6654
// 调用AI服务
67-
ConversationOperation.ConversationResponse aiResponse = aiService.execute("CONVERSATION_OP", request);
68-
log.info("AI服务调用成功: {}", aiResponse);
55+
QuestionGenerationOperation.QuestionGenerationResponse aiResponse = aiService.execute("QUESTION_GENERATION_OP", request);
6956

70-
// 添加AI回复到会话历史
71-
session.addMessage(aiResponse.getAnswer(), "assistant");
72-
73-
// 根据响应类型处理AI回复
74-
String nodeId = "ai_" + System.currentTimeMillis();
75-
sseCallback.accept(MessageResponse.aiAnswer("ai_" + System.currentTimeMillis(), aiResponse.getAnswer()));
76-
// if (aiResponse.getResponseType() == ConversationOperation.ResponseType.SELECTION) {
77-
// // 选择题类型
78-
// sseCallback.accept(MessageResponse.aiSelectionQuestion(nodeId, aiResponse.getAnswer(), aiResponse.getOptions()));
79-
// } else {
80-
// // 普通文本回复
81-
// sseCallback.accept(MessageResponse.aiQuestion(nodeId, aiResponse.getAnswer()));
82-
// }
57+
sseCallback.accept(aiResponse);
58+
log.info("AI服务调用成功: {}", aiResponse);
8359

8460
} catch (Exception e) {
8561
log.error("AI服务调用失败: {}", e.getMessage(), e);
8662
// 降级处理
8763
String fallbackResponse = "抱歉,我暂时无法处理您的请求,请稍后再试。";
88-
session.addMessage(fallbackResponse, "assistant");
8964
String nodeId = "ai_" + System.currentTimeMillis();
90-
sseCallback.accept(MessageResponse.aiQuestion(nodeId, fallbackResponse));
91-
}
92-
}
93-
94-
private List<ConversationOperation.ConversationHistory> buildConversationHistory(ConversationSession session) {
95-
List<ConversationOperation.ConversationHistory> history = new ArrayList<>();
96-
97-
// 从会话消息构建对话历史
98-
for (ConversationSession.ConversationMessage message : session.getMessages()) {
99-
ConversationOperation.ConversationHistory historyItem = new ConversationOperation.ConversationHistory(
100-
message.getRole(),
101-
message.getContent(),
102-
message.getRole() + "_" + message.getTimestamp().toString()
103-
);
104-
history.add(historyItem);
10565
}
106-
107-
return history;
10866
}
10967

11068

0 commit comments

Comments
 (0)