66from langchain_core .messages import BaseMessage
77from langgraph .graph import StateGraph , END
88from langchain .output_parsers .pydantic import PydanticOutputParser
9+ from langchain_core .output_parsers import StrOutputParser
910from langchain .prompts import load_prompt
1011from schemas .sql_schemas import SqlQuery
1112from core .db_manager import db_instance
@@ -27,6 +28,7 @@ def resource_path(relative_path):
2728PROMPT_DIR = os .path .join ("prompts" , PROMPT_VERSION , "sql_agent" )
2829
2930# --- 프롬프트 로드 ---
31+ INTENT_CLASSIFIER_PROMPT = load_prompt (resource_path (os .path .join (PROMPT_DIR , "intent_classifier.yaml" )))
3032SQL_GENERATOR_PROMPT = load_prompt (resource_path (os .path .join (PROMPT_DIR , "sql_generator.yaml" )))
3133RESPONSE_SYNTHESIZER_PROMPT = load_prompt (resource_path (os .path .join (PROMPT_DIR , "response_synthesizer.yaml" )))
3234
@@ -35,6 +37,7 @@ class SqlAgentState(TypedDict):
3537 question : str
3638 chat_history : List [BaseMessage ]
3739 db_schema : str
40+ intent : str
3841 sql_query : str
3942 validation_error : Optional [str ]
4043 validation_error_count : int
@@ -43,6 +46,18 @@ class SqlAgentState(TypedDict):
4346 final_response : str
4447
4548# --- 노드 함수 정의 ---
49+ def intent_classifier_node (state : SqlAgentState ):
50+ print ("--- 0. 의도 분류 중 ---" )
51+ chain = INTENT_CLASSIFIER_PROMPT | llm_instance | StrOutputParser ()
52+ intent = chain .invoke ({"question" : state ['question' ]})
53+ state ['intent' ] = intent
54+ return state
55+
56+ def unsupported_question_node (state : SqlAgentState ):
57+ print ("--- SQL 관련 없는 질문 ---" )
58+ state ['final_response' ] = "죄송합니다, 해당 질문에는 답변할 수 없습니다. 데이터베이스 관련 질문만 가능합니다."
59+ return state
60+
4661def sql_generator_node (state : SqlAgentState ):
4762 print ("--- 1. SQL 생성 중 ---" )
4863 parser = PydanticOutputParser (pydantic_object = SqlQuery )
@@ -73,7 +88,7 @@ def sql_generator_node(state: SqlAgentState):
7388 )
7489
7590 response = llm_instance .invoke (prompt )
76- parsed_query = parser .invoke (response )
91+ parsed_query = parser .invoke (response . content )
7792 state ['sql_query' ] = parsed_query .query
7893 state ['validation_error' ] = None
7994 state ['execution_result' ] = None
@@ -145,6 +160,13 @@ def response_synthesizer_node(state: SqlAgentState):
145160 return state
146161
147162# --- 엣지 함수 정의 ---
163+ def route_after_intent_classification (state : SqlAgentState ):
164+ if state ['intent' ] == "SQL" :
165+ print ("--- 의도: SQL 관련 질문 ---" )
166+ return "sql_generator"
167+ print ("--- 의도: SQL과 관련 없는 질문 ---" )
168+ return "unsupported_question"
169+
148170def should_execute_sql (state : SqlAgentState ):
149171 if state .get ("validation_error_count" , 0 ) >= MAX_ERROR_COUNT :
150172 print (f"--- 검증 실패 { MAX_ERROR_COUNT } 회 초과: 답변 생성으로 이동 ---" )
@@ -168,12 +190,26 @@ def should_retry_or_respond(state: SqlAgentState):
168190# --- 그래프 구성 ---
169191def create_sql_agent_graph () -> StateGraph :
170192 graph = StateGraph (SqlAgentState )
193+
194+ graph .add_node ("intent_classifier" , intent_classifier_node )
195+ graph .add_node ("unsupported_question" , unsupported_question_node )
171196 graph .add_node ("sql_generator" , sql_generator_node )
172197 graph .add_node ("sql_validator" , sql_validator_node )
173198 graph .add_node ("sql_executor" , sql_executor_node )
174199 graph .add_node ("response_synthesizer" , response_synthesizer_node )
175200
176- graph .set_entry_point ("sql_generator" )
201+ graph .set_entry_point ("intent_classifier" )
202+
203+ graph .add_conditional_edges (
204+ "intent_classifier" ,
205+ route_after_intent_classification ,
206+ {
207+ "sql_generator" : "sql_generator" ,
208+ "unsupported_question" : "unsupported_question"
209+ }
210+ )
211+ graph .add_edge ("unsupported_question" , END )
212+
177213 graph .add_edge ("sql_generator" , "sql_validator" )
178214
179215 graph .add_conditional_edges ("sql_validator" , should_execute_sql , {
0 commit comments