From 8c9ad8a12c9dacb50304319f6e19643c9abe660d Mon Sep 17 00:00:00 2001 From: khyaejin Date: Wed, 28 May 2025 08:40:29 +0900 Subject: [PATCH 1/9] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20[Refactor]=20GraphRAG?= =?UTF-8?q?=20=EB=A1=9C=EC=A7=81=20=EB=B6=84=EB=A6=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chatbot/service/ChatbotServiceImpl.java | 118 ++++-------------- .../server/domain/graph/entity/GraphEdge.java | 20 --- .../server/domain/graph/entity/GraphNode.java | 7 -- .../domain/rag/service/GraphRAGService.java | 72 +++++++++++ .../rag/service/SimilarityFilterService.java | 13 +- .../service/SimilarityFilterServiceImpl.java | 17 --- .../upload/service/UploadServiceImpl.java | 3 + 7 files changed, 111 insertions(+), 139 deletions(-) create mode 100644 src/main/java/com/going/server/domain/rag/service/GraphRAGService.java delete mode 100644 src/main/java/com/going/server/domain/rag/service/SimilarityFilterServiceImpl.java diff --git a/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java b/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java index 5646ee8..d1739f0 100644 --- a/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java +++ b/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java @@ -15,6 +15,7 @@ import com.going.server.domain.openai.service.RAGAnswerCreateService; import com.going.server.domain.openai.service.SimpleAnswerCreateService; import com.going.server.domain.openai.service.TextSummaryCreateService; +import com.going.server.domain.rag.service.GraphRAGService; import com.going.server.domain.rag.service.SimilarityFilterService; import com.going.server.domain.rag.util.PromptBuilder; import lombok.RequiredArgsConstructor; @@ -38,8 +39,10 @@ public class ChatbotServiceImpl implements ChatbotService { // openai 관련 service private final TextSummaryCreateService textSummaryCreateService; private final SimpleAnswerCreateService simpleAnswerCreateService; - private final RAGAnswerCreateService ragAnswerCreateService; private final ImageCreateService imageCreateService; + // graphRAG + private final GraphRAGService graphRAGService; + private final RAGAnswerCreateService ragAnswerCreateService; // 원문 반환 @Override @@ -71,106 +74,48 @@ public CreateChatbotResponseDto getSummaryText(String graphId) { .build(); } - - // RAG 챗봇 응답 생성 + // GraphRAG 챗봇 응답 생성 @Override - public CreateChatbotResponseDto createAnswerWithRAG(String graphStrId, CreateChatbotRequestDto createChatbotRequestDto) { + public CreateChatbotResponseDto createAnswerWithRAG(String graphStrId, CreateChatbotRequestDto requestDto) { Long graphId = Long.valueOf(graphStrId); // 404 : 지식그래프 찾을 수 없음 Graph graph = graphRepository.getByGraph(graphId); - // RAG: 사용자 질문 - String userQuestion = createChatbotRequestDto.getChatContent(); - - // RAG : 키워드 추출 - List keywords = extractKeywords(userQuestion); - System.out.println("[RAG] 추출된 키워드: " + keywords); - - // RAG: 유사 노드 검색 및 문장 추출 - List matchedNodes = graphNodeRepository.findByGraphIdAndKeywords(graphId, keywords); -// List matchedNodes = graphNodeRepository.findByGraphIdAndKeywordsWithEdges(graphId, keywords); - System.out.println("[RAG] matchedNodes: " + matchedNodes); - - List candidateSentences = matchedNodes.stream() - .map(GraphNode::getIncludeSentence) - .filter(Objects::nonNull) - .distinct() - .collect(Collectors.toList()); - - // RAG: 유사 문장 필터링 - List filteredChunks = similarityFilterService.filterRelevantSentences(userQuestion, candidateSentences); - - // RAG: 최종 프롬프트 구성 - String finalPrompt = promptBuilder.buildPrompt(filteredChunks, userQuestion); - System.out.println("finalPrompt: " + finalPrompt); - - // RAG: 메타정보 수집 - List retrievedChunks = new ArrayList<>(filteredChunks); - List sourceNodes = new ArrayList<>( - matchedNodes.stream().map(GraphNode::getLabel).distinct().toList() - ); - Map ragMeta = Map.of( - "chunkCount", String.valueOf(filteredChunks.size()) - ); - // 새로운 대화인 경우 기존 채팅 삭제 - if (createChatbotRequestDto.isNewChat()) { + if (requestDto.isNewChat()) { deletePreviousChat(graphId); } - // 기존 채팅 내역 조회 - List chatHistory = chattingRepository.findAllByGraphId(graphId); - - // 새로운 채팅 - String newChat = createChatbotRequestDto.getChatContent(); - - // 새로운 채팅 repository에 저장 - Chatting chatting = Chatting.builder() + // 사용자 입력 채팅 저장 + Chatting userChat = Chatting.builder() .graph(graph) - .content(newChat) + .content(requestDto.getChatContent()) .sender(Sender.USER) .createdAt(LocalDateTime.now()) .build(); - chattingRepository.save(chatting); - - // 응답 생성 - String chatContent; - - // RAG: 유사 문장이 있을 경우 컨텍스트 활용 - if (retrievedChunks.isEmpty()) { - System.out.println("[INFO] RAG 미적용 - 일반 채팅 기반 응답"); - System.out.println("[INFO] RAG 미적용 - 유사 문장 없음"); - System.out.println("[DEBUG] matchedNodes.size(): " + matchedNodes.size()); - System.out.println("[DEBUG] candidateSentences.size(): " + candidateSentences.size()); - System.out.println("[DEBUG] filteredChunks.size(): " + filteredChunks.size()); - chatContent = ragAnswerCreateService.chat(chatHistory, newChat); - } else { - System.out.println("[INFO] RAG 적용됨 - 유사 문장 " + retrievedChunks.size() + "개 포함"); - chatContent = ragAnswerCreateService.chatWithContext(chatHistory, finalPrompt); - } + chattingRepository.save(userChat); - // 응답 repository에 저장 - Chatting answer = Chatting.builder() + // RAG 응답 생성 (응답 + 메타 포함) + CreateChatbotResponseDto responseDto = graphRAGService.createAnswerWithRAG( + graphId, + requestDto.getChatContent(), + requestDto.isNewChat() + ); + + // 응답 채팅 저장 + Chatting gptChat = Chatting.builder() .graph(graph) - .content(chatContent) + .content(responseDto.getChatContent()) .sender(Sender.GPT) - .createdAt(LocalDateTime.now()) + .createdAt(responseDto.getCreatedAt()) .build(); - chattingRepository.save(answer); + chattingRepository.save(gptChat); - // 반환 - return CreateChatbotResponseDto.builder() - .chatContent(chatContent) - .graphId(graphStrId) - .createdAt(answer.getCreatedAt()) - .retrievedChunks(retrievedChunks) - .sourceNodes(sourceNodes) - .ragMeta(ragMeta) - .build(); + return responseDto; } - // RAG 사용하지 않는 응답 생성 + // 기본 응답 생성 @Override public CreateChatbotResponseDto createSimpleAnswer(String graphStrId, CreateChatbotRequestDto createChatbotRequestDto) { Long graphId = Long.valueOf(graphStrId); @@ -293,17 +238,4 @@ private void deletePreviousChat(Long graphId) { chattingRepository.deleteByGraphId(graphId); } - - // RAG : 키워드 추출 - private List extractKeywords(String text) { - List stopwords = List.of("은", "는", "이", "가", "을", "를", "에", "의", "와", "과", "에서", "하다"); - - return Arrays.stream(text.split("[\\s,.!?]+")) - .map(word -> word.replaceAll("(은|는|이|가|을|를|에|의|와|과|에서)$", "")) // ✅ 조사 제거 - .map(String::toLowerCase) - .filter(word -> word.length() > 1 && !stopwords.contains(word)) - .distinct() - .limit(5) - .collect(Collectors.toList()); - } } diff --git a/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java b/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java index 22f01e9..c8da938 100644 --- a/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java +++ b/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java @@ -23,24 +23,4 @@ public class GraphEdge { @TargetNode private GraphNode target; // 연결 대상 노드 - -// private GraphEdge createEdge(Long edgeId, String label, GraphNode source, GraphNode target) { -// GraphEdge edge = new GraphEdge(); -// edge.setEdgeId(edgeId); -// edge.setLabel(label); -// edge.setTarget(target); -// -// // outbound edge를 source에 연결 -// if (source.getEdges() == null) { -// source.setEdges(new ArrayList<>()); -// } -// source.getEdges().add(edge); -// return edge; -// } - - // Long → String 변환 (프론트 전송 시) - public String getIdAsString() { - return id != null ? String.valueOf(id) : null; - } - } diff --git a/src/main/java/com/going/server/domain/graph/entity/GraphNode.java b/src/main/java/com/going/server/domain/graph/entity/GraphNode.java index 33e19eb..d7a6770 100644 --- a/src/main/java/com/going/server/domain/graph/entity/GraphNode.java +++ b/src/main/java/com/going/server/domain/graph/entity/GraphNode.java @@ -28,13 +28,6 @@ public class GraphNode { private String includeSentence; //해당 노드(단어)가 포함된 문장 private String image; -// @Relationship(type = "HAS_GRAPH", direction = Relationship.Direction.INCOMING) -// private Graph graph; - @Relationship(type = "RELATED", direction = Relationship.Direction.OUTGOING) private Set edges; - - public String getIdAsString() { - return id != null ? String.valueOf(id) : null; - } } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java new file mode 100644 index 0000000..952c022 --- /dev/null +++ b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java @@ -0,0 +1,72 @@ +package com.going.server.domain.rag.service; + +import com.going.server.domain.chatbot.dto.CreateChatbotResponseDto; +import com.going.server.domain.chatbot.entity.Chatting; +import com.going.server.domain.chatbot.repository.ChattingRepository; +import com.going.server.domain.graph.entity.Graph; +import com.going.server.domain.graph.entity.GraphNode; +import com.going.server.domain.graph.repository.GraphNodeRepository; +import com.going.server.domain.graph.repository.GraphRepository; +import com.going.server.domain.rag.util.PromptBuilder; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +@Service +@RequiredArgsConstructor +public class GraphRAGService { + + private final GraphRepository graphRepository; + private final GraphNodeRepository graphNodeRepository; + private final SimilarityFilterService similarityFilterService; + private final PromptBuilder promptBuilder; + private final ChattingRepository chattingRepository; + private final KeywordExtractor keywordExtractor; + private final RagAnswerCreateService ragAnswerCreateService; + + public CreateChatbotResponseDto createAnswerWithRAG(Long graphId, String userQuestion, boolean isNewChat) { + Graph graph = graphRepository.getByGraph(graphId); + + // 키워드 추출 + List keywords = keywordExtractor.extract(userQuestion); + + // 관련 노드 및 문장 + List matchedNodes = graphNodeRepository.findByGraphIdAndKeywords(graphId, keywords); + List candidateSentences = matchedNodes.stream() + .map(GraphNode::getIncludeSentence) + .filter(Objects::nonNull) + .distinct() + .collect(Collectors.toList()); + + // 필터링 + List filteredChunks = similarityFilterService.filterRelevantSentences(userQuestion, candidateSentences); + String finalPrompt = promptBuilder.buildPrompt(filteredChunks, userQuestion); + + // 대화 내역 + if (isNewChat) chattingRepository.deleteAllByGraphId(graphId); + List chatHistory = chattingRepository.findAllByGraphId(graphId); + + // 질문 저장 + chattingRepository.save(Chatting.ofUser(graph, userQuestion)); + + // 응답 생성 + String response = filteredChunks.isEmpty() + ? ragAnswerCreateService.chat(chatHistory, userQuestion) + : ragAnswerCreateService.chatWithContext(chatHistory, finalPrompt); + + // 응답 저장 + Chatting answer = Chatting.ofGPT(graph, response); + chattingRepository.save(answer); + + return CreateChatbotResponseDto.of( + response, + graphId.toString(), + answer.getCreatedAt(), + filteredChunks, + matchedNodes.stream().map(GraphNode::getLabel).distinct().toList() + ); + } +} diff --git a/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java b/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java index cefc705..8e5ac47 100644 --- a/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java +++ b/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java @@ -1,7 +1,16 @@ package com.going.server.domain.rag.service; +import org.springframework.stereotype.Service; + import java.util.List; -public interface SimilarityFilterService { - List filterRelevantSentences(String question, List candidateSentences); +@Service +// 유사도 검사 +public class SimilarityFilterService { + // TODO : 정확한 문맥 유사도 필터로 개선 필요 + public List filterRelevantSentences(String query, List candidates) { + return candidates.stream() + .filter(sentence -> sentence.toLowerCase().contains(query.toLowerCase())) + .toList(); + } } diff --git a/src/main/java/com/going/server/domain/rag/service/SimilarityFilterServiceImpl.java b/src/main/java/com/going/server/domain/rag/service/SimilarityFilterServiceImpl.java deleted file mode 100644 index b4b038b..0000000 --- a/src/main/java/com/going/server/domain/rag/service/SimilarityFilterServiceImpl.java +++ /dev/null @@ -1,17 +0,0 @@ -package com.going.server.domain.rag.service; - -import org.springframework.stereotype.Service; - -import java.util.List; - -@Service -// 유사도 검사 -public class SimilarityFilterServiceImpl implements SimilarityFilterService { - // TODO : 정확한 문맥 유사도 필터로 개선 필요 - @Override - public List filterRelevantSentences(String query, List candidates) { - return candidates.stream() - .filter(sentence -> sentence.toLowerCase().contains(query.toLowerCase())) - .toList(); - } -} diff --git a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java index a11c6f4..1ea6eae 100644 --- a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java +++ b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java @@ -47,6 +47,7 @@ public class UploadServiceImpl implements UploadService { @Override public UploadResponseDto uploadFile(UploadRequestDto dto) { try { + String jsonResponse = ocrService.processOcr(dto.getFile(), apiUrl, secretKey); log.info("jsonResponse log={}",jsonResponse); Map paresData = pdfOcrService.parse(jsonResponse); @@ -153,6 +154,7 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { } } + // 모델 코드 호출 public String setModelData(String text) { WebClient webClient = WebClient.builder().baseUrl(fastApiUrl).build(); Map requestBody = new HashMap<>(); @@ -166,4 +168,5 @@ public String setModelData(String text) { .bodyToMono(String.class) .block(); } + } From ac3a9a813df8014ff5d45bb17dbd1f91ee82d90a Mon Sep 17 00:00:00 2001 From: khyaejin Date: Wed, 28 May 2025 09:05:27 +0900 Subject: [PATCH 2/9] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20[Refactor]=20GraphRAG?= =?UTF-8?q?=EC=97=90=EC=84=9C=20LLM=20=EB=A1=9C=EC=A7=81=20=EB=B6=84?= =?UTF-8?q?=EB=A6=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chatbot/dto/CreateChatbotResponseDto.java | 29 ++++-- .../domain/chatbot/entity/Chatting.java | 18 ++++ .../chatbot/service/ChatbotServiceImpl.java | 21 +--- .../domain/openai/service/OpenAIService.java | 31 ++++++ .../service/RAGAnswerCreateService.java | 99 ------------------- .../domain/rag/dto/SimilarityRequestDto.java | 4 - .../domain/rag/dto/SimilarityResponseDto.java | 4 - .../rag/service/CypherQueryGenerator.java | 32 ++++++ .../rag/service/GraphQueryExecutor.java | 32 ++++++ .../domain/rag/service/GraphRAGService.java | 48 ++++----- .../rag/service/RagAnswerCreateService.java | 73 ++++++++++++++ .../rag/service/SimilarityFilterService.java | 13 ++- 12 files changed, 239 insertions(+), 165 deletions(-) create mode 100644 src/main/java/com/going/server/domain/openai/service/OpenAIService.java delete mode 100644 src/main/java/com/going/server/domain/openai/service/RAGAnswerCreateService.java delete mode 100644 src/main/java/com/going/server/domain/rag/dto/SimilarityRequestDto.java delete mode 100644 src/main/java/com/going/server/domain/rag/dto/SimilarityResponseDto.java create mode 100644 src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java create mode 100644 src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java create mode 100644 src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java diff --git a/src/main/java/com/going/server/domain/chatbot/dto/CreateChatbotResponseDto.java b/src/main/java/com/going/server/domain/chatbot/dto/CreateChatbotResponseDto.java index 5864666..9e43ebe 100644 --- a/src/main/java/com/going/server/domain/chatbot/dto/CreateChatbotResponseDto.java +++ b/src/main/java/com/going/server/domain/chatbot/dto/CreateChatbotResponseDto.java @@ -11,11 +11,28 @@ @Getter @Builder public class CreateChatbotResponseDto { - private String chatContent; // 챗봇 응답 - private String graphId; // 지식그래프 ID + private String chatContent; + private String graphId; @JsonFormat(pattern = "yyyy-MM-dd'T'HH:mm") - private LocalDateTime createdAt; // 응답 생성 시각 - private List retrievedChunks; // RAG: 검색된 문장들 - private List sourceNodes; // RAG: 참조된 지식그래프 노드 ID - private Map ragMeta; // RAG: 점수, 검색 method 등 + private LocalDateTime createdAt; + private List retrievedChunks; + private List sourceNodes; + private Map ragMeta; + + public static CreateChatbotResponseDto of( + String chatContent, + String graphId, + LocalDateTime createdAt, + List retrievedChunks, + List sourceNodes + ) { + return CreateChatbotResponseDto.builder() + .chatContent(chatContent) + .graphId(graphId) + .createdAt(createdAt) + .retrievedChunks(retrievedChunks) + .sourceNodes(sourceNodes) + .ragMeta(Map.of("chunkCount", String.valueOf(retrievedChunks.size()))) + .build(); + } } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/chatbot/entity/Chatting.java b/src/main/java/com/going/server/domain/chatbot/entity/Chatting.java index 2bd6596..510ffa9 100644 --- a/src/main/java/com/going/server/domain/chatbot/entity/Chatting.java +++ b/src/main/java/com/going/server/domain/chatbot/entity/Chatting.java @@ -26,4 +26,22 @@ public class Chatting { private Sender sender; private LocalDateTime createdAt; + + public static Chatting ofUser(Graph graph, String content) { + return Chatting.builder() + .graph(graph) + .content(content) + .sender(Sender.USER) + .createdAt(LocalDateTime.now()) + .build(); + } + + public static Chatting ofGPT(Graph graph, String content) { + return Chatting.builder() + .graph(graph) + .content(content) + .sender(Sender.GPT) + .createdAt(LocalDateTime.now()) + .build(); + } } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java b/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java index d1739f0..715e783 100644 --- a/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java +++ b/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java @@ -6,13 +6,11 @@ import com.going.server.domain.chatbot.entity.Sender; import com.going.server.domain.chatbot.repository.ChattingRepository; import com.going.server.domain.graph.entity.Graph; -import com.going.server.domain.graph.entity.GraphNode; import com.going.server.domain.graph.exception.GraphContentNotFoundException; import com.going.server.domain.graph.repository.GraphRepository; import com.going.server.domain.graph.repository.GraphNodeRepository; import com.going.server.domain.openai.dto.ImageCreateRequestDto; import com.going.server.domain.openai.service.ImageCreateService; -import com.going.server.domain.openai.service.RAGAnswerCreateService; import com.going.server.domain.openai.service.SimpleAnswerCreateService; import com.going.server.domain.openai.service.TextSummaryCreateService; import com.going.server.domain.rag.service.GraphRAGService; @@ -25,8 +23,6 @@ import java.time.LocalDateTime; import java.util.*; -import java.util.stream.Collectors; - @Service @RequiredArgsConstructor @Transactional @@ -42,7 +38,6 @@ public class ChatbotServiceImpl implements ChatbotService { private final ImageCreateService imageCreateService; // graphRAG private final GraphRAGService graphRAGService; - private final RAGAnswerCreateService ragAnswerCreateService; // 원문 반환 @Override @@ -78,16 +73,13 @@ public CreateChatbotResponseDto getSummaryText(String graphId) { @Override public CreateChatbotResponseDto createAnswerWithRAG(String graphStrId, CreateChatbotRequestDto requestDto) { Long graphId = Long.valueOf(graphStrId); - // 404 : 지식그래프 찾을 수 없음 Graph graph = graphRepository.getByGraph(graphId); - // 새로운 대화인 경우 기존 채팅 삭제 if (requestDto.isNewChat()) { deletePreviousChat(graphId); } - // 사용자 입력 채팅 저장 Chatting userChat = Chatting.builder() .graph(graph) .content(requestDto.getChatContent()) @@ -96,20 +88,17 @@ public CreateChatbotResponseDto createAnswerWithRAG(String graphStrId, CreateCha .build(); chattingRepository.save(userChat); + List chatHistory = chattingRepository.findAllByGraphId(graphId); + // RAG 응답 생성 (응답 + 메타 포함) - CreateChatbotResponseDto responseDto = graphRAGService.createAnswerWithRAG( + CreateChatbotResponseDto responseDto = graphRAGService.createAnswerWithGraphRAG( graphId, requestDto.getChatContent(), - requestDto.isNewChat() + chatHistory ); // 응답 채팅 저장 - Chatting gptChat = Chatting.builder() - .graph(graph) - .content(responseDto.getChatContent()) - .sender(Sender.GPT) - .createdAt(responseDto.getCreatedAt()) - .build(); + Chatting gptChat = Chatting.ofGPT(graph, responseDto.getChatContent()); chattingRepository.save(gptChat); return responseDto; diff --git a/src/main/java/com/going/server/domain/openai/service/OpenAIService.java b/src/main/java/com/going/server/domain/openai/service/OpenAIService.java new file mode 100644 index 0000000..ada3955 --- /dev/null +++ b/src/main/java/com/going/server/domain/openai/service/OpenAIService.java @@ -0,0 +1,31 @@ +package com.going.server.domain.openai.service; + +import com.going.server.domain.openai.dto.ChatCompletionRequestDto; +import com.theokanning.openai.OpenAiService; +import com.theokanning.openai.completion.chat.ChatMessage; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service +@RequiredArgsConstructor +public class OpenAIService { + + private final OpenAiService openAiService; + + public String getCompletionResponse(List messages, String model, double temperature, int maxTokens) { + ChatCompletionRequestDto request = ChatCompletionRequestDto.builder() + .model(model) + .temperature(temperature) + .maxTokens(maxTokens) + .messages(messages) + .build(); + + return openAiService.createChatCompletion(request.toRequest()) + .getChoices() + .get(0) + .getMessage() + .getContent(); + } +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/openai/service/RAGAnswerCreateService.java b/src/main/java/com/going/server/domain/openai/service/RAGAnswerCreateService.java deleted file mode 100644 index 8afd16d..0000000 --- a/src/main/java/com/going/server/domain/openai/service/RAGAnswerCreateService.java +++ /dev/null @@ -1,99 +0,0 @@ -package com.going.server.domain.openai.service; - -import com.going.server.domain.chatbot.entity.Chatting; -import com.going.server.domain.chatbot.entity.Sender; -import com.going.server.domain.openai.dto.ChatCompletionRequestDto; -import com.theokanning.openai.OpenAiService; -import com.theokanning.openai.completion.chat.ChatMessage; -import lombok.RequiredArgsConstructor; -import org.springframework.stereotype.Service; - -import java.util.ArrayList; -import java.util.List; - -@Service -@RequiredArgsConstructor -public class RAGAnswerCreateService { - private final OpenAiService openAiService; - - // 시스템 역할 설정 - private static final String SYSTEM_PROMPT = """ - 당신은 초등학생의 이해를 돕는 친절하고 정확한 지식 튜터입니다. - - 아래 제공된 데이터를 기반으로 질문에 대해 매우 길고 정확하게 설명해주세요. - - 만약 참고 데이터가 없다면, 교육 도메인의 일반적인 지식을 바탕으로 충실하게 답변해주세요. - - 반드시 한글로만 응답하고, 인사말이나 불필요한 문장은 생략한 대답만 반환하세요. - """; - - // 모델 스펙 정의 - private static final String MODEL_NAME = "gpt-4-0125-preview"; -// private static final String MODEL_NAME = "gpt-4o"; - - private static final double TEMPERATURE = 0.3 ; - private static final int MAX_TOKENS = 3000; - - // 기존 채팅 이력을 기반으로 GPT 응답 생성 - public String chat(List chatHistory, String question) { - - // 메세지 구성 - List messages = new ArrayList<>(); - messages.add(new ChatMessage("system", SYSTEM_PROMPT)); // 프롬프트 설정 -// messages.addAll(convertHistoryToMessages(chatHistory)); // 기존 채팅 - messages.add(new ChatMessage("user", question)); // 새로운 질문 - - // DTO 기반 요청 생성 - ChatCompletionRequestDto request = buildRequest(messages); - - // OpenAI 모델에게 질문 및 응답 생성 - return getResponseText(request); - } - - // RAG 컨텍스트 기반 + 기존 채팅 이력을 함께 사용하는 GPT 응답 생성 - public String chatWithContext(List chatHistory, String finalPrompt) { - - List messages = new ArrayList<>(); - messages.add(new ChatMessage("system", SYSTEM_PROMPT)); - - // 기존 대화 이력 추가 -// messages.addAll(convertHistoryToMessages(chatHistory)); - - // 마지막 질문을 RAG 컨텍스트 기반으로 전달 - messages.add(new ChatMessage("user", finalPrompt)); - - ChatCompletionRequestDto request = buildRequest(messages); - return getResponseText(request); - } - - // 요청 생성 메서드 - private ChatCompletionRequestDto buildRequest(List messages) { - return ChatCompletionRequestDto.builder() - .model(MODEL_NAME) - .temperature(TEMPERATURE) - .maxTokens(MAX_TOKENS) - .messages(messages) - .build(); - } - - // 응답 추출 메서드 - private String getResponseText(ChatCompletionRequestDto request) { - return openAiService.createChatCompletion(request.toRequest()) - .getChoices() - .get(0) - .getMessage() - .getContent(); - } - - // Chatting 엔티티를 OpenAI ChatMessage로 변환 - private List convertHistoryToMessages(List chatHistory) { - return chatHistory.stream() - .map(chat -> new ChatMessage( - convertSenderToRole(chat.getSender()), - chat.getContent() - )) - .toList(); - } - - // Chatting 엔티티의 Sender(Enum type) -> OpenAI 역할 문자열 변환 - private String convertSenderToRole(Sender sender) { - return sender == Sender.USER ? "user" : "assistant"; - } -} diff --git a/src/main/java/com/going/server/domain/rag/dto/SimilarityRequestDto.java b/src/main/java/com/going/server/domain/rag/dto/SimilarityRequestDto.java deleted file mode 100644 index 0f32fce..0000000 --- a/src/main/java/com/going/server/domain/rag/dto/SimilarityRequestDto.java +++ /dev/null @@ -1,4 +0,0 @@ -package com.going.server.domain.rag.dto; - -public class SimilarityRequestDto { -} diff --git a/src/main/java/com/going/server/domain/rag/dto/SimilarityResponseDto.java b/src/main/java/com/going/server/domain/rag/dto/SimilarityResponseDto.java deleted file mode 100644 index 46a6974..0000000 --- a/src/main/java/com/going/server/domain/rag/dto/SimilarityResponseDto.java +++ /dev/null @@ -1,4 +0,0 @@ -package com.going.server.domain.rag.dto; - -public class SimilarityResponseDto { -} diff --git a/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java b/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java new file mode 100644 index 0000000..8c56fed --- /dev/null +++ b/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java @@ -0,0 +1,32 @@ +package com.going.server.domain.rag.service; + +import com.going.server.domain.openai.service.OpenAIService; +import com.theokanning.openai.OpenAiService; +import com.theokanning.openai.completion.chat.ChatMessage; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; + +import java.util.List; + +@Component +@RequiredArgsConstructor +public class CypherQueryGenerator { + private final OpenAIService openAIService; + + public String generate(String userQuestion) { + String prompt = """ + 너는 Cypher 쿼리 생성기야. + 아래 사용자 질문에 맞는 Cypher 쿼리를 생성해줘. + 데이터는 (:Concept)-[:REL]->(:Concept) 구조야. + + 질문: %s + + Cypher 쿼리: + """.formatted(userQuestion); + + return openAIService.getCompletionResponse( + List.of(new ChatMessage("user", prompt)), + "gpt-4-0125-preview", 0.2, 1000 + ); + } +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java b/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java new file mode 100644 index 0000000..11b67f6 --- /dev/null +++ b/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java @@ -0,0 +1,32 @@ +package com.going.server.domain.rag.service; + +import lombok.RequiredArgsConstructor; +import org.neo4j.driver.*; +import org.neo4j.driver.Record; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.List; + +@Component +@RequiredArgsConstructor +public class GraphQueryExecutor { + + private final Driver neo4jDriver; // Neo4j Java Driver 주입 + + public List runQuery(Long graphId, String cypherQuery) { + List results = new ArrayList<>(); + + try (Session session = neo4jDriver.session()) { + Result result = session.run(cypherQuery); + while (result.hasNext()) { + Record record = result.next(); + results.add(record.toString()); // 필요에 따라 특정 필드만 추출 가능 + } + } catch (Exception e) { + e.printStackTrace(); + } + + return results; + } +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java index 952c022..3874db4 100644 --- a/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java +++ b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java @@ -4,7 +4,6 @@ import com.going.server.domain.chatbot.entity.Chatting; import com.going.server.domain.chatbot.repository.ChattingRepository; import com.going.server.domain.graph.entity.Graph; -import com.going.server.domain.graph.entity.GraphNode; import com.going.server.domain.graph.repository.GraphNodeRepository; import com.going.server.domain.graph.repository.GraphRepository; import com.going.server.domain.rag.util.PromptBuilder; @@ -12,8 +11,6 @@ import org.springframework.stereotype.Service; import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; @Service @RequiredArgsConstructor @@ -24,40 +21,33 @@ public class GraphRAGService { private final SimilarityFilterService similarityFilterService; private final PromptBuilder promptBuilder; private final ChattingRepository chattingRepository; - private final KeywordExtractor keywordExtractor; + private final CypherQueryGenerator cypherQueryGenerator; + private final GraphQueryExecutor graphQueryExecutor; private final RagAnswerCreateService ragAnswerCreateService; - public CreateChatbotResponseDto createAnswerWithRAG(Long graphId, String userQuestion, boolean isNewChat) { - Graph graph = graphRepository.getByGraph(graphId); - - // 키워드 추출 - List keywords = keywordExtractor.extract(userQuestion); - // 관련 노드 및 문장 - List matchedNodes = graphNodeRepository.findByGraphIdAndKeywords(graphId, keywords); - List candidateSentences = matchedNodes.stream() - .map(GraphNode::getIncludeSentence) - .filter(Objects::nonNull) - .distinct() - .collect(Collectors.toList()); + // GraphRAG 응답 생성 + public CreateChatbotResponseDto createAnswerWithGraphRAG( + Long graphId, + String userQuestion, + List chatHistory + ){ + Graph graph = graphRepository.getByGraph(graphId); - // 필터링 - List filteredChunks = similarityFilterService.filterRelevantSentences(userQuestion, candidateSentences); - String finalPrompt = promptBuilder.buildPrompt(filteredChunks, userQuestion); + // 1. 질문 → Cypher 쿼리 생성 (LLM) + String cypherQuery = cypherQueryGenerator.generate(userQuestion); - // 대화 내역 - if (isNewChat) chattingRepository.deleteAllByGraphId(graphId); - List chatHistory = chattingRepository.findAllByGraphId(graphId); + // 2. 쿼리 실행 → 결과 추출 + List contextChunks = graphQueryExecutor.runQuery(graphId, cypherQuery); - // 질문 저장 - chattingRepository.save(Chatting.ofUser(graph, userQuestion)); + // 3. 프롬프트 구성 + String finalPrompt = promptBuilder.buildPrompt(contextChunks, userQuestion); - // 응답 생성 - String response = filteredChunks.isEmpty() + // 4. 응답 생성 + String response = contextChunks.isEmpty() ? ragAnswerCreateService.chat(chatHistory, userQuestion) : ragAnswerCreateService.chatWithContext(chatHistory, finalPrompt); - // 응답 저장 Chatting answer = Chatting.ofGPT(graph, response); chattingRepository.save(answer); @@ -65,8 +55,8 @@ public CreateChatbotResponseDto createAnswerWithRAG(Long graphId, String userQue response, graphId.toString(), answer.getCreatedAt(), - filteredChunks, - matchedNodes.stream().map(GraphNode::getLabel).distinct().toList() + contextChunks, + null // sourceNodes: 필요하면 쿼리 결과에서 추출 ); } } diff --git a/src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java b/src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java new file mode 100644 index 0000000..712b7bc --- /dev/null +++ b/src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java @@ -0,0 +1,73 @@ +package com.going.server.domain.rag.service; + +import com.going.server.domain.chatbot.entity.Chatting; +import com.going.server.domain.chatbot.entity.Sender; +import com.going.server.domain.graph.entity.Graph; +import com.going.server.domain.openai.dto.ChatCompletionRequestDto; +import com.going.server.domain.openai.service.OpenAIService; +import com.theokanning.openai.completion.chat.ChatMessage; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; + +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.List; + +@Service +@RequiredArgsConstructor +public class RagAnswerCreateService { + + private final OpenAIService openAIService; + + private static final String SYSTEM_PROMPT = """ + 당신은 초등학생의 이해를 돕는 친절하고 정확한 지식 튜터입니다. + - 아래 제공된 데이터를 기반으로 질문에 대해 매우 길고 정확하게 설명해주세요. + - 만약 참고 데이터가 없다면, 교육 도메인의 일반적인 지식을 바탕으로 충실하게 답변해주세요. + - 반드시 한글로만 응답하고, 인사말이나 불필요한 문장은 생략한 대답만 반환하세요. + """; + + private static final String MODEL_NAME = "gpt-4-0125-preview"; + private static final double TEMPERATURE = 0.3; + private static final int MAX_TOKENS = 3000; + + public String chat(List chatHistory, String question) { + List messages = new ArrayList<>(); + messages.add(new ChatMessage("system", SYSTEM_PROMPT)); + messages.add(new ChatMessage("user", question)); + return openAIService.getCompletionResponse(messages, MODEL_NAME, TEMPERATURE, MAX_TOKENS); + } + + public String chatWithContext(List chatHistory, String finalPrompt) { + List messages = new ArrayList<>(); + messages.add(new ChatMessage("system", SYSTEM_PROMPT)); + messages.add(new ChatMessage("user", finalPrompt)); + return openAIService.getCompletionResponse(messages, MODEL_NAME, TEMPERATURE, MAX_TOKENS); + } + + private List convertHistoryToMessages(List chatHistory) { + return chatHistory.stream() + .map(chat -> new ChatMessage( + chat.getSender() == Sender.USER ? "user" : "assistant", + chat.getContent() + )) + .toList(); + } + + public static Chatting ofUser(Graph graph, String content) { + return Chatting.builder() + .graph(graph) + .content(content) + .sender(Sender.USER) + .createdAt(LocalDateTime.now()) + .build(); + } + + public static Chatting ofGPT(Graph graph, String content) { + return Chatting.builder() + .graph(graph) + .content(content) + .sender(Sender.GPT) + .createdAt(LocalDateTime.now()) + .build(); + } +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java b/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java index 8e5ac47..b0d237d 100644 --- a/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java +++ b/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java @@ -1,16 +1,15 @@ package com.going.server.domain.rag.service; + import org.springframework.stereotype.Service; import java.util.List; @Service -// 유사도 검사 public class SimilarityFilterService { - // TODO : 정확한 문맥 유사도 필터로 개선 필요 - public List filterRelevantSentences(String query, List candidates) { - return candidates.stream() - .filter(sentence -> sentence.toLowerCase().contains(query.toLowerCase())) - .toList(); + + // 간단히 모든 문장을 통과시키는 기본 구현 (추후 유사도 필터링 적용 가능) + public List filterRelevantSentences(String userQuestion, List sentences) { + return sentences; } -} +} \ No newline at end of file From 7c68cbbb8f9c6c57bed938daf91f5743f498c870 Mon Sep 17 00:00:00 2001 From: khyaejin Date: Wed, 28 May 2025 09:16:08 +0900 Subject: [PATCH 3/9] =?UTF-8?q?=E2=9C=A8=20[Feat]=20GraphRAG=20=EB=A1=9C?= =?UTF-8?q?=EC=A7=81=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../domain/rag/dto/GraphQueryResult.java | 13 ++++++++ .../rag/service/CypherQueryGenerator.java | 1 + .../rag/service/GraphQueryExecutor.java | 3 +- .../domain/rag/service/GraphRAGService.java | 30 ++++++++++++++----- .../rag/service/RagAnswerCreateService.java | 30 ++++--------------- 5 files changed, 44 insertions(+), 33 deletions(-) create mode 100644 src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java diff --git a/src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java b/src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java new file mode 100644 index 0000000..1d729e5 --- /dev/null +++ b/src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java @@ -0,0 +1,13 @@ +package com.going.server.domain.rag.dto; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +@AllArgsConstructor +public class GraphQueryResult { + + private final String sentence; // RAG에 활용할 문장 + private final String nodeLabel; // 해당 문장이 포함된 노드의 라벨 또는 ID + +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java b/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java index 8c56fed..d56f12e 100644 --- a/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java +++ b/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java @@ -8,6 +8,7 @@ import java.util.List; +// 1. 질문 → Cypher 쿼리 생성 (LLM) @Component @RequiredArgsConstructor public class CypherQueryGenerator { diff --git a/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java b/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java index 11b67f6..eca924b 100644 --- a/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java +++ b/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java @@ -8,11 +8,12 @@ import java.util.ArrayList; import java.util.List; +// 2. 쿼리 실행 → 결과 추출 @Component @RequiredArgsConstructor public class GraphQueryExecutor { - private final Driver neo4jDriver; // Neo4j Java Driver 주입 + private final Driver neo4jDriver; // Neo4j Java Driver public List runQuery(Long graphId, String cypherQuery) { List results = new ArrayList<>(); diff --git a/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java index 3874db4..33d8ffc 100644 --- a/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java +++ b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java @@ -6,14 +6,17 @@ import com.going.server.domain.graph.entity.Graph; import com.going.server.domain.graph.repository.GraphNodeRepository; import com.going.server.domain.graph.repository.GraphRepository; +import com.going.server.domain.rag.dto.GraphQueryResult; import com.going.server.domain.rag.util.PromptBuilder; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.util.List; @Service @RequiredArgsConstructor +@Slf4j public class GraphRAGService { private final GraphRepository graphRepository; @@ -25,38 +28,49 @@ public class GraphRAGService { private final GraphQueryExecutor graphQueryExecutor; private final RagAnswerCreateService ragAnswerCreateService; - - // GraphRAG 응답 생성 + /** + * 사용자 질문에 대해 Cypher 쿼리 → 그래프 정보 검색 → 프롬프트 생성 → LLM 응답 생성 + * 본 메서드는 LangChain 없이 구현한 Spring 기반 GraphRAG의 핵심 흐름입니다. + */ public CreateChatbotResponseDto createAnswerWithGraphRAG( Long graphId, String userQuestion, List chatHistory - ){ + ) { Graph graph = graphRepository.getByGraph(graphId); + log.info("[GraphRAG] graphId: {}, question: {}", graphId, userQuestion); - // 1. 질문 → Cypher 쿼리 생성 (LLM) + // 1. 질문 → Cypher 쿼리 생성 String cypherQuery = cypherQueryGenerator.generate(userQuestion); + log.info("[GraphRAG] Generated Cypher Query:\n{}", cypherQuery); - // 2. 쿼리 실행 → 결과 추출 - List contextChunks = graphQueryExecutor.runQuery(graphId, cypherQuery); + // 2. 쿼리 실행 → 문맥(context) 및 노드 라벨 추출 + List queryResults = graphQueryExecutor.runQuery(graphId, cypherQuery); + List contextChunks = queryResults.stream().map(GraphQueryResult::sentence).toList(); + List sourceNodes = queryResults.stream().map(GraphQueryResult::nodeLabel).distinct().toList(); + log.info("[GraphRAG] Retrieved {} context chunks", contextChunks.size()); // 3. 프롬프트 구성 String finalPrompt = promptBuilder.buildPrompt(contextChunks, userQuestion); + log.info("[GraphRAG] Final Prompt constructed"); - // 4. 응답 생성 + // 4. RAG 응답 생성 String response = contextChunks.isEmpty() ? ragAnswerCreateService.chat(chatHistory, userQuestion) : ragAnswerCreateService.chatWithContext(chatHistory, finalPrompt); + log.info("[GraphRAG] Response generated by LLM"); + // 5. 응답 저장 Chatting answer = Chatting.ofGPT(graph, response); chattingRepository.save(answer); + log.info("[GraphRAG] Response saved to DB"); return CreateChatbotResponseDto.of( response, graphId.toString(), answer.getCreatedAt(), contextChunks, - null // sourceNodes: 필요하면 쿼리 결과에서 추출 + sourceNodes ); } } diff --git a/src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java b/src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java index 712b7bc..2da4c4b 100644 --- a/src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java +++ b/src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java @@ -2,17 +2,15 @@ import com.going.server.domain.chatbot.entity.Chatting; import com.going.server.domain.chatbot.entity.Sender; -import com.going.server.domain.graph.entity.Graph; -import com.going.server.domain.openai.dto.ChatCompletionRequestDto; import com.going.server.domain.openai.service.OpenAIService; import com.theokanning.openai.completion.chat.ChatMessage; import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; -import java.time.LocalDateTime; import java.util.ArrayList; import java.util.List; +// 4. GraphRAG 응답 생성 @Service @RequiredArgsConstructor public class RagAnswerCreateService { @@ -22,17 +20,18 @@ public class RagAnswerCreateService { private static final String SYSTEM_PROMPT = """ 당신은 초등학생의 이해를 돕는 친절하고 정확한 지식 튜터입니다. - 아래 제공된 데이터를 기반으로 질문에 대해 매우 길고 정확하게 설명해주세요. - - 만약 참고 데이터가 없다면, 교육 도메인의 일반적인 지식을 바탕으로 충실하게 답변해주세요. + - 만약 참고 데이터가 없다면, 관련정보 없다고 하세요. - 반드시 한글로만 응답하고, 인사말이나 불필요한 문장은 생략한 대답만 반환하세요. """; - private static final String MODEL_NAME = "gpt-4-0125-preview"; + private static final String MODEL_NAME = "gpt-4o"; private static final double TEMPERATURE = 0.3; - private static final int MAX_TOKENS = 3000; + private static final int MAX_TOKENS = 1500; public String chat(List chatHistory, String question) { List messages = new ArrayList<>(); messages.add(new ChatMessage("system", SYSTEM_PROMPT)); + messages.addAll(convertHistoryToMessages(chatHistory)); messages.add(new ChatMessage("user", question)); return openAIService.getCompletionResponse(messages, MODEL_NAME, TEMPERATURE, MAX_TOKENS); } @@ -40,6 +39,7 @@ public String chat(List chatHistory, String question) { public String chatWithContext(List chatHistory, String finalPrompt) { List messages = new ArrayList<>(); messages.add(new ChatMessage("system", SYSTEM_PROMPT)); + messages.addAll(convertHistoryToMessages(chatHistory)); messages.add(new ChatMessage("user", finalPrompt)); return openAIService.getCompletionResponse(messages, MODEL_NAME, TEMPERATURE, MAX_TOKENS); } @@ -52,22 +52,4 @@ private List convertHistoryToMessages(List chatHistory) { )) .toList(); } - - public static Chatting ofUser(Graph graph, String content) { - return Chatting.builder() - .graph(graph) - .content(content) - .sender(Sender.USER) - .createdAt(LocalDateTime.now()) - .build(); - } - - public static Chatting ofGPT(Graph graph, String content) { - return Chatting.builder() - .graph(graph) - .content(content) - .sender(Sender.GPT) - .createdAt(LocalDateTime.now()) - .build(); - } } \ No newline at end of file From 60586c5a06e80c26efebefe9045e5fe6db41e64e Mon Sep 17 00:00:00 2001 From: khyaejin Date: Wed, 28 May 2025 09:38:09 +0900 Subject: [PATCH 4/9] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20[Perf]=20=EA=B7=B8?= =?UTF-8?q?=EB=9E=98=ED=94=84=20=EA=B2=80=EC=83=89=20=EC=BF=BC=EB=A6=AC=20?= =?UTF-8?q?=EC=83=9D=EC=84=B1=20=EB=A1=9C=EC=A7=81=20=EA=B0=9C=EC=84=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../domain/rag/dto/GraphQueryResult.java | 6 ++---- .../rag/service/CypherQueryGenerator.java | 19 ++++++++++++------- .../rag/service/GraphQueryExecutor.java | 12 +++++++++--- .../domain/rag/service/GraphRAGService.java | 15 ++++++++++++--- 4 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java b/src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java index 1d729e5..ef185d5 100644 --- a/src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java +++ b/src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java @@ -6,8 +6,6 @@ @Getter @AllArgsConstructor public class GraphQueryResult { - - private final String sentence; // RAG에 활용할 문장 - private final String nodeLabel; // 해당 문장이 포함된 노드의 라벨 또는 ID - + private String sentence; + private String nodeLabel; } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java b/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java index d56f12e..e42a735 100644 --- a/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java +++ b/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java @@ -16,13 +16,18 @@ public class CypherQueryGenerator { public String generate(String userQuestion) { String prompt = """ - 너는 Cypher 쿼리 생성기야. - 아래 사용자 질문에 맞는 Cypher 쿼리를 생성해줘. - 데이터는 (:Concept)-[:REL]->(:Concept) 구조야. - - 질문: %s - - Cypher 쿼리: + 당신은 Neo4j용 Cypher 쿼리를 생성하는 AI입니다. + 주어진 질문에 대해 Cypher 쿼리만 반환하세요. 코드블록, 설명 없이 오직 쿼리만 출력해야 합니다. + + 예: + 질문: "고래와 관련된 개념들을 알려줘" + → MATCH (n:GraphNode)-[r]->(m:GraphNode)\s + WHERE n.label = '고래'\s + RETURN m.label AS nodeLabel, m.includeSentence AS sentence\s + LIMIT 10 + + 질문: "${userQuestion}" + → """.formatted(userQuestion); return openAIService.getCompletionResponse( diff --git a/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java b/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java index eca924b..8b506d1 100644 --- a/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java +++ b/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java @@ -1,5 +1,6 @@ package com.going.server.domain.rag.service; +import com.going.server.domain.rag.dto.GraphQueryResult; import lombok.RequiredArgsConstructor; import org.neo4j.driver.*; import org.neo4j.driver.Record; @@ -15,14 +16,19 @@ public class GraphQueryExecutor { private final Driver neo4jDriver; // Neo4j Java Driver - public List runQuery(Long graphId, String cypherQuery) { - List results = new ArrayList<>(); + public List runQuery(Long graphId, String cypherQuery) { + List results = new ArrayList<>(); try (Session session = neo4jDriver.session()) { Result result = session.run(cypherQuery); while (result.hasNext()) { Record record = result.next(); - results.add(record.toString()); // 필요에 따라 특정 필드만 추출 가능 + + // 필드 이름은 Cypher 쿼리 결과와 일치해야 함 + String sentence = record.get("sentence").asString(""); + String nodeLabel = record.get("nodeLabel").asString(""); + + results.add(new GraphQueryResult(sentence, nodeLabel)); } } catch (Exception e) { e.printStackTrace(); diff --git a/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java index 33d8ffc..84561d8 100644 --- a/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java +++ b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java @@ -41,13 +41,22 @@ public CreateChatbotResponseDto createAnswerWithGraphRAG( log.info("[GraphRAG] graphId: {}, question: {}", graphId, userQuestion); // 1. 질문 → Cypher 쿼리 생성 - String cypherQuery = cypherQueryGenerator.generate(userQuestion); + String cypherQuery = cypherQueryGenerator.generate(userQuestion).trim() + .replaceAll("(?s)```cypher.*?```", "") // 마크다운 제거 + .replaceAll("```", "") // 남은 ``` 제거 + .trim(); log.info("[GraphRAG] Generated Cypher Query:\n{}", cypherQuery); // 2. 쿼리 실행 → 문맥(context) 및 노드 라벨 추출 List queryResults = graphQueryExecutor.runQuery(graphId, cypherQuery); - List contextChunks = queryResults.stream().map(GraphQueryResult::sentence).toList(); - List sourceNodes = queryResults.stream().map(GraphQueryResult::nodeLabel).distinct().toList(); + List contextChunks = queryResults.stream() + .map(GraphQueryResult::getSentence) + .toList(); + + List sourceNodes = queryResults.stream() + .map(GraphQueryResult::getNodeLabel) + .distinct() + .toList(); log.info("[GraphRAG] Retrieved {} context chunks", contextChunks.size()); // 3. 프롬프트 구성 From ec79c0584e838309cb5a426e713184700bb7cd31 Mon Sep 17 00:00:00 2001 From: khyaejin Date: Wed, 28 May 2025 09:50:25 +0900 Subject: [PATCH 5/9] =?UTF-8?q?=F0=9F=90=9B=20[Fix]=20=EA=B7=B8=EB=9E=98?= =?UTF-8?q?=ED=94=84=20=EC=88=9C=ED=99=98=20=EC=B0=B8=EC=A1=B0=20=EC=98=A4?= =?UTF-8?q?=EB=A5=98=20=ED=95=B4=EA=B2=B0=EC=9D=84=20=EC=9C=84=ED=95=9C=20?= =?UTF-8?q?=EB=A1=9C=EC=A7=81=20=EC=B6=94=EA=B0=80(eauals,=20hashcode)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../server/domain/graph/entity/GraphEdge.java | 19 ++++++++++++++++ .../server/domain/graph/entity/GraphNode.java | 3 +++ .../upload/service/UploadServiceImpl.java | 22 +++++++++++++++++-- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java b/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java index c8da938..5f02f41 100644 --- a/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java +++ b/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java @@ -1,11 +1,13 @@ package com.going.server.domain.graph.entity; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import org.springframework.data.neo4j.core.schema.*; import java.util.ArrayList; +import java.util.Objects; @RelationshipProperties @Getter @@ -23,4 +25,21 @@ public class GraphEdge { @TargetNode private GraphNode target; // 연결 대상 노드 + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + GraphEdge edge = (GraphEdge) o; + return Objects.equals(source, edge.source) + && Objects.equals(label, edge.label) + && edge.target != null && target != null + && Objects.equals(target.getNodeId(), edge.target.getNodeId()); + } + + @Override + public int hashCode() { + return Objects.hash(source, label, target != null ? target.getNodeId() : null); + } + } diff --git a/src/main/java/com/going/server/domain/graph/entity/GraphNode.java b/src/main/java/com/going/server/domain/graph/entity/GraphNode.java index d7a6770..15a7a08 100644 --- a/src/main/java/com/going/server/domain/graph/entity/GraphNode.java +++ b/src/main/java/com/going/server/domain/graph/entity/GraphNode.java @@ -4,6 +4,7 @@ import lombok.Builder; import lombok.Getter; import lombok.Setter; +import lombok.ToString; import org.springframework.data.neo4j.core.schema.*; import java.util.ArrayList; @@ -28,6 +29,8 @@ public class GraphNode { private String includeSentence; //해당 노드(단어)가 포함된 문장 private String image; + @ToString.Exclude + @JsonIgnore @Relationship(type = "RELATED", direction = Relationship.Direction.OUTGOING) private Set edges; } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java index 1ea6eae..4d8ae68 100644 --- a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java +++ b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java @@ -13,6 +13,9 @@ import com.going.server.domain.upload.dto.UploadResponseDto; import lombok.RequiredArgsConstructor; +import org.neo4j.driver.Session; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Result; import org.springframework.beans.factory.annotation.Value; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; @@ -33,6 +36,7 @@ public class UploadServiceImpl implements UploadService { private final PdfOcrService pdfOcrService; private final GraphNodeRepository graphNodeRepository; private final GraphRepository graphRepository; + private final Driver neo4jDriver; // Neo4j Java Driver @Value("${ocr.api.url}") private String apiUrl; @@ -96,7 +100,7 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { } String sourceId = edgeNode.get("source").asText(); String targetId = edgeNode.get("target").asText(); - String label = edgeNode.get("label").asText(); + String relationType = edgeNode.get("label").asText(); GraphNode sourceNode = nodeIdToNode.get(sourceId); GraphNode targetNode = nodeIdToNode.get(targetId); @@ -106,10 +110,20 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { continue; } + String dynamicCypher = String.format( + "MATCH (a:GraphNode {nodeId: '%s'}), (b:GraphNode {nodeId: '%s'}) MERGE (a)-[:`%s`]->(b)", + sourceId, targetId, relationType // 한글도 가능: "기능", "포함" 등 + ); + + // session 통해 직접 실행 + try (Session session = neo4jDriver.session()) { + session.run(dynamicCypher); + } + //edge 엔티티 생성 GraphEdge edge = GraphEdge.builder() .source(sourceId) - .label(label) + .label(relationType) .target(targetNode) .build(); @@ -117,6 +131,10 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { if (sourceNode.getEdges() == null) { sourceNode.setEdges(new HashSet<>()); } + + if (!sourceNode.getEdges().contains(edge)) { + sourceNode.getEdges().add(edge); + } sourceNode.getEdges().add(edge); } From e24ac968bc9679bb9eb8d508ab959289601bc2fe Mon Sep 17 00:00:00 2001 From: khyaejin Date: Wed, 28 May 2025 09:57:31 +0900 Subject: [PATCH 6/9] =?UTF-8?q?=F0=9F=9A=A7=20[progress]=20=EC=84=9C?= =?UTF-8?q?=EB=B8=8C=EB=AA=A8=EB=93=88=20=EC=B5=9C=EC=8B=A0=ED=99=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../going/server/domain/upload/service/UploadServiceImpl.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java index 4d8ae68..9d56670 100644 --- a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java +++ b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java @@ -45,6 +45,7 @@ public class UploadServiceImpl implements UploadService { @Value("${fastapi.base-url}") private String fastApiUrl; + private final Map translationCache = new HashMap<>(); private final Map imageCache = new HashMap<>(); From 0562c5003b3d5024bdbe9f4cf91009b906bb5329 Mon Sep 17 00:00:00 2001 From: khyaejin Date: Wed, 28 May 2025 12:31:19 +0900 Subject: [PATCH 7/9] =?UTF-8?q?=F0=9F=90=9B=20[fix]=20=EC=88=9C=ED=99=98?= =?UTF-8?q?=EC=B0=B8=EC=A1=B0=20=EC=98=A4=EB=A5=98=20=ED=95=B4=EA=B2=B0?= =?UTF-8?q?=EC=9D=84=20=EC=9C=84=ED=95=9C=20Graph=20=EB=82=B4=20=EB=B3=84?= =?UTF-8?q?=EB=8F=84=EC=9D=98=20id=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Backend_Config | 2 +- .../server/domain/graph/entity/Graph.java | 5 +++- .../server/domain/graph/entity/GraphEdge.java | 6 ++++ .../server/domain/graph/entity/GraphNode.java | 9 +++--- .../graph/repository/GraphNodeRepository.java | 1 + .../graph/repository/GraphRepository.java | 29 ++++++++++++++++++- .../graph/service/GraphServiceImpl.java | 15 +++++++--- 7 files changed, 56 insertions(+), 11 deletions(-) diff --git a/Backend_Config b/Backend_Config index 3c9dbc1..d2b85ef 160000 --- a/Backend_Config +++ b/Backend_Config @@ -1 +1 @@ -Subproject commit 3c9dbc1b7a76e502834f0986874e74e00e5682eb +Subproject commit d2b85ef76f4279d58f4f002c2cf06c467e2af103 diff --git a/src/main/java/com/going/server/domain/graph/entity/Graph.java b/src/main/java/com/going/server/domain/graph/entity/Graph.java index b2c9fcb..f36c93c 100644 --- a/src/main/java/com/going/server/domain/graph/entity/Graph.java +++ b/src/main/java/com/going/server/domain/graph/entity/Graph.java @@ -15,7 +15,10 @@ public class Graph extends BaseEntity { @Id @GeneratedValue - private Long id; //그래프 id -> 프론트와 통신에서는 String 값으로 사용 + private Long dbId; // 내부 관리용 elementId와 연결됨 + + @Property("id") + private Long id; // 우리가 직접 사용하는 명시적 ID private String title; diff --git a/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java b/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java index 5f02f41..c6ef65e 100644 --- a/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java +++ b/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java @@ -1,5 +1,6 @@ package com.going.server.domain.graph.entity; +import com.fasterxml.jackson.annotation.JsonIgnore; import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -19,11 +20,16 @@ public class GraphEdge { @GeneratedValue private Long id; // Neo4j 내부 ID + @EqualsAndHashCode.Include private String source; + @EqualsAndHashCode.Include private String label; // 관계 라벨 + @EqualsAndHashCode.Include @TargetNode + @Relationship(type = "RELATED", direction = Relationship.Direction.INCOMING) + @JsonIgnore private GraphNode target; // 연결 대상 노드 @Override diff --git a/src/main/java/com/going/server/domain/graph/entity/GraphNode.java b/src/main/java/com/going/server/domain/graph/entity/GraphNode.java index 15a7a08..630ef2a 100644 --- a/src/main/java/com/going/server/domain/graph/entity/GraphNode.java +++ b/src/main/java/com/going/server/domain/graph/entity/GraphNode.java @@ -1,10 +1,8 @@ package com.going.server.domain.graph.entity; import com.fasterxml.jackson.annotation.JsonIgnore; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; -import lombok.ToString; +import lombok.*; +import org.springframework.data.annotation.Transient; import org.springframework.data.neo4j.core.schema.*; import java.util.ArrayList; @@ -15,6 +13,8 @@ @Getter @Setter @Builder +@AllArgsConstructor +@NoArgsConstructor // Neo4j가 생성자 주입 대신 setter 기반으로 생성할 수 있도록 해줌 public class GraphNode { @Id @GeneratedValue @@ -29,6 +29,7 @@ public class GraphNode { private String includeSentence; //해당 노드(단어)가 포함된 문장 private String image; +// @Transient // Neo4j가 매핑하지 않음 @ToString.Exclude @JsonIgnore @Relationship(type = "RELATED", direction = Relationship.Direction.OUTGOING) diff --git a/src/main/java/com/going/server/domain/graph/repository/GraphNodeRepository.java b/src/main/java/com/going/server/domain/graph/repository/GraphNodeRepository.java index 4df135b..e6395d7 100644 --- a/src/main/java/com/going/server/domain/graph/repository/GraphNodeRepository.java +++ b/src/main/java/com/going/server/domain/graph/repository/GraphNodeRepository.java @@ -9,6 +9,7 @@ import java.util.Optional; public interface GraphNodeRepository extends Neo4jRepository { + Optional findByNodeId(Long nodeId); default GraphNode getByNode(Long nodeId) { return findByNodeId(nodeId).orElseThrow(NodeNotFoundException::new); diff --git a/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java b/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java index fa59140..63c7d42 100644 --- a/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java +++ b/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java @@ -1,13 +1,40 @@ package com.going.server.domain.graph.repository; import com.going.server.domain.graph.entity.Graph; +import com.going.server.domain.graph.entity.GraphNode; import com.going.server.domain.graph.exception.GraphNotFoundException; import jdk.jfr.Registered; import org.springframework.data.neo4j.repository.Neo4jRepository; +import org.springframework.data.neo4j.repository.query.Query; +import org.springframework.data.repository.query.Param; + +import java.util.List; +import java.util.Optional; @Registered public interface GraphRepository extends Neo4jRepository { +// default Graph getByGraph(Long graphId) { +// return findById(graphId).orElseThrow(GraphNotFoundException::new); +// } + default Graph getByGraph(Long graphId) { - return findById(graphId).orElseThrow(GraphNotFoundException::new); + return findGraphWithEdgesByGraphId(graphId).orElseThrow(GraphNotFoundException::new); } + + // 그래프 + 노드 + 엣지까지 전부 fetch + @Query(""" + MATCH (g:Graph {id: $graphId})-[:HAS_NODE]->(n:GraphNode) + OPTIONAL MATCH (n)-[r:RELATED]->(m:GraphNode) + WITH g, collect(DISTINCT n) as nodes, collect(DISTINCT r) as rels, collect(DISTINCT m) as targets + RETURN g, nodes, rels, targets +""") + Optional findGraphWithEdgesByGraphId(Long graphId); + + + @Query(""" +MATCH (g:Graph {id: $graphId})-[:HAS_NODE]->(n:GraphNode) +OPTIONAL MATCH (n)-[r:RELATED]->(m:GraphNode) +RETURN g, collect(n), collect(r), collect(m) +""") + Graph findGraphWithEdges(Long graphId); } diff --git a/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java b/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java index 730e9cd..ad23d26 100644 --- a/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java +++ b/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java @@ -44,6 +44,7 @@ public class GraphServiceImpl implements GraphService { @Override public GraphListDto getGraphList() { List graphs = graphRepository.findAll(); + List graphDtos = new ArrayList<>(); for (Graph graph : graphs) { @@ -78,13 +79,18 @@ public KnowledgeGraphDto getGraph(Long graphId) { List edgeDtoList = new ArrayList<>(); for (GraphNode node : graph.getNodes()) { - NodeDto nodeDto = NodeDto.from(node); - nodeDtoList.add(nodeDto); + nodeDtoList.add(NodeDto.from(node)); if (node.getEdges() != null) { for (GraphEdge edge : node.getEdges()) { - EdgeDto edgeDto = EdgeDto.from(edge.getSource(),edge.getTarget().getNodeId().toString(),edge.getLabel()); - edgeDtoList.add(edgeDto); + // 엣지 대상 노드도 제대로 fetch된 상태여야 함 + if (edge.getTarget() != null) { + edgeDtoList.add(EdgeDto.from( + edge.getSource(), + edge.getTarget().getNodeId().toString(), + edge.getLabel() + )); + } } } } @@ -92,6 +98,7 @@ public KnowledgeGraphDto getGraph(Long graphId) { return KnowledgeGraphDto.of(nodeDtoList, edgeDtoList); } + @Override public NodeDto getNode(Long graphId, Long nodeId) { Graph graph = graphRepository.getByGraph(graphId); From fa0e1175577e8dc4ada48b383ed6756c50e1a648 Mon Sep 17 00:00:00 2001 From: khyaejin Date: Wed, 28 May 2025 13:34:01 +0900 Subject: [PATCH 8/9] =?UTF-8?q?=F0=9F=90=9B=20[fix]=20graph=20id=20?= =?UTF-8?q?=EC=B6=94=EA=B0=80=EB=A1=9C=20=EC=9D=B8=ED=95=9C=20=EC=A1=B0?= =?UTF-8?q?=ED=9A=8C=20=EB=A1=9C=EC=A7=81=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../graph/repository/GraphRepository.java | 21 +++++------ .../graph/service/GraphServiceImpl.java | 35 ++++++++++++------- .../upload/service/UploadServiceImpl.java | 2 ++ 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java b/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java index 63c7d42..e4ebb96 100644 --- a/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java +++ b/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java @@ -21,20 +21,17 @@ default Graph getByGraph(Long graphId) { return findGraphWithEdgesByGraphId(graphId).orElseThrow(GraphNotFoundException::new); } - // 그래프 + 노드 + 엣지까지 전부 fetch - @Query(""" - MATCH (g:Graph {id: $graphId})-[:HAS_NODE]->(n:GraphNode) - OPTIONAL MATCH (n)-[r:RELATED]->(m:GraphNode) - WITH g, collect(DISTINCT n) as nodes, collect(DISTINCT r) as rels, collect(DISTINCT m) as targets - RETURN g, nodes, rels, targets -""") - Optional findGraphWithEdgesByGraphId(Long graphId); - + @Query("MATCH (g:Graph) WHERE g.id = $graphId RETURN g") + Optional findByGraphId(@Param("graphId") Long graphId); + // 그래프 + 노드 + 엣지까지 전부 fetch @Query(""" MATCH (g:Graph {id: $graphId})-[:HAS_NODE]->(n:GraphNode) -OPTIONAL MATCH (n)-[r:RELATED]->(m:GraphNode) -RETURN g, collect(n), collect(r), collect(m) +OPTIONAL MATCH (n)-[r]->(m:GraphNode) +RETURN g, collect(DISTINCT n) as nodes, collect(DISTINCT r) as rels, collect(DISTINCT m) as targets """) - Graph findGraphWithEdges(Long graphId); + Optional findGraphWithEdgesByGraphId(@Param("graphId") Long graphId); + + @Query("MATCH (g:Graph) RETURN max(g.id)") + Long findMaxGraphId(); } diff --git a/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java b/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java index ad23d26..27aeb17 100644 --- a/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java +++ b/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java @@ -74,24 +74,35 @@ public void deleteGraph(Long graphId) { @Override public KnowledgeGraphDto getGraph(Long graphId) { Graph graph = graphRepository.getByGraph(graphId); + log.info("[getGraph] 조회된 Graph ID: {}, Title: {}", graph.getId(), graph.getTitle()); List nodeDtoList = new ArrayList<>(); List edgeDtoList = new ArrayList<>(); - for (GraphNode node : graph.getNodes()) { + List nodes = graph.getNodes(); + if (nodes == null || nodes.isEmpty()) { + log.warn("[getGraph] 해당 그래프에는 노드가 없습니다."); + return KnowledgeGraphDto.of(nodeDtoList, edgeDtoList); + } + + for (GraphNode node : nodes) { + if (node == null) continue; + + log.debug("[getGraph] 노드 추가 - ID: {}, Label: {}", node.getNodeId(), node.getLabel()); nodeDtoList.add(NodeDto.from(node)); - if (node.getEdges() != null) { - for (GraphEdge edge : node.getEdges()) { - // 엣지 대상 노드도 제대로 fetch된 상태여야 함 - if (edge.getTarget() != null) { - edgeDtoList.add(EdgeDto.from( - edge.getSource(), - edge.getTarget().getNodeId().toString(), - edge.getLabel() - )); - } - } + Set edges = node.getEdges(); + if (edges == null || edges.isEmpty()) continue; + + for (GraphEdge edge : edges) { + if (edge == null || edge.getTarget() == null) continue; + + String sourceId = edge.getSource(); + String targetId = edge.getTarget().getNodeId().toString(); + String label = edge.getLabel(); + + log.debug("[getGraph] 엣지 추가 - Source: {}, Target: {}, Label: {}", sourceId, targetId, label); + edgeDtoList.add(EdgeDto.from(sourceId, targetId, label)); } } diff --git a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java index 9d56670..48b6c6a 100644 --- a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java +++ b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java @@ -144,7 +144,9 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { //그래프 생성 String title = dto.getTitle(); + Long nextGraphId = graphRepository.findMaxGraphId(); Graph graphEntity = Graph.builder() + .id(nextGraphId == null ? 1L : nextGraphId + 1) //id 직접 세팅 .title(title) .content(text) .listenUpPerfect(false) From 130ffed55a5b7c3b781bc97d857156c361e988b5 Mon Sep 17 00:00:00 2001 From: khyaejin Date: Wed, 28 May 2025 13:58:23 +0900 Subject: [PATCH 9/9] =?UTF-8?q?=F0=9F=90=9B=20[fix]=20=EC=A7=80=EC=8B=9D?= =?UTF-8?q?=EA=B7=B8=EB=9E=98=ED=94=84=20=EC=83=9D=EC=84=B1=20=EC=88=9C?= =?UTF-8?q?=ED=99=98=EC=B0=B8=EC=A1=B0=20=EC=98=A4=EB=A5=98=20=ED=95=B4?= =?UTF-8?q?=EA=B2=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 이제 긴 텍스트도 삽입 가능 --- .../chatbot/service/ChatbotServiceImpl.java | 27 +++++++++---------- .../server/domain/graph/entity/Graph.java | 8 +++--- .../server/domain/graph/entity/GraphEdge.java | 27 +++++++++---------- .../server/domain/graph/entity/GraphNode.java | 5 ++-- .../graph/repository/GraphRepository.java | 25 +++++++++-------- .../graph/service/GraphServiceImpl.java | 24 ++++++++++++----- .../domain/quiz/service/QuizServiceImpl.java | 11 +++----- .../domain/rag/service/GraphRAGService.java | 10 +++---- .../upload/service/UploadServiceImpl.java | 3 +-- 9 files changed, 70 insertions(+), 70 deletions(-) diff --git a/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java b/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java index 715e783..dc91172 100644 --- a/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java +++ b/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java @@ -42,7 +42,8 @@ public class ChatbotServiceImpl implements ChatbotService { // 원문 반환 @Override public CreateChatbotResponseDto getOriginalText(String graphId) { - Graph graph = graphRepository.getByGraph(Long.valueOf(graphId)); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphId)); + Graph graph = graphRepository.getByGraph(dbId); return CreateChatbotResponseDto.builder() .chatContent(graph.getContent()) // 원문 텍스트 @@ -54,7 +55,8 @@ public CreateChatbotResponseDto getOriginalText(String graphId) { // 요약본 생성 @Override public CreateChatbotResponseDto getSummaryText(String graphId) { - Graph graph = graphRepository.getByGraph(Long.valueOf(graphId)); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphId)); + Graph graph = graphRepository.getByGraph(dbId); String context = Optional.ofNullable(graph.getContent()) .filter(s -> !s.trim().isEmpty()) @@ -72,12 +74,11 @@ public CreateChatbotResponseDto getSummaryText(String graphId) { // GraphRAG 챗봇 응답 생성 @Override public CreateChatbotResponseDto createAnswerWithRAG(String graphStrId, CreateChatbotRequestDto requestDto) { - Long graphId = Long.valueOf(graphStrId); - // 404 : 지식그래프 찾을 수 없음 - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphStrId)); + Graph graph = graphRepository.getByGraph(dbId); if (requestDto.isNewChat()) { - deletePreviousChat(graphId); + deletePreviousChat(dbId); } Chatting userChat = Chatting.builder() @@ -88,11 +89,11 @@ public CreateChatbotResponseDto createAnswerWithRAG(String graphStrId, CreateCha .build(); chattingRepository.save(userChat); - List chatHistory = chattingRepository.findAllByGraphId(graphId); + List chatHistory = chattingRepository.findAllByGraphId(dbId); // RAG 응답 생성 (응답 + 메타 포함) CreateChatbotResponseDto responseDto = graphRAGService.createAnswerWithGraphRAG( - graphId, + dbId, requestDto.getChatContent(), chatHistory ); @@ -107,18 +108,16 @@ public CreateChatbotResponseDto createAnswerWithRAG(String graphStrId, CreateCha // 기본 응답 생성 @Override public CreateChatbotResponseDto createSimpleAnswer(String graphStrId, CreateChatbotRequestDto createChatbotRequestDto) { - Long graphId = Long.valueOf(graphStrId); - - // 404 : 지식그래프 찾을 수 없음 - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphStrId)); + Graph graph = graphRepository.getByGraph(dbId); // 새로운 대화인 경우 기존 채팅 삭제 if (createChatbotRequestDto.isNewChat()) { - deletePreviousChat(graphId); + deletePreviousChat(dbId); } // 기존 채팅 내역 조회 - List chatHistory = chattingRepository.findAllByGraphId(graphId); + List chatHistory = chattingRepository.findAllByGraphId(dbId); // 사용자 입력 채팅 String newChat = createChatbotRequestDto.getChatContent(); diff --git a/src/main/java/com/going/server/domain/graph/entity/Graph.java b/src/main/java/com/going/server/domain/graph/entity/Graph.java index f36c93c..0a8f974 100644 --- a/src/main/java/com/going/server/domain/graph/entity/Graph.java +++ b/src/main/java/com/going/server/domain/graph/entity/Graph.java @@ -6,6 +6,7 @@ import lombok.Setter; import org.springframework.data.neo4j.core.schema.*; +import java.util.ArrayList; import java.util.List; @Node("Graph") @@ -29,11 +30,8 @@ public class Graph extends BaseEntity { private boolean connectPerfect; //connect 퀴즈 만접 여부 private boolean picturePerfect; //picture 퀴즈 만접 여부 + @Builder.Default @Relationship(type = "HAS_NODE", direction = Relationship.Direction.OUTGOING) - private List nodes; + private List nodes = new ArrayList<>(); - // Long → String 변환 (프론트 전송 시) - public String getIdAsString() { - return id != null ? String.valueOf(id) : null; - } } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java b/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java index c6ef65e..4dec151 100644 --- a/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java +++ b/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java @@ -1,13 +1,14 @@ package com.going.server.domain.graph.entity; -import com.fasterxml.jackson.annotation.JsonIgnore; import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; -import org.springframework.data.neo4j.core.schema.*; +import org.springframework.data.neo4j.core.schema.GeneratedValue; +import org.springframework.data.neo4j.core.schema.Id; +import org.springframework.data.neo4j.core.schema.RelationshipProperties; +import org.springframework.data.neo4j.core.schema.TargetNode; -import java.util.ArrayList; import java.util.Objects; @RelationshipProperties @@ -18,34 +19,30 @@ public class GraphEdge { @Id @GeneratedValue - private Long id; // Neo4j 내부 ID + private Long id; @EqualsAndHashCode.Include private String source; @EqualsAndHashCode.Include - private String label; // 관계 라벨 + private String label; @EqualsAndHashCode.Include @TargetNode - @Relationship(type = "RELATED", direction = Relationship.Direction.INCOMING) - @JsonIgnore - private GraphNode target; // 연결 대상 노드 + private GraphNode target; // 여기에 방향 붙이지 마세요 @Override public boolean equals(Object o) { if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - GraphEdge edge = (GraphEdge) o; - return Objects.equals(source, edge.source) - && Objects.equals(label, edge.label) - && edge.target != null && target != null - && Objects.equals(target.getNodeId(), edge.target.getNodeId()); + if (!(o instanceof GraphEdge edge)) return false; + return Objects.equals(source, edge.source) && + Objects.equals(label, edge.label) && + target != null && edge.target != null && + Objects.equals(target.getNodeId(), edge.target.getNodeId()); } @Override public int hashCode() { return Objects.hash(source, label, target != null ? target.getNodeId() : null); } - } diff --git a/src/main/java/com/going/server/domain/graph/entity/GraphNode.java b/src/main/java/com/going/server/domain/graph/entity/GraphNode.java index 630ef2a..add5550 100644 --- a/src/main/java/com/going/server/domain/graph/entity/GraphNode.java +++ b/src/main/java/com/going/server/domain/graph/entity/GraphNode.java @@ -6,6 +6,7 @@ import org.springframework.data.neo4j.core.schema.*; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Set; @@ -29,9 +30,9 @@ public class GraphNode { private String includeSentence; //해당 노드(단어)가 포함된 문장 private String image; -// @Transient // Neo4j가 매핑하지 않음 + @Builder.Default @ToString.Exclude @JsonIgnore @Relationship(type = "RELATED", direction = Relationship.Direction.OUTGOING) - private Set edges; + private Set edges = new HashSet<>(); } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java b/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java index e4ebb96..f9a2199 100644 --- a/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java +++ b/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java @@ -13,24 +13,23 @@ @Registered public interface GraphRepository extends Neo4jRepository { + default Graph getByGraph(Long graphId) { + return findById(graphId).orElseThrow(GraphNotFoundException::new); + } // default Graph getByGraph(Long graphId) { -// return findById(graphId).orElseThrow(GraphNotFoundException::new); +// return findGraphWithEdgesByGraphId(graphId).orElseThrow(GraphNotFoundException::new); // } - default Graph getByGraph(Long graphId) { - return findGraphWithEdgesByGraphId(graphId).orElseThrow(GraphNotFoundException::new); - } +// @Query(""" +//MATCH (g:Graph {id: $graphId})-[:HAS_NODE]->(n:GraphNode) +//OPTIONAL MATCH (n)-[r:RELATED]->(m:GraphNode) +//RETURN g, collect(DISTINCT n) as nodes, collect(DISTINCT r) as rels +//""") +// Optional findGraphWithEdgesByGraphId(@Param("graphId") Long graphId); - @Query("MATCH (g:Graph) WHERE g.id = $graphId RETURN g") - Optional findByGraphId(@Param("graphId") Long graphId); - // 그래프 + 노드 + 엣지까지 전부 fetch - @Query(""" -MATCH (g:Graph {id: $graphId})-[:HAS_NODE]->(n:GraphNode) -OPTIONAL MATCH (n)-[r]->(m:GraphNode) -RETURN g, collect(DISTINCT n) as nodes, collect(DISTINCT r) as rels, collect(DISTINCT m) as targets -""") - Optional findGraphWithEdgesByGraphId(@Param("graphId") Long graphId); + @Query("MATCH (g:Graph {id: $graphId}) RETURN id(g)") + Long findDbIdByGraphId(@Param("graphId") Long graphId); @Query("MATCH (g:Graph) RETURN max(g.id)") Long findMaxGraphId(); diff --git a/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java b/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java index 27aeb17..1d54cc0 100644 --- a/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java +++ b/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java @@ -63,7 +63,9 @@ public GraphListDto getGraphList() { @Override public void deleteGraph(Long graphId) { - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphId)); + Graph graph = graphRepository.getByGraph(dbId); + //그래프에 연결된 노드 삭제 if (graph.getNodes() != null) { graph.getNodes().forEach(node -> graphNodeRepository.deleteById(node.getId())); @@ -71,9 +73,11 @@ public void deleteGraph(Long graphId) { graphRepository.deleteById(graph.getId()); } + @Transactional(readOnly = true) @Override public KnowledgeGraphDto getGraph(Long graphId) { - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(graphId); + Graph graph = graphRepository.getByGraph(dbId); log.info("[getGraph] 조회된 Graph ID: {}, Title: {}", graph.getId(), graph.getTitle()); List nodeDtoList = new ArrayList<>(); @@ -112,7 +116,9 @@ public KnowledgeGraphDto getGraph(Long graphId) { @Override public NodeDto getNode(Long graphId, Long nodeId) { - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(graphId); + Graph graph = graphRepository.getByGraph(dbId); + //노드 찾기 GraphNode node = null; for (GraphNode n : graph.getNodes()) { @@ -132,7 +138,9 @@ public NodeDto getNode(Long graphId, Long nodeId) { @Override @Transactional public void addNode(Long graphId, NodeAddDto nodeAddDto) { - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(graphId); + Graph graph = graphRepository.getByGraph(dbId); + GraphNode parentNode = null; for (GraphNode node : graph.getNodes()) { if (node.getNodeId().equals(Long.parseLong(nodeAddDto.getParentId()))) { @@ -168,7 +176,7 @@ public void addNode(Long graphId, NodeAddDto nodeAddDto) { parentNode.getEdges().add(newEdge); graphNodeRepository.save(parentNode); - graph = graphRepository.getByGraph(graphId); + graph = graphRepository.getByGraph(dbId); if (graph.getNodes() == null) { graph.setNodes(new ArrayList<>()); @@ -181,7 +189,8 @@ public void addNode(Long graphId, NodeAddDto nodeAddDto) { @Override public void deleteNode(Long graphId, Long nodeId) { //그래프 검증 - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(graphId); + Graph graph = graphRepository.getByGraph(dbId); GraphNode node = null; for (GraphNode n : graph.getNodes()) { if (n.getNodeId().equals(nodeId)) { @@ -199,7 +208,8 @@ public void deleteNode(Long graphId, Long nodeId) { @Override public void modifyNode(Long graphId, Long nodeId, NodeModifyDto nodeModifyDto) { //그래프 검증 - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(graphId); + Graph graph = graphRepository.getByGraph(dbId); //노드 찾기 GraphNode node = null; for (GraphNode n : graph.getNodes()) { diff --git a/src/main/java/com/going/server/domain/quiz/service/QuizServiceImpl.java b/src/main/java/com/going/server/domain/quiz/service/QuizServiceImpl.java index 02dee5d..aa3e754 100644 --- a/src/main/java/com/going/server/domain/quiz/service/QuizServiceImpl.java +++ b/src/main/java/com/going/server/domain/quiz/service/QuizServiceImpl.java @@ -23,10 +23,9 @@ public class QuizServiceImpl implements QuizService{ // 모드 별 퀴즈 생성 @Override public QuizCreateResponseDto quizCreate(String graphIdStr, String mode) { - Long graphId = Long.valueOf(graphIdStr); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphIdStr)); + Graph graph = graphRepository.getByGraph(dbId); - // 404 : 지식그래프 찾을 수 없음 - Graph graph = graphRepository.getByGraph(graphId); Object quizDto = switch (mode) { case "listenUp" -> listenUpQuizGenerator.generate(graph); @@ -41,10 +40,8 @@ public QuizCreateResponseDto quizCreate(String graphIdStr, String mode) { // 만점일 경우 Graph Quiz 정보 업데이트 @Override public void updateIfPerfect(String graphIdStr, String mode) { - Long graphId = Long.valueOf(graphIdStr); - - // 404 : 지식 그래프 찾을 수 없음 - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphIdStr)); + Graph graph = graphRepository.getByGraph(dbId); switch (mode){ case "listenUp": diff --git a/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java index 84561d8..61c685a 100644 --- a/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java +++ b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java @@ -33,12 +33,12 @@ public class GraphRAGService { * 본 메서드는 LangChain 없이 구현한 Spring 기반 GraphRAG의 핵심 흐름입니다. */ public CreateChatbotResponseDto createAnswerWithGraphRAG( - Long graphId, + Long dbId, String userQuestion, List chatHistory ) { - Graph graph = graphRepository.getByGraph(graphId); - log.info("[GraphRAG] graphId: {}, question: {}", graphId, userQuestion); + Graph graph = graphRepository.getByGraph(dbId); + log.info("[GraphRAG] dbId: {}, question: {}", dbId, userQuestion); // 1. 질문 → Cypher 쿼리 생성 String cypherQuery = cypherQueryGenerator.generate(userQuestion).trim() @@ -48,7 +48,7 @@ public CreateChatbotResponseDto createAnswerWithGraphRAG( log.info("[GraphRAG] Generated Cypher Query:\n{}", cypherQuery); // 2. 쿼리 실행 → 문맥(context) 및 노드 라벨 추출 - List queryResults = graphQueryExecutor.runQuery(graphId, cypherQuery); + List queryResults = graphQueryExecutor.runQuery(dbId, cypherQuery); List contextChunks = queryResults.stream() .map(GraphQueryResult::getSentence) .toList(); @@ -76,7 +76,7 @@ public CreateChatbotResponseDto createAnswerWithGraphRAG( return CreateChatbotResponseDto.of( response, - graphId.toString(), + dbId.toString(), answer.getCreatedAt(), contextChunks, sourceNodes diff --git a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java index 48b6c6a..c52521e 100644 --- a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java +++ b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java @@ -58,8 +58,7 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { Map paresData = pdfOcrService.parse(jsonResponse); String text = paresData.get("읽기자료"); log.info("text log={}",text); - - //모델에 돌린 값을 받아옴 + //모델에 돌린 값을 받아옴 String response = setModelData(text); ObjectMapper mapper = new ObjectMapper();