Skip to content

Commit 1f57e9b

Browse files
authored
Merge pull request #14 from Queryus/feature/intent-classifier
Feature/intent classifier
2 parents abeaee7 + 1fe69ba commit 1f57e9b

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ pyzmq==27.0.0
8080
regex==2024.11.6
8181
requests==2.32.4
8282
requests-toolbelt==1.0.0
83+
setuptools==80.9.0
8384
six==1.17.0
8485
sniffio==1.3.1
8586
SQLAlchemy==2.0.41

src/agents/sql_agent_graph.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from langchain_core.messages import BaseMessage
77
from langgraph.graph import StateGraph, END
88
from langchain.output_parsers.pydantic import PydanticOutputParser
9+
from langchain_core.output_parsers import StrOutputParser
910
from langchain.prompts import load_prompt
1011
from schemas.sql_schemas import SqlQuery
1112
from core.db_manager import db_instance
@@ -27,6 +28,7 @@ def resource_path(relative_path):
2728
PROMPT_DIR = os.path.join("prompts", PROMPT_VERSION, "sql_agent")
2829

2930
# --- 프롬프트 로드 ---
31+
INTENT_CLASSIFIER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "intent_classifier.yaml")))
3032
SQL_GENERATOR_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "sql_generator.yaml")))
3133
RESPONSE_SYNTHESIZER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "response_synthesizer.yaml")))
3234

@@ -35,6 +37,7 @@ class SqlAgentState(TypedDict):
3537
question: str
3638
chat_history: List[BaseMessage]
3739
db_schema: str
40+
intent: str
3841
sql_query: str
3942
validation_error: Optional[str]
4043
validation_error_count: int
@@ -43,6 +46,18 @@ class SqlAgentState(TypedDict):
4346
final_response: str
4447

4548
# --- 노드 함수 정의 ---
49+
def intent_classifier_node(state: SqlAgentState):
50+
print("--- 0. 의도 분류 중 ---")
51+
chain = INTENT_CLASSIFIER_PROMPT | llm_instance | StrOutputParser()
52+
intent = chain.invoke({"question": state['question']})
53+
state['intent'] = intent
54+
return state
55+
56+
def unsupported_question_node(state: SqlAgentState):
57+
print("--- SQL 관련 없는 질문 ---")
58+
state['final_response'] = "죄송합니다, 해당 질문에는 답변할 수 없습니다. 데이터베이스 관련 질문만 가능합니다."
59+
return state
60+
4661
def sql_generator_node(state: SqlAgentState):
4762
print("--- 1. SQL 생성 중 ---")
4863
parser = PydanticOutputParser(pydantic_object=SqlQuery)
@@ -73,7 +88,7 @@ def sql_generator_node(state: SqlAgentState):
7388
)
7489

7590
response = llm_instance.invoke(prompt)
76-
parsed_query = parser.invoke(response)
91+
parsed_query = parser.invoke(response.content)
7792
state['sql_query'] = parsed_query.query
7893
state['validation_error'] = None
7994
state['execution_result'] = None
@@ -145,6 +160,13 @@ def response_synthesizer_node(state: SqlAgentState):
145160
return state
146161

147162
# --- 엣지 함수 정의 ---
163+
def route_after_intent_classification(state: SqlAgentState):
164+
if state['intent'] == "SQL":
165+
print("--- 의도: SQL 관련 질문 ---")
166+
return "sql_generator"
167+
print("--- 의도: SQL과 관련 없는 질문 ---")
168+
return "unsupported_question"
169+
148170
def should_execute_sql(state: SqlAgentState):
149171
if state.get("validation_error_count", 0) >= MAX_ERROR_COUNT:
150172
print(f"--- 검증 실패 {MAX_ERROR_COUNT}회 초과: 답변 생성으로 이동 ---")
@@ -168,12 +190,26 @@ def should_retry_or_respond(state: SqlAgentState):
168190
# --- 그래프 구성 ---
169191
def create_sql_agent_graph() -> StateGraph:
170192
graph = StateGraph(SqlAgentState)
193+
194+
graph.add_node("intent_classifier", intent_classifier_node)
195+
graph.add_node("unsupported_question", unsupported_question_node)
171196
graph.add_node("sql_generator", sql_generator_node)
172197
graph.add_node("sql_validator", sql_validator_node)
173198
graph.add_node("sql_executor", sql_executor_node)
174199
graph.add_node("response_synthesizer", response_synthesizer_node)
175200

176-
graph.set_entry_point("sql_generator")
201+
graph.set_entry_point("intent_classifier")
202+
203+
graph.add_conditional_edges(
204+
"intent_classifier",
205+
route_after_intent_classification,
206+
{
207+
"sql_generator": "sql_generator",
208+
"unsupported_question": "unsupported_question"
209+
}
210+
)
211+
graph.add_edge("unsupported_question", END)
212+
177213
graph.add_edge("sql_generator", "sql_validator")
178214

179215
graph.add_conditional_edges("sql_validator", should_execute_sql, {
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
_type: prompt
3+
input_variables:
4+
- question
5+
template: |
6+
You are an intelligent assistant responsible for classifying user questions.
7+
Your task is to determine whether a user's question is related to retrieving information from a database using SQL.
8+
9+
- If the question can be answered with a SQL query, respond with "SQL".
10+
- If the question is a simple greeting, a question about your identity, or anything that does not require database access, respond with "non-SQL".
11+
12+
Example 1:
13+
Question: "Show me the list of users who signed up last month."
14+
Classification: SQL
15+
16+
Example 2:
17+
Question: "What is the total revenue for the last quarter?"
18+
Classification: SQL
19+
20+
Example 3:
21+
Question: "Hello, who are you?"
22+
Classification: non-SQL
23+
24+
Example 4:
25+
Question: "What is the weather like today?"
26+
Classification: non-SQL
27+
28+
Now, classify the following question:
29+
Question: {question}
30+
Classification:

0 commit comments

Comments
 (0)