diff --git a/Backend_Config b/Backend_Config index 3c9dbc1..d2b85ef 160000 --- a/Backend_Config +++ b/Backend_Config @@ -1 +1 @@ -Subproject commit 3c9dbc1b7a76e502834f0986874e74e00e5682eb +Subproject commit d2b85ef76f4279d58f4f002c2cf06c467e2af103 diff --git a/src/main/java/com/going/server/domain/chatbot/dto/CreateChatbotResponseDto.java b/src/main/java/com/going/server/domain/chatbot/dto/CreateChatbotResponseDto.java index 5864666..9e43ebe 100644 --- a/src/main/java/com/going/server/domain/chatbot/dto/CreateChatbotResponseDto.java +++ b/src/main/java/com/going/server/domain/chatbot/dto/CreateChatbotResponseDto.java @@ -11,11 +11,28 @@ @Getter @Builder public class CreateChatbotResponseDto { - private String chatContent; // 챗봇 응답 - private String graphId; // 지식그래프 ID + private String chatContent; + private String graphId; @JsonFormat(pattern = "yyyy-MM-dd'T'HH:mm") - private LocalDateTime createdAt; // 응답 생성 시각 - private List retrievedChunks; // RAG: 검색된 문장들 - private List sourceNodes; // RAG: 참조된 지식그래프 노드 ID - private Map ragMeta; // RAG: 점수, 검색 method 등 + private LocalDateTime createdAt; + private List retrievedChunks; + private List sourceNodes; + private Map ragMeta; + + public static CreateChatbotResponseDto of( + String chatContent, + String graphId, + LocalDateTime createdAt, + List retrievedChunks, + List sourceNodes + ) { + return CreateChatbotResponseDto.builder() + .chatContent(chatContent) + .graphId(graphId) + .createdAt(createdAt) + .retrievedChunks(retrievedChunks) + .sourceNodes(sourceNodes) + .ragMeta(Map.of("chunkCount", String.valueOf(retrievedChunks.size()))) + .build(); + } } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/chatbot/entity/Chatting.java b/src/main/java/com/going/server/domain/chatbot/entity/Chatting.java index 2bd6596..510ffa9 100644 --- a/src/main/java/com/going/server/domain/chatbot/entity/Chatting.java +++ b/src/main/java/com/going/server/domain/chatbot/entity/Chatting.java @@ -26,4 +26,22 @@ public class Chatting { private Sender sender; private LocalDateTime createdAt; + + public static Chatting ofUser(Graph graph, String content) { + return Chatting.builder() + .graph(graph) + .content(content) + .sender(Sender.USER) + .createdAt(LocalDateTime.now()) + .build(); + } + + public static Chatting ofGPT(Graph graph, String content) { + return Chatting.builder() + .graph(graph) + .content(content) + .sender(Sender.GPT) + .createdAt(LocalDateTime.now()) + .build(); + } } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java b/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java index 5646ee8..dc91172 100644 --- a/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java +++ b/src/main/java/com/going/server/domain/chatbot/service/ChatbotServiceImpl.java @@ -6,15 +6,14 @@ import com.going.server.domain.chatbot.entity.Sender; import com.going.server.domain.chatbot.repository.ChattingRepository; import com.going.server.domain.graph.entity.Graph; -import com.going.server.domain.graph.entity.GraphNode; import com.going.server.domain.graph.exception.GraphContentNotFoundException; import com.going.server.domain.graph.repository.GraphRepository; import com.going.server.domain.graph.repository.GraphNodeRepository; import com.going.server.domain.openai.dto.ImageCreateRequestDto; import com.going.server.domain.openai.service.ImageCreateService; -import com.going.server.domain.openai.service.RAGAnswerCreateService; import com.going.server.domain.openai.service.SimpleAnswerCreateService; import com.going.server.domain.openai.service.TextSummaryCreateService; +import com.going.server.domain.rag.service.GraphRAGService; import com.going.server.domain.rag.service.SimilarityFilterService; import com.going.server.domain.rag.util.PromptBuilder; import lombok.RequiredArgsConstructor; @@ -24,8 +23,6 @@ import java.time.LocalDateTime; import java.util.*; -import java.util.stream.Collectors; - @Service @RequiredArgsConstructor @Transactional @@ -38,13 +35,15 @@ public class ChatbotServiceImpl implements ChatbotService { // openai 관련 service private final TextSummaryCreateService textSummaryCreateService; private final SimpleAnswerCreateService simpleAnswerCreateService; - private final RAGAnswerCreateService ragAnswerCreateService; private final ImageCreateService imageCreateService; + // graphRAG + private final GraphRAGService graphRAGService; // 원문 반환 @Override public CreateChatbotResponseDto getOriginalText(String graphId) { - Graph graph = graphRepository.getByGraph(Long.valueOf(graphId)); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphId)); + Graph graph = graphRepository.getByGraph(dbId); return CreateChatbotResponseDto.builder() .chatContent(graph.getContent()) // 원문 텍스트 @@ -56,7 +55,8 @@ public CreateChatbotResponseDto getOriginalText(String graphId) { // 요약본 생성 @Override public CreateChatbotResponseDto getSummaryText(String graphId) { - Graph graph = graphRepository.getByGraph(Long.valueOf(graphId)); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphId)); + Graph graph = graphRepository.getByGraph(dbId); String context = Optional.ofNullable(graph.getContent()) .filter(s -> !s.trim().isEmpty()) @@ -71,120 +71,53 @@ public CreateChatbotResponseDto getSummaryText(String graphId) { .build(); } - - // RAG 챗봇 응답 생성 + // GraphRAG 챗봇 응답 생성 @Override - public CreateChatbotResponseDto createAnswerWithRAG(String graphStrId, CreateChatbotRequestDto createChatbotRequestDto) { - Long graphId = Long.valueOf(graphStrId); - - // 404 : 지식그래프 찾을 수 없음 - Graph graph = graphRepository.getByGraph(graphId); - - // RAG: 사용자 질문 - String userQuestion = createChatbotRequestDto.getChatContent(); - - // RAG : 키워드 추출 - List keywords = extractKeywords(userQuestion); - System.out.println("[RAG] 추출된 키워드: " + keywords); - - // RAG: 유사 노드 검색 및 문장 추출 - List matchedNodes = graphNodeRepository.findByGraphIdAndKeywords(graphId, keywords); -// List matchedNodes = graphNodeRepository.findByGraphIdAndKeywordsWithEdges(graphId, keywords); - System.out.println("[RAG] matchedNodes: " + matchedNodes); - - List candidateSentences = matchedNodes.stream() - .map(GraphNode::getIncludeSentence) - .filter(Objects::nonNull) - .distinct() - .collect(Collectors.toList()); - - // RAG: 유사 문장 필터링 - List filteredChunks = similarityFilterService.filterRelevantSentences(userQuestion, candidateSentences); - - // RAG: 최종 프롬프트 구성 - String finalPrompt = promptBuilder.buildPrompt(filteredChunks, userQuestion); - System.out.println("finalPrompt: " + finalPrompt); - - // RAG: 메타정보 수집 - List retrievedChunks = new ArrayList<>(filteredChunks); - List sourceNodes = new ArrayList<>( - matchedNodes.stream().map(GraphNode::getLabel).distinct().toList() - ); - Map ragMeta = Map.of( - "chunkCount", String.valueOf(filteredChunks.size()) - ); + public CreateChatbotResponseDto createAnswerWithRAG(String graphStrId, CreateChatbotRequestDto requestDto) { + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphStrId)); + Graph graph = graphRepository.getByGraph(dbId); - // 새로운 대화인 경우 기존 채팅 삭제 - if (createChatbotRequestDto.isNewChat()) { - deletePreviousChat(graphId); + if (requestDto.isNewChat()) { + deletePreviousChat(dbId); } - // 기존 채팅 내역 조회 - List chatHistory = chattingRepository.findAllByGraphId(graphId); - - // 새로운 채팅 - String newChat = createChatbotRequestDto.getChatContent(); - - // 새로운 채팅 repository에 저장 - Chatting chatting = Chatting.builder() + Chatting userChat = Chatting.builder() .graph(graph) - .content(newChat) + .content(requestDto.getChatContent()) .sender(Sender.USER) .createdAt(LocalDateTime.now()) .build(); - chattingRepository.save(chatting); - - // 응답 생성 - String chatContent; - - // RAG: 유사 문장이 있을 경우 컨텍스트 활용 - if (retrievedChunks.isEmpty()) { - System.out.println("[INFO] RAG 미적용 - 일반 채팅 기반 응답"); - System.out.println("[INFO] RAG 미적용 - 유사 문장 없음"); - System.out.println("[DEBUG] matchedNodes.size(): " + matchedNodes.size()); - System.out.println("[DEBUG] candidateSentences.size(): " + candidateSentences.size()); - System.out.println("[DEBUG] filteredChunks.size(): " + filteredChunks.size()); - chatContent = ragAnswerCreateService.chat(chatHistory, newChat); - } else { - System.out.println("[INFO] RAG 적용됨 - 유사 문장 " + retrievedChunks.size() + "개 포함"); - chatContent = ragAnswerCreateService.chatWithContext(chatHistory, finalPrompt); - } + chattingRepository.save(userChat); - // 응답 repository에 저장 - Chatting answer = Chatting.builder() - .graph(graph) - .content(chatContent) - .sender(Sender.GPT) - .createdAt(LocalDateTime.now()) - .build(); - chattingRepository.save(answer); + List chatHistory = chattingRepository.findAllByGraphId(dbId); - // 반환 - return CreateChatbotResponseDto.builder() - .chatContent(chatContent) - .graphId(graphStrId) - .createdAt(answer.getCreatedAt()) - .retrievedChunks(retrievedChunks) - .sourceNodes(sourceNodes) - .ragMeta(ragMeta) - .build(); + // RAG 응답 생성 (응답 + 메타 포함) + CreateChatbotResponseDto responseDto = graphRAGService.createAnswerWithGraphRAG( + dbId, + requestDto.getChatContent(), + chatHistory + ); + + // 응답 채팅 저장 + Chatting gptChat = Chatting.ofGPT(graph, responseDto.getChatContent()); + chattingRepository.save(gptChat); + + return responseDto; } - // RAG 사용하지 않는 응답 생성 + // 기본 응답 생성 @Override public CreateChatbotResponseDto createSimpleAnswer(String graphStrId, CreateChatbotRequestDto createChatbotRequestDto) { - Long graphId = Long.valueOf(graphStrId); - - // 404 : 지식그래프 찾을 수 없음 - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphStrId)); + Graph graph = graphRepository.getByGraph(dbId); // 새로운 대화인 경우 기존 채팅 삭제 if (createChatbotRequestDto.isNewChat()) { - deletePreviousChat(graphId); + deletePreviousChat(dbId); } // 기존 채팅 내역 조회 - List chatHistory = chattingRepository.findAllByGraphId(graphId); + List chatHistory = chattingRepository.findAllByGraphId(dbId); // 사용자 입력 채팅 String newChat = createChatbotRequestDto.getChatContent(); @@ -293,17 +226,4 @@ private void deletePreviousChat(Long graphId) { chattingRepository.deleteByGraphId(graphId); } - - // RAG : 키워드 추출 - private List extractKeywords(String text) { - List stopwords = List.of("은", "는", "이", "가", "을", "를", "에", "의", "와", "과", "에서", "하다"); - - return Arrays.stream(text.split("[\\s,.!?]+")) - .map(word -> word.replaceAll("(은|는|이|가|을|를|에|의|와|과|에서)$", "")) // ✅ 조사 제거 - .map(String::toLowerCase) - .filter(word -> word.length() > 1 && !stopwords.contains(word)) - .distinct() - .limit(5) - .collect(Collectors.toList()); - } } diff --git a/src/main/java/com/going/server/domain/graph/entity/Graph.java b/src/main/java/com/going/server/domain/graph/entity/Graph.java index b2c9fcb..0a8f974 100644 --- a/src/main/java/com/going/server/domain/graph/entity/Graph.java +++ b/src/main/java/com/going/server/domain/graph/entity/Graph.java @@ -6,6 +6,7 @@ import lombok.Setter; import org.springframework.data.neo4j.core.schema.*; +import java.util.ArrayList; import java.util.List; @Node("Graph") @@ -15,7 +16,10 @@ public class Graph extends BaseEntity { @Id @GeneratedValue - private Long id; //그래프 id -> 프론트와 통신에서는 String 값으로 사용 + private Long dbId; // 내부 관리용 elementId와 연결됨 + + @Property("id") + private Long id; // 우리가 직접 사용하는 명시적 ID private String title; @@ -26,11 +30,8 @@ public class Graph extends BaseEntity { private boolean connectPerfect; //connect 퀴즈 만접 여부 private boolean picturePerfect; //picture 퀴즈 만접 여부 + @Builder.Default @Relationship(type = "HAS_NODE", direction = Relationship.Direction.OUTGOING) - private List nodes; + private List nodes = new ArrayList<>(); - // Long → String 변환 (프론트 전송 시) - public String getIdAsString() { - return id != null ? String.valueOf(id) : null; - } } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java b/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java index 22f01e9..4dec151 100644 --- a/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java +++ b/src/main/java/com/going/server/domain/graph/entity/GraphEdge.java @@ -1,11 +1,15 @@ package com.going.server.domain.graph.entity; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; -import org.springframework.data.neo4j.core.schema.*; +import org.springframework.data.neo4j.core.schema.GeneratedValue; +import org.springframework.data.neo4j.core.schema.Id; +import org.springframework.data.neo4j.core.schema.RelationshipProperties; +import org.springframework.data.neo4j.core.schema.TargetNode; -import java.util.ArrayList; +import java.util.Objects; @RelationshipProperties @Getter @@ -15,32 +19,30 @@ public class GraphEdge { @Id @GeneratedValue - private Long id; // Neo4j 내부 ID + private Long id; + @EqualsAndHashCode.Include private String source; - private String label; // 관계 라벨 + @EqualsAndHashCode.Include + private String label; + @EqualsAndHashCode.Include @TargetNode - private GraphNode target; // 연결 대상 노드 - -// private GraphEdge createEdge(Long edgeId, String label, GraphNode source, GraphNode target) { -// GraphEdge edge = new GraphEdge(); -// edge.setEdgeId(edgeId); -// edge.setLabel(label); -// edge.setTarget(target); -// -// // outbound edge를 source에 연결 -// if (source.getEdges() == null) { -// source.setEdges(new ArrayList<>()); -// } -// source.getEdges().add(edge); -// return edge; -// } - - // Long → String 변환 (프론트 전송 시) - public String getIdAsString() { - return id != null ? String.valueOf(id) : null; + private GraphNode target; // 여기에 방향 붙이지 마세요 + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof GraphEdge edge)) return false; + return Objects.equals(source, edge.source) && + Objects.equals(label, edge.label) && + target != null && edge.target != null && + Objects.equals(target.getNodeId(), edge.target.getNodeId()); } + @Override + public int hashCode() { + return Objects.hash(source, label, target != null ? target.getNodeId() : null); + } } diff --git a/src/main/java/com/going/server/domain/graph/entity/GraphNode.java b/src/main/java/com/going/server/domain/graph/entity/GraphNode.java index 33e19eb..add5550 100644 --- a/src/main/java/com/going/server/domain/graph/entity/GraphNode.java +++ b/src/main/java/com/going/server/domain/graph/entity/GraphNode.java @@ -1,12 +1,12 @@ package com.going.server.domain.graph.entity; import com.fasterxml.jackson.annotation.JsonIgnore; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import lombok.*; +import org.springframework.data.annotation.Transient; import org.springframework.data.neo4j.core.schema.*; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Set; @@ -14,6 +14,8 @@ @Getter @Setter @Builder +@AllArgsConstructor +@NoArgsConstructor // Neo4j가 생성자 주입 대신 setter 기반으로 생성할 수 있도록 해줌 public class GraphNode { @Id @GeneratedValue @@ -28,13 +30,9 @@ public class GraphNode { private String includeSentence; //해당 노드(단어)가 포함된 문장 private String image; -// @Relationship(type = "HAS_GRAPH", direction = Relationship.Direction.INCOMING) -// private Graph graph; - + @Builder.Default + @ToString.Exclude + @JsonIgnore @Relationship(type = "RELATED", direction = Relationship.Direction.OUTGOING) - private Set edges; - - public String getIdAsString() { - return id != null ? String.valueOf(id) : null; - } + private Set edges = new HashSet<>(); } \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/graph/repository/GraphNodeRepository.java b/src/main/java/com/going/server/domain/graph/repository/GraphNodeRepository.java index 4df135b..e6395d7 100644 --- a/src/main/java/com/going/server/domain/graph/repository/GraphNodeRepository.java +++ b/src/main/java/com/going/server/domain/graph/repository/GraphNodeRepository.java @@ -9,6 +9,7 @@ import java.util.Optional; public interface GraphNodeRepository extends Neo4jRepository { + Optional findByNodeId(Long nodeId); default GraphNode getByNode(Long nodeId) { return findByNodeId(nodeId).orElseThrow(NodeNotFoundException::new); diff --git a/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java b/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java index fa59140..f9a2199 100644 --- a/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java +++ b/src/main/java/com/going/server/domain/graph/repository/GraphRepository.java @@ -1,13 +1,36 @@ package com.going.server.domain.graph.repository; import com.going.server.domain.graph.entity.Graph; +import com.going.server.domain.graph.entity.GraphNode; import com.going.server.domain.graph.exception.GraphNotFoundException; import jdk.jfr.Registered; import org.springframework.data.neo4j.repository.Neo4jRepository; +import org.springframework.data.neo4j.repository.query.Query; +import org.springframework.data.repository.query.Param; + +import java.util.List; +import java.util.Optional; @Registered public interface GraphRepository extends Neo4jRepository { default Graph getByGraph(Long graphId) { return findById(graphId).orElseThrow(GraphNotFoundException::new); } +// default Graph getByGraph(Long graphId) { +// return findGraphWithEdgesByGraphId(graphId).orElseThrow(GraphNotFoundException::new); +// } + +// @Query(""" +//MATCH (g:Graph {id: $graphId})-[:HAS_NODE]->(n:GraphNode) +//OPTIONAL MATCH (n)-[r:RELATED]->(m:GraphNode) +//RETURN g, collect(DISTINCT n) as nodes, collect(DISTINCT r) as rels +//""") +// Optional findGraphWithEdgesByGraphId(@Param("graphId") Long graphId); + + + @Query("MATCH (g:Graph {id: $graphId}) RETURN id(g)") + Long findDbIdByGraphId(@Param("graphId") Long graphId); + + @Query("MATCH (g:Graph) RETURN max(g.id)") + Long findMaxGraphId(); } diff --git a/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java b/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java index 730e9cd..1d54cc0 100644 --- a/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java +++ b/src/main/java/com/going/server/domain/graph/service/GraphServiceImpl.java @@ -44,6 +44,7 @@ public class GraphServiceImpl implements GraphService { @Override public GraphListDto getGraphList() { List graphs = graphRepository.findAll(); + List graphDtos = new ArrayList<>(); for (Graph graph : graphs) { @@ -62,7 +63,9 @@ public GraphListDto getGraphList() { @Override public void deleteGraph(Long graphId) { - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphId)); + Graph graph = graphRepository.getByGraph(dbId); + //그래프에 연결된 노드 삭제 if (graph.getNodes() != null) { graph.getNodes().forEach(node -> graphNodeRepository.deleteById(node.getId())); @@ -70,31 +73,52 @@ public void deleteGraph(Long graphId) { graphRepository.deleteById(graph.getId()); } + @Transactional(readOnly = true) @Override public KnowledgeGraphDto getGraph(Long graphId) { - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(graphId); + Graph graph = graphRepository.getByGraph(dbId); + log.info("[getGraph] 조회된 Graph ID: {}, Title: {}", graph.getId(), graph.getTitle()); List nodeDtoList = new ArrayList<>(); List edgeDtoList = new ArrayList<>(); - for (GraphNode node : graph.getNodes()) { - NodeDto nodeDto = NodeDto.from(node); - nodeDtoList.add(nodeDto); - - if (node.getEdges() != null) { - for (GraphEdge edge : node.getEdges()) { - EdgeDto edgeDto = EdgeDto.from(edge.getSource(),edge.getTarget().getNodeId().toString(),edge.getLabel()); - edgeDtoList.add(edgeDto); - } + List nodes = graph.getNodes(); + if (nodes == null || nodes.isEmpty()) { + log.warn("[getGraph] 해당 그래프에는 노드가 없습니다."); + return KnowledgeGraphDto.of(nodeDtoList, edgeDtoList); + } + + for (GraphNode node : nodes) { + if (node == null) continue; + + log.debug("[getGraph] 노드 추가 - ID: {}, Label: {}", node.getNodeId(), node.getLabel()); + nodeDtoList.add(NodeDto.from(node)); + + Set edges = node.getEdges(); + if (edges == null || edges.isEmpty()) continue; + + for (GraphEdge edge : edges) { + if (edge == null || edge.getTarget() == null) continue; + + String sourceId = edge.getSource(); + String targetId = edge.getTarget().getNodeId().toString(); + String label = edge.getLabel(); + + log.debug("[getGraph] 엣지 추가 - Source: {}, Target: {}, Label: {}", sourceId, targetId, label); + edgeDtoList.add(EdgeDto.from(sourceId, targetId, label)); } } return KnowledgeGraphDto.of(nodeDtoList, edgeDtoList); } + @Override public NodeDto getNode(Long graphId, Long nodeId) { - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(graphId); + Graph graph = graphRepository.getByGraph(dbId); + //노드 찾기 GraphNode node = null; for (GraphNode n : graph.getNodes()) { @@ -114,7 +138,9 @@ public NodeDto getNode(Long graphId, Long nodeId) { @Override @Transactional public void addNode(Long graphId, NodeAddDto nodeAddDto) { - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(graphId); + Graph graph = graphRepository.getByGraph(dbId); + GraphNode parentNode = null; for (GraphNode node : graph.getNodes()) { if (node.getNodeId().equals(Long.parseLong(nodeAddDto.getParentId()))) { @@ -150,7 +176,7 @@ public void addNode(Long graphId, NodeAddDto nodeAddDto) { parentNode.getEdges().add(newEdge); graphNodeRepository.save(parentNode); - graph = graphRepository.getByGraph(graphId); + graph = graphRepository.getByGraph(dbId); if (graph.getNodes() == null) { graph.setNodes(new ArrayList<>()); @@ -163,7 +189,8 @@ public void addNode(Long graphId, NodeAddDto nodeAddDto) { @Override public void deleteNode(Long graphId, Long nodeId) { //그래프 검증 - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(graphId); + Graph graph = graphRepository.getByGraph(dbId); GraphNode node = null; for (GraphNode n : graph.getNodes()) { if (n.getNodeId().equals(nodeId)) { @@ -181,7 +208,8 @@ public void deleteNode(Long graphId, Long nodeId) { @Override public void modifyNode(Long graphId, Long nodeId, NodeModifyDto nodeModifyDto) { //그래프 검증 - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(graphId); + Graph graph = graphRepository.getByGraph(dbId); //노드 찾기 GraphNode node = null; for (GraphNode n : graph.getNodes()) { diff --git a/src/main/java/com/going/server/domain/openai/service/OpenAIService.java b/src/main/java/com/going/server/domain/openai/service/OpenAIService.java new file mode 100644 index 0000000..ada3955 --- /dev/null +++ b/src/main/java/com/going/server/domain/openai/service/OpenAIService.java @@ -0,0 +1,31 @@ +package com.going.server.domain.openai.service; + +import com.going.server.domain.openai.dto.ChatCompletionRequestDto; +import com.theokanning.openai.OpenAiService; +import com.theokanning.openai.completion.chat.ChatMessage; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service +@RequiredArgsConstructor +public class OpenAIService { + + private final OpenAiService openAiService; + + public String getCompletionResponse(List messages, String model, double temperature, int maxTokens) { + ChatCompletionRequestDto request = ChatCompletionRequestDto.builder() + .model(model) + .temperature(temperature) + .maxTokens(maxTokens) + .messages(messages) + .build(); + + return openAiService.createChatCompletion(request.toRequest()) + .getChoices() + .get(0) + .getMessage() + .getContent(); + } +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/openai/service/RAGAnswerCreateService.java b/src/main/java/com/going/server/domain/openai/service/RAGAnswerCreateService.java deleted file mode 100644 index 8afd16d..0000000 --- a/src/main/java/com/going/server/domain/openai/service/RAGAnswerCreateService.java +++ /dev/null @@ -1,99 +0,0 @@ -package com.going.server.domain.openai.service; - -import com.going.server.domain.chatbot.entity.Chatting; -import com.going.server.domain.chatbot.entity.Sender; -import com.going.server.domain.openai.dto.ChatCompletionRequestDto; -import com.theokanning.openai.OpenAiService; -import com.theokanning.openai.completion.chat.ChatMessage; -import lombok.RequiredArgsConstructor; -import org.springframework.stereotype.Service; - -import java.util.ArrayList; -import java.util.List; - -@Service -@RequiredArgsConstructor -public class RAGAnswerCreateService { - private final OpenAiService openAiService; - - // 시스템 역할 설정 - private static final String SYSTEM_PROMPT = """ - 당신은 초등학생의 이해를 돕는 친절하고 정확한 지식 튜터입니다. - - 아래 제공된 데이터를 기반으로 질문에 대해 매우 길고 정확하게 설명해주세요. - - 만약 참고 데이터가 없다면, 교육 도메인의 일반적인 지식을 바탕으로 충실하게 답변해주세요. - - 반드시 한글로만 응답하고, 인사말이나 불필요한 문장은 생략한 대답만 반환하세요. - """; - - // 모델 스펙 정의 - private static final String MODEL_NAME = "gpt-4-0125-preview"; -// private static final String MODEL_NAME = "gpt-4o"; - - private static final double TEMPERATURE = 0.3 ; - private static final int MAX_TOKENS = 3000; - - // 기존 채팅 이력을 기반으로 GPT 응답 생성 - public String chat(List chatHistory, String question) { - - // 메세지 구성 - List messages = new ArrayList<>(); - messages.add(new ChatMessage("system", SYSTEM_PROMPT)); // 프롬프트 설정 -// messages.addAll(convertHistoryToMessages(chatHistory)); // 기존 채팅 - messages.add(new ChatMessage("user", question)); // 새로운 질문 - - // DTO 기반 요청 생성 - ChatCompletionRequestDto request = buildRequest(messages); - - // OpenAI 모델에게 질문 및 응답 생성 - return getResponseText(request); - } - - // RAG 컨텍스트 기반 + 기존 채팅 이력을 함께 사용하는 GPT 응답 생성 - public String chatWithContext(List chatHistory, String finalPrompt) { - - List messages = new ArrayList<>(); - messages.add(new ChatMessage("system", SYSTEM_PROMPT)); - - // 기존 대화 이력 추가 -// messages.addAll(convertHistoryToMessages(chatHistory)); - - // 마지막 질문을 RAG 컨텍스트 기반으로 전달 - messages.add(new ChatMessage("user", finalPrompt)); - - ChatCompletionRequestDto request = buildRequest(messages); - return getResponseText(request); - } - - // 요청 생성 메서드 - private ChatCompletionRequestDto buildRequest(List messages) { - return ChatCompletionRequestDto.builder() - .model(MODEL_NAME) - .temperature(TEMPERATURE) - .maxTokens(MAX_TOKENS) - .messages(messages) - .build(); - } - - // 응답 추출 메서드 - private String getResponseText(ChatCompletionRequestDto request) { - return openAiService.createChatCompletion(request.toRequest()) - .getChoices() - .get(0) - .getMessage() - .getContent(); - } - - // Chatting 엔티티를 OpenAI ChatMessage로 변환 - private List convertHistoryToMessages(List chatHistory) { - return chatHistory.stream() - .map(chat -> new ChatMessage( - convertSenderToRole(chat.getSender()), - chat.getContent() - )) - .toList(); - } - - // Chatting 엔티티의 Sender(Enum type) -> OpenAI 역할 문자열 변환 - private String convertSenderToRole(Sender sender) { - return sender == Sender.USER ? "user" : "assistant"; - } -} diff --git a/src/main/java/com/going/server/domain/quiz/service/QuizServiceImpl.java b/src/main/java/com/going/server/domain/quiz/service/QuizServiceImpl.java index 02dee5d..aa3e754 100644 --- a/src/main/java/com/going/server/domain/quiz/service/QuizServiceImpl.java +++ b/src/main/java/com/going/server/domain/quiz/service/QuizServiceImpl.java @@ -23,10 +23,9 @@ public class QuizServiceImpl implements QuizService{ // 모드 별 퀴즈 생성 @Override public QuizCreateResponseDto quizCreate(String graphIdStr, String mode) { - Long graphId = Long.valueOf(graphIdStr); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphIdStr)); + Graph graph = graphRepository.getByGraph(dbId); - // 404 : 지식그래프 찾을 수 없음 - Graph graph = graphRepository.getByGraph(graphId); Object quizDto = switch (mode) { case "listenUp" -> listenUpQuizGenerator.generate(graph); @@ -41,10 +40,8 @@ public QuizCreateResponseDto quizCreate(String graphIdStr, String mode) { // 만점일 경우 Graph Quiz 정보 업데이트 @Override public void updateIfPerfect(String graphIdStr, String mode) { - Long graphId = Long.valueOf(graphIdStr); - - // 404 : 지식 그래프 찾을 수 없음 - Graph graph = graphRepository.getByGraph(graphId); + Long dbId = graphRepository.findDbIdByGraphId(Long.valueOf(graphIdStr)); + Graph graph = graphRepository.getByGraph(dbId); switch (mode){ case "listenUp": diff --git a/src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java b/src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java new file mode 100644 index 0000000..ef185d5 --- /dev/null +++ b/src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java @@ -0,0 +1,11 @@ +package com.going.server.domain.rag.dto; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +@AllArgsConstructor +public class GraphQueryResult { + private String sentence; + private String nodeLabel; +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/dto/SimilarityRequestDto.java b/src/main/java/com/going/server/domain/rag/dto/SimilarityRequestDto.java deleted file mode 100644 index 0f32fce..0000000 --- a/src/main/java/com/going/server/domain/rag/dto/SimilarityRequestDto.java +++ /dev/null @@ -1,4 +0,0 @@ -package com.going.server.domain.rag.dto; - -public class SimilarityRequestDto { -} diff --git a/src/main/java/com/going/server/domain/rag/dto/SimilarityResponseDto.java b/src/main/java/com/going/server/domain/rag/dto/SimilarityResponseDto.java deleted file mode 100644 index 46a6974..0000000 --- a/src/main/java/com/going/server/domain/rag/dto/SimilarityResponseDto.java +++ /dev/null @@ -1,4 +0,0 @@ -package com.going.server.domain.rag.dto; - -public class SimilarityResponseDto { -} diff --git a/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java b/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java new file mode 100644 index 0000000..e42a735 --- /dev/null +++ b/src/main/java/com/going/server/domain/rag/service/CypherQueryGenerator.java @@ -0,0 +1,38 @@ +package com.going.server.domain.rag.service; + +import com.going.server.domain.openai.service.OpenAIService; +import com.theokanning.openai.OpenAiService; +import com.theokanning.openai.completion.chat.ChatMessage; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; + +import java.util.List; + +// 1. 질문 → Cypher 쿼리 생성 (LLM) +@Component +@RequiredArgsConstructor +public class CypherQueryGenerator { + private final OpenAIService openAIService; + + public String generate(String userQuestion) { + String prompt = """ + 당신은 Neo4j용 Cypher 쿼리를 생성하는 AI입니다. + 주어진 질문에 대해 Cypher 쿼리만 반환하세요. 코드블록, 설명 없이 오직 쿼리만 출력해야 합니다. + + 예: + 질문: "고래와 관련된 개념들을 알려줘" + → MATCH (n:GraphNode)-[r]->(m:GraphNode)\s + WHERE n.label = '고래'\s + RETURN m.label AS nodeLabel, m.includeSentence AS sentence\s + LIMIT 10 + + 질문: "${userQuestion}" + → + """.formatted(userQuestion); + + return openAIService.getCompletionResponse( + List.of(new ChatMessage("user", prompt)), + "gpt-4-0125-preview", 0.2, 1000 + ); + } +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java b/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java new file mode 100644 index 0000000..8b506d1 --- /dev/null +++ b/src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java @@ -0,0 +1,39 @@ +package com.going.server.domain.rag.service; + +import com.going.server.domain.rag.dto.GraphQueryResult; +import lombok.RequiredArgsConstructor; +import org.neo4j.driver.*; +import org.neo4j.driver.Record; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.List; + +// 2. 쿼리 실행 → 결과 추출 +@Component +@RequiredArgsConstructor +public class GraphQueryExecutor { + + private final Driver neo4jDriver; // Neo4j Java Driver + + public List runQuery(Long graphId, String cypherQuery) { + List results = new ArrayList<>(); + + try (Session session = neo4jDriver.session()) { + Result result = session.run(cypherQuery); + while (result.hasNext()) { + Record record = result.next(); + + // 필드 이름은 Cypher 쿼리 결과와 일치해야 함 + String sentence = record.get("sentence").asString(""); + String nodeLabel = record.get("nodeLabel").asString(""); + + results.add(new GraphQueryResult(sentence, nodeLabel)); + } + } catch (Exception e) { + e.printStackTrace(); + } + + return results; + } +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java new file mode 100644 index 0000000..61c685a --- /dev/null +++ b/src/main/java/com/going/server/domain/rag/service/GraphRAGService.java @@ -0,0 +1,85 @@ +package com.going.server.domain.rag.service; + +import com.going.server.domain.chatbot.dto.CreateChatbotResponseDto; +import com.going.server.domain.chatbot.entity.Chatting; +import com.going.server.domain.chatbot.repository.ChattingRepository; +import com.going.server.domain.graph.entity.Graph; +import com.going.server.domain.graph.repository.GraphNodeRepository; +import com.going.server.domain.graph.repository.GraphRepository; +import com.going.server.domain.rag.dto.GraphQueryResult; +import com.going.server.domain.rag.util.PromptBuilder; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service +@RequiredArgsConstructor +@Slf4j +public class GraphRAGService { + + private final GraphRepository graphRepository; + private final GraphNodeRepository graphNodeRepository; + private final SimilarityFilterService similarityFilterService; + private final PromptBuilder promptBuilder; + private final ChattingRepository chattingRepository; + private final CypherQueryGenerator cypherQueryGenerator; + private final GraphQueryExecutor graphQueryExecutor; + private final RagAnswerCreateService ragAnswerCreateService; + + /** + * 사용자 질문에 대해 Cypher 쿼리 → 그래프 정보 검색 → 프롬프트 생성 → LLM 응답 생성 + * 본 메서드는 LangChain 없이 구현한 Spring 기반 GraphRAG의 핵심 흐름입니다. + */ + public CreateChatbotResponseDto createAnswerWithGraphRAG( + Long dbId, + String userQuestion, + List chatHistory + ) { + Graph graph = graphRepository.getByGraph(dbId); + log.info("[GraphRAG] dbId: {}, question: {}", dbId, userQuestion); + + // 1. 질문 → Cypher 쿼리 생성 + String cypherQuery = cypherQueryGenerator.generate(userQuestion).trim() + .replaceAll("(?s)```cypher.*?```", "") // 마크다운 제거 + .replaceAll("```", "") // 남은 ``` 제거 + .trim(); + log.info("[GraphRAG] Generated Cypher Query:\n{}", cypherQuery); + + // 2. 쿼리 실행 → 문맥(context) 및 노드 라벨 추출 + List queryResults = graphQueryExecutor.runQuery(dbId, cypherQuery); + List contextChunks = queryResults.stream() + .map(GraphQueryResult::getSentence) + .toList(); + + List sourceNodes = queryResults.stream() + .map(GraphQueryResult::getNodeLabel) + .distinct() + .toList(); + log.info("[GraphRAG] Retrieved {} context chunks", contextChunks.size()); + + // 3. 프롬프트 구성 + String finalPrompt = promptBuilder.buildPrompt(contextChunks, userQuestion); + log.info("[GraphRAG] Final Prompt constructed"); + + // 4. RAG 응답 생성 + String response = contextChunks.isEmpty() + ? ragAnswerCreateService.chat(chatHistory, userQuestion) + : ragAnswerCreateService.chatWithContext(chatHistory, finalPrompt); + log.info("[GraphRAG] Response generated by LLM"); + + // 5. 응답 저장 + Chatting answer = Chatting.ofGPT(graph, response); + chattingRepository.save(answer); + log.info("[GraphRAG] Response saved to DB"); + + return CreateChatbotResponseDto.of( + response, + dbId.toString(), + answer.getCreatedAt(), + contextChunks, + sourceNodes + ); + } +} diff --git a/src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java b/src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java new file mode 100644 index 0000000..2da4c4b --- /dev/null +++ b/src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java @@ -0,0 +1,55 @@ +package com.going.server.domain.rag.service; + +import com.going.server.domain.chatbot.entity.Chatting; +import com.going.server.domain.chatbot.entity.Sender; +import com.going.server.domain.openai.service.OpenAIService; +import com.theokanning.openai.completion.chat.ChatMessage; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.List; + +// 4. GraphRAG 응답 생성 +@Service +@RequiredArgsConstructor +public class RagAnswerCreateService { + + private final OpenAIService openAIService; + + private static final String SYSTEM_PROMPT = """ + 당신은 초등학생의 이해를 돕는 친절하고 정확한 지식 튜터입니다. + - 아래 제공된 데이터를 기반으로 질문에 대해 매우 길고 정확하게 설명해주세요. + - 만약 참고 데이터가 없다면, 관련정보 없다고 하세요. + - 반드시 한글로만 응답하고, 인사말이나 불필요한 문장은 생략한 대답만 반환하세요. + """; + + private static final String MODEL_NAME = "gpt-4o"; + private static final double TEMPERATURE = 0.3; + private static final int MAX_TOKENS = 1500; + + public String chat(List chatHistory, String question) { + List messages = new ArrayList<>(); + messages.add(new ChatMessage("system", SYSTEM_PROMPT)); + messages.addAll(convertHistoryToMessages(chatHistory)); + messages.add(new ChatMessage("user", question)); + return openAIService.getCompletionResponse(messages, MODEL_NAME, TEMPERATURE, MAX_TOKENS); + } + + public String chatWithContext(List chatHistory, String finalPrompt) { + List messages = new ArrayList<>(); + messages.add(new ChatMessage("system", SYSTEM_PROMPT)); + messages.addAll(convertHistoryToMessages(chatHistory)); + messages.add(new ChatMessage("user", finalPrompt)); + return openAIService.getCompletionResponse(messages, MODEL_NAME, TEMPERATURE, MAX_TOKENS); + } + + private List convertHistoryToMessages(List chatHistory) { + return chatHistory.stream() + .map(chat -> new ChatMessage( + chat.getSender() == Sender.USER ? "user" : "assistant", + chat.getContent() + )) + .toList(); + } +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java b/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java index cefc705..b0d237d 100644 --- a/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java +++ b/src/main/java/com/going/server/domain/rag/service/SimilarityFilterService.java @@ -1,7 +1,15 @@ package com.going.server.domain.rag.service; + +import org.springframework.stereotype.Service; + import java.util.List; -public interface SimilarityFilterService { - List filterRelevantSentences(String question, List candidateSentences); -} +@Service +public class SimilarityFilterService { + + // 간단히 모든 문장을 통과시키는 기본 구현 (추후 유사도 필터링 적용 가능) + public List filterRelevantSentences(String userQuestion, List sentences) { + return sentences; + } +} \ No newline at end of file diff --git a/src/main/java/com/going/server/domain/rag/service/SimilarityFilterServiceImpl.java b/src/main/java/com/going/server/domain/rag/service/SimilarityFilterServiceImpl.java deleted file mode 100644 index b4b038b..0000000 --- a/src/main/java/com/going/server/domain/rag/service/SimilarityFilterServiceImpl.java +++ /dev/null @@ -1,17 +0,0 @@ -package com.going.server.domain.rag.service; - -import org.springframework.stereotype.Service; - -import java.util.List; - -@Service -// 유사도 검사 -public class SimilarityFilterServiceImpl implements SimilarityFilterService { - // TODO : 정확한 문맥 유사도 필터로 개선 필요 - @Override - public List filterRelevantSentences(String query, List candidates) { - return candidates.stream() - .filter(sentence -> sentence.toLowerCase().contains(query.toLowerCase())) - .toList(); - } -} diff --git a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java index a11c6f4..c52521e 100644 --- a/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java +++ b/src/main/java/com/going/server/domain/upload/service/UploadServiceImpl.java @@ -13,6 +13,9 @@ import com.going.server.domain.upload.dto.UploadResponseDto; import lombok.RequiredArgsConstructor; +import org.neo4j.driver.Session; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Result; import org.springframework.beans.factory.annotation.Value; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; @@ -33,6 +36,7 @@ public class UploadServiceImpl implements UploadService { private final PdfOcrService pdfOcrService; private final GraphNodeRepository graphNodeRepository; private final GraphRepository graphRepository; + private final Driver neo4jDriver; // Neo4j Java Driver @Value("${ocr.api.url}") private String apiUrl; @@ -41,19 +45,20 @@ public class UploadServiceImpl implements UploadService { @Value("${fastapi.base-url}") private String fastApiUrl; + private final Map translationCache = new HashMap<>(); private final Map imageCache = new HashMap<>(); @Override public UploadResponseDto uploadFile(UploadRequestDto dto) { try { + String jsonResponse = ocrService.processOcr(dto.getFile(), apiUrl, secretKey); log.info("jsonResponse log={}",jsonResponse); Map paresData = pdfOcrService.parse(jsonResponse); String text = paresData.get("읽기자료"); log.info("text log={}",text); - - //모델에 돌린 값을 받아옴 + //모델에 돌린 값을 받아옴 String response = setModelData(text); ObjectMapper mapper = new ObjectMapper(); @@ -95,7 +100,7 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { } String sourceId = edgeNode.get("source").asText(); String targetId = edgeNode.get("target").asText(); - String label = edgeNode.get("label").asText(); + String relationType = edgeNode.get("label").asText(); GraphNode sourceNode = nodeIdToNode.get(sourceId); GraphNode targetNode = nodeIdToNode.get(targetId); @@ -105,10 +110,20 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { continue; } + String dynamicCypher = String.format( + "MATCH (a:GraphNode {nodeId: '%s'}), (b:GraphNode {nodeId: '%s'}) MERGE (a)-[:`%s`]->(b)", + sourceId, targetId, relationType // 한글도 가능: "기능", "포함" 등 + ); + + // session 통해 직접 실행 + try (Session session = neo4jDriver.session()) { + session.run(dynamicCypher); + } + //edge 엔티티 생성 GraphEdge edge = GraphEdge.builder() .source(sourceId) - .label(label) + .label(relationType) .target(targetNode) .build(); @@ -116,6 +131,10 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { if (sourceNode.getEdges() == null) { sourceNode.setEdges(new HashSet<>()); } + + if (!sourceNode.getEdges().contains(edge)) { + sourceNode.getEdges().add(edge); + } sourceNode.getEdges().add(edge); } @@ -124,7 +143,9 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { //그래프 생성 String title = dto.getTitle(); + Long nextGraphId = graphRepository.findMaxGraphId(); Graph graphEntity = Graph.builder() + .id(nextGraphId == null ? 1L : nextGraphId + 1) //id 직접 세팅 .title(title) .content(text) .listenUpPerfect(false) @@ -153,6 +174,7 @@ public UploadResponseDto uploadFile(UploadRequestDto dto) { } } + // 모델 코드 호출 public String setModelData(String text) { WebClient webClient = WebClient.builder().baseUrl(fastApiUrl).build(); Map requestBody = new HashMap<>(); @@ -166,4 +188,5 @@ public String setModelData(String text) { .bodyToMono(String.class) .block(); } + }