Skip to content

Commit 63b20f1

Browse files
authored
Merge pull request #50 from welsir/master
前端修改
2 parents 66ae912 + d741071 commit 63b20f1

File tree

7 files changed

+1242
-151
lines changed

7 files changed

+1242
-151
lines changed

prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import io.github.timemachinelab.core.session.application.RetryProcessingService;
1414
import io.github.timemachinelab.core.qatree.QaTreeDomain;
1515
import io.github.timemachinelab.core.qatree.QaTree;
16+
import io.github.timemachinelab.core.qatree.QaTreeNode;
17+
import io.github.timemachinelab.core.question.InputQuestion;
1618

1719
import io.github.timemachinelab.core.session.domain.entity.ConversationSession;
1820
import io.github.timemachinelab.core.session.infrastructure.web.dto.*;
@@ -21,6 +23,7 @@
2123
import io.github.timemachinelab.entity.resp.ApiResult;
2224
import io.github.timemachinelab.entity.resp.RetryResponse;
2325
import lombok.extern.slf4j.Slf4j;
26+
import org.apache.commons.lang.StringUtils;
2427
import org.springframework.http.MediaType;
2528
import org.springframework.http.ResponseEntity;
2629
import org.springframework.validation.annotation.Validated;
@@ -32,6 +35,7 @@
3235
import javax.validation.Valid;
3336
import java.io.IOException;
3437
import java.time.LocalDateTime;
38+
import java.util.HashMap;
3539
import java.util.List;
3640
import java.util.Map;
3741
import java.util.Objects;
@@ -253,6 +257,17 @@ public ResponseEntity<String> genPrompt(@RequestBody GenPromptRequest request, H
253257
// 2. 检查sessionId和answer的逻辑
254258
String sessionId = request.getSessionId();
255259
Object answer = request.getAnswer();
260+
261+
// 将Object类型的answer转换为String类型
262+
String answerStr = "";
263+
if (answer != null) {
264+
if (answer instanceof String) {
265+
answerStr = (String) answer;
266+
} else {
267+
// 对于其他类型,转换为JSON字符串
268+
answerStr = JSONObject.toJSONString(answer);
269+
}
270+
}
256271

257272
if (sessionId == null || sessionId.trim().isEmpty()) {
258273
// 如果没有sessionId,必须检查answer是否为空
@@ -269,11 +284,6 @@ public ResponseEntity<String> genPrompt(@RequestBody GenPromptRequest request, H
269284
if (sessionId == null || sessionId.trim().isEmpty()) {
270285
// 新建会话
271286
session = sessionManagementService.createNewSession(fingerprint);
272-
if (session == null) {
273-
log.error("会话创建失败 - 指纹: {}", fingerprint);
274-
sseNotificationService.sendErrorMessage(fingerprint, "会话创建失败,请重试"); // 保持原样,因为错误消息的发送方式未改变
275-
return ResponseEntity.internalServerError().body("会话处理失败");
276-
}
277287
} else {
278288
// 3. 如果存在sessionId,获取conversation的currentNodeId,表示当前node节点需要过滤
279289
session = sessionManagementService.validateAndGetSession(fingerprint, sessionId);
@@ -283,15 +293,37 @@ public ResponseEntity<String> genPrompt(@RequestBody GenPromptRequest request, H
283293
return ResponseEntity.badRequest().body("会话不存在或已失效");
284294
}
285295

286-
// 获取当前节点ID并过滤qaTree
296+
// 获取当前节点ID
287297
String currentNodeId = session.getCurrentNode();
288298
QaTree originalQaTree = session.getQaTree();
299+
300+
// answerStr已在方法开始处定义,这里直接使用
301+
302+
// 检查是否是在回答当前问题
303+
if (!StringUtils.isBlank(answerStr)) {
304+
// 用户提供了答案,说明是在回答当前问题
305+
// 先将答案插入到qaTree中
306+
try {
307+
boolean updateSuccess = qaTreeDomain.updateNodeAnswer(originalQaTree, currentNodeId, answerStr);
308+
if (updateSuccess) {
309+
log.info("已将用户答案插入qaTree - 会话: {}, 节点: {}, 答案: {}", sessionId, currentNodeId, answerStr);
310+
} else {
311+
log.warn("更新节点答案失败,节点可能不存在 - 会话: {}, 节点: {}", sessionId, currentNodeId);
312+
sseNotificationService.sendErrorMessage(fingerprint, "当前问题节点不存在,请刷新页面重试");
313+
return ResponseEntity.badRequest().body("当前问题节点不存在");
314+
}
315+
} catch (Exception e) {
316+
log.error("插入用户答案失败 - 会话: {}, 节点: {}", sessionId, currentNodeId, e);
317+
sseNotificationService.sendErrorMessage(fingerprint, "处理用户答案失败,请重试");
318+
return ResponseEntity.internalServerError().body("处理用户答案失败");
319+
}
320+
}
321+
289322
// 4. 在qaTreeDomain里过滤qaNode(如果answer不存在则过滤),返回整个qaTree
290323
filteredQaTree = qaTreeDomain.filterQaTreeByAnswer(originalQaTree, currentNodeId);
291324
log.info("已过滤qaTree - 会话: {}, 过滤节点: {}", sessionId, currentNodeId);
292325
}
293326

294-
// 5. 走现在有的逻辑(从创建会话开始) - 调用AI服务生成提示词
295327
// 如果有过滤后的qaTree,临时替换session中的qaTree
296328
QaTree originalQaTree = null;
297329
if (filteredQaTree != null) {
@@ -300,16 +332,7 @@ public ResponseEntity<String> genPrompt(@RequestBody GenPromptRequest request, H
300332
}
301333

302334
final QaTree finalOriginalQaTree = originalQaTree;
303-
// 将Object类型的answer转换为String类型
304-
String answerStr = "";
305-
if (request.getAnswer() != null) {
306-
if (request.getAnswer() instanceof String) {
307-
answerStr = (String) request.getAnswer();
308-
} else {
309-
// 对于其他类型,转换为JSON字符串
310-
answerStr = JSONObject.toJSONString(request.getAnswer());
311-
}
312-
}
335+
// answerStr已在上面定义,这里不需要重复定义
313336

314337
conversationService.genPrompt(session.getSessionId(), answerStr, response -> {
315338
try {
@@ -318,15 +341,34 @@ public ResponseEntity<String> genPrompt(@RequestBody GenPromptRequest request, H
318341
session.setQaTree(finalOriginalQaTree);
319342
}
320343

321-
// 更新currentNode - 在AI回答后创建新节点
322-
String newNodeId = session.getNextNodeId();
323-
session.setCurrentNode(newNodeId);
344+
String parentId = session.getCurrentNode();
345+
346+
// 创建一个文本类型的问题,内容是生成的提示词
347+
InputQuestion promptQuestion = new InputQuestion();
348+
promptQuestion.setQuestion(response.getGenPrompt());
349+
promptQuestion.setAnswer(""); // 初始无答案,等待用户回答
350+
351+
// 添加到qaTree
352+
QaTreeNode promptNode = qaTreeDomain.appendNode(session.getQaTree(), parentId, promptQuestion, session);
353+
354+
// 更新currentNode为新创建的提示词节点
355+
session.setCurrentNode(promptNode.getId());
324356
session.setUpdateTime(LocalDateTime.now());
325357

326-
// 发送AI生成的提示词
327-
sseNotificationService.sendSuccessMessage(fingerprint, response.getGenPrompt()); // 保持原样,因为成功消息的发送方式未改变
328-
329-
log.info("genPrompt处理完成 - 会话: {}, 新节点: {}", session.getSessionId(), newNodeId);
358+
// 发送question格式的SSE消息,就像普通的AI回答一样
359+
Map<String, Object> questionResponse = new HashMap<>();
360+
questionResponse.put("question", Map.of(
361+
"type", "input",
362+
"question", response.getGenPrompt(),
363+
"desc", "这是为您生成的提示词,您可以基于此内容继续对话"
364+
));
365+
questionResponse.put("sessionId", session.getSessionId());
366+
questionResponse.put("currentNodeId", promptNode.getId());
367+
questionResponse.put("parentNodeId", session.getCurrentNode());
368+
369+
sseNotificationService.sendSuccessMessage(fingerprint, JSONObject.toJSONString(questionResponse));
370+
371+
log.info("genPrompt处理完成 - 会话: {}, 新节点: {}", session.getSessionId(), promptNode.getId());
330372
} catch (Exception e) {
331373
// 恢复原始qaTree(异常情况下)
332374
if (finalOriginalQaTree != null) {

prompto-lab-app/src/main/java/io/github/timemachinelab/util/QaTreeSerializeUtil.java

Lines changed: 190 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
import io.github.timemachinelab.core.qatree.QaTree;
66
import io.github.timemachinelab.core.qatree.QaTreeNode;
77
import io.github.timemachinelab.core.serializable.JsonNode;
8+
import io.github.timemachinelab.core.question.*;
9+
import io.github.timemachinelab.core.serializable.TempFormQuestion;
810

911
import java.util.ArrayList;
1012
import java.util.List;
13+
import java.util.Map;
14+
import java.util.HashMap;
1115

1216
public class QaTreeSerializeUtil {
1317

@@ -16,33 +20,210 @@ public static String serialize(QaTree t) throws JsonProcessingException {
1620
return "[]";
1721
}
1822

19-
List<JsonNode> result = new ArrayList<>();
23+
List<Map<String, Object>> result = new ArrayList<>();
2024

21-
firstOrderTraversal(t.getRoot(), null, result);
25+
firstOrderTraversalEnhanced(t.getRoot(), null, result);
2226

2327
return JSONObject.toJSONString(result);
2428
}
2529

26-
private static void firstOrderTraversal(QaTreeNode node, String parentId, List<JsonNode> result) throws JsonProcessingException {
30+
/**
31+
* 增强版遍历方法,返回SSE兼容的格式
32+
*/
33+
private static void firstOrderTraversalEnhanced(QaTreeNode node, String parentId, List<Map<String, Object>> result) throws JsonProcessingException {
2734
if (node == null) {
2835
return;
2936
}
3037

3138
// 获取子节点列表
3239
List<QaTreeNode> children = new ArrayList<>();
33-
3440
if (node.getChildren() != null) {
3541
children.addAll(node.getChildren().values());
3642
}
3743

38-
// 访问当前节点
39-
JsonNode jsonNode = JsonNode.Convert2JsonNode(node, parentId);
40-
41-
result.add(jsonNode);
44+
// 创建增强的节点数据
45+
Map<String, Object> enhancedNode = createEnhancedNode(node, parentId);
46+
result.add(enhancedNode);
4247

4348
// 先序遍历
4449
for (QaTreeNode child : children) {
45-
firstOrderTraversal(child, node.getId(), result);
50+
firstOrderTraversalEnhanced(child, node.getId(), result);
4651
}
4752
}
53+
54+
/**
55+
* 创建SSE兼容的增强节点数据
56+
*/
57+
private static Map<String, Object> createEnhancedNode(QaTreeNode node, String parentId) {
58+
Map<String, Object> enhancedNode = new HashMap<>();
59+
enhancedNode.put("nodeId", node.getId());
60+
enhancedNode.put("parentId", parentId);
61+
62+
String answer = "";
63+
Map<String, Object> questionData = null;
64+
65+
BaseQuestion qa = node.getQa();
66+
if (qa != null) {
67+
// 根据问题类型创建questionData
68+
QuestionType type = QuestionType.fromString(qa.getType());
69+
switch (type) {
70+
case INPUT:
71+
InputQuestion inputQA = (InputQuestion) qa;
72+
questionData = createInputQuestionData(inputQA);
73+
answer = inputQA.getAnswer() != null ? inputQA.getAnswer() : "";
74+
break;
75+
case SINGLE:
76+
SingleChoiceQuestion singleQA = (SingleChoiceQuestion) qa;
77+
questionData = createSingleQuestionData(singleQA);
78+
answer = formatSingleAnswer(singleQA);
79+
break;
80+
case MULTI:
81+
MultipleChoiceQuestion multiQA = (MultipleChoiceQuestion) qa;
82+
questionData = createMultiQuestionData(multiQA);
83+
answer = formatMultiAnswer(multiQA);
84+
break;
85+
case FORM:
86+
FormQuestion formQA = (FormQuestion) qa;
87+
questionData = createFormQuestionData(formQA);
88+
answer = formQA.getAnswer() != null ? JSONObject.toJSONString(formQA.getAnswer()) : "";
89+
break;
90+
default:
91+
// 普通文本问题
92+
questionData = createTextQuestionData(qa.getQuestion());
93+
break;
94+
}
95+
}
96+
97+
enhancedNode.put("questionData", questionData);
98+
enhancedNode.put("answer", answer);
99+
100+
return enhancedNode;
101+
}
102+
103+
/**
104+
* 创建输入问题数据
105+
*/
106+
private static Map<String, Object> createInputQuestionData(InputQuestion inputQA) {
107+
Map<String, Object> questionData = new HashMap<>();
108+
questionData.put("type", "input");
109+
questionData.put("question", inputQA.getQuestion() != null ? inputQA.getQuestion() : "");
110+
questionData.put("desc", inputQA.getDesc() != null ? inputQA.getDesc() : "");
111+
return questionData;
112+
}
113+
114+
/**
115+
* 创建单选问题数据
116+
*/
117+
private static Map<String, Object> createSingleQuestionData(SingleChoiceQuestion singleQA) {
118+
Map<String, Object> questionData = new HashMap<>();
119+
questionData.put("type", "single");
120+
questionData.put("question", singleQA.getQuestion() != null ? singleQA.getQuestion() : "");
121+
questionData.put("desc", singleQA.getDesc() != null ? singleQA.getDesc() : "");
122+
questionData.put("options", singleQA.getOptions() != null ? singleQA.getOptions() : new ArrayList<>());
123+
return questionData;
124+
}
125+
126+
/**
127+
* 创建多选问题数据
128+
*/
129+
private static Map<String, Object> createMultiQuestionData(MultipleChoiceQuestion multiQA) {
130+
Map<String, Object> questionData = new HashMap<>();
131+
questionData.put("type", "multi");
132+
questionData.put("question", multiQA.getQuestion() != null ? multiQA.getQuestion() : "");
133+
questionData.put("desc", multiQA.getDesc() != null ? multiQA.getDesc() : "");
134+
questionData.put("options", multiQA.getOptions() != null ? multiQA.getOptions() : new ArrayList<>());
135+
return questionData;
136+
}
137+
138+
/**
139+
* 创建表单问题数据
140+
*/
141+
private static Map<String, Object> createFormQuestionData(FormQuestion formQA) {
142+
Map<String, Object> questionData = new HashMap<>();
143+
questionData.put("type", "form");
144+
questionData.put("question", formQA.getQuestion() != null ? formQA.getQuestion() : "");
145+
questionData.put("desc", formQA.getDesc() != null ? formQA.getDesc() : "");
146+
questionData.put("fields", formQA.getFields() != null ? formQA.getFields() : new ArrayList<>());
147+
return questionData;
148+
}
149+
150+
/**
151+
* 创建普通文本问题数据
152+
*/
153+
private static Map<String, Object> createTextQuestionData(String question) {
154+
Map<String, Object> questionData = new HashMap<>();
155+
questionData.put("type", "text");
156+
questionData.put("question", question != null ? question : "");
157+
questionData.put("desc", "");
158+
return questionData;
159+
}
160+
161+
/**
162+
* 格式化单选答案
163+
*/
164+
private static String formatSingleAnswer(SingleChoiceQuestion singleQA) {
165+
if (singleQA.getAnswer() != null && !singleQA.getAnswer().isEmpty()) {
166+
List<String> answerLabels = new ArrayList<>();
167+
for (String answerId : singleQA.getAnswer()) {
168+
String label = findOptionLabel(singleQA.getOptions(), answerId);
169+
answerLabels.add(label != null ? label : answerId);
170+
}
171+
return String.join(",", answerLabels);
172+
}
173+
return "";
174+
}
175+
176+
/**
177+
* 格式化多选答案
178+
*/
179+
private static String formatMultiAnswer(MultipleChoiceQuestion multiQA) {
180+
if (multiQA.getAnswer() != null && !multiQA.getAnswer().isEmpty()) {
181+
List<String> answerLabels = new ArrayList<>();
182+
for (String answerId : multiQA.getAnswer()) {
183+
String label = findOptionLabel(multiQA.getOptions(), answerId);
184+
answerLabels.add(label != null ? label : answerId);
185+
}
186+
return String.join(",", answerLabels);
187+
}
188+
return "";
189+
}
190+
191+
/**
192+
* 根据选项id查找对应的标签
193+
*/
194+
private static String findOptionLabel(List<Option> options, String id) {
195+
if (options == null || id == null) {
196+
return null;
197+
}
198+
for (Option option : options) {
199+
if (id.equals(option.getId())) {
200+
return option.getLabel();
201+
}
202+
}
203+
return null;
204+
}
205+
206+
// 保留原有的序列化方法作为备用
207+
private static void firstOrderTraversal(QaTreeNode node, String parentId, List<JsonNode> result) throws JsonProcessingException {
208+
if (node == null) {
209+
return;
210+
}
211+
212+
// 获取子节点列表
213+
List<QaTreeNode> children = new ArrayList<>();
214+
215+
if (node.getChildren() != null) {
216+
children.addAll(node.getChildren().values());
217+
}
218+
219+
// 访问当前节点
220+
JsonNode jsonNode = JsonNode.Convert2JsonNode(node, parentId);
221+
222+
result.add(jsonNode);
223+
224+
// 先序遍历
225+
for (QaTreeNode child : children) {
226+
firstOrderTraversal(child, node.getId(), result);
227+
}
228+
}
48229
}

0 commit comments

Comments
 (0)