Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ pyzmq==27.0.0
regex==2024.11.6
requests==2.32.4
requests-toolbelt==1.0.0
setuptools==80.9.0
six==1.17.0
sniffio==1.3.1
SQLAlchemy==2.0.41
Expand Down
40 changes: 38 additions & 2 deletions src/agents/sql_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.messages import BaseMessage
from langgraph.graph import StateGraph, END
from langchain.output_parsers.pydantic import PydanticOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import load_prompt
from schemas.sql_schemas import SqlQuery
from core.db_manager import db_instance
Expand All @@ -27,6 +28,7 @@ def resource_path(relative_path):
PROMPT_DIR = os.path.join("prompts", PROMPT_VERSION, "sql_agent")

# --- 프롬프트 로드 ---
INTENT_CLASSIFIER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "intent_classifier.yaml")))
SQL_GENERATOR_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "sql_generator.yaml")))
RESPONSE_SYNTHESIZER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "response_synthesizer.yaml")))

Expand All @@ -35,6 +37,7 @@ class SqlAgentState(TypedDict):
question: str
chat_history: List[BaseMessage]
db_schema: str
intent: str
sql_query: str
validation_error: Optional[str]
validation_error_count: int
Expand All @@ -43,6 +46,18 @@ class SqlAgentState(TypedDict):
final_response: str

# --- 노드 함수 정의 ---
def intent_classifier_node(state: SqlAgentState):
print("--- 0. 의도 분류 중 ---")
chain = INTENT_CLASSIFIER_PROMPT | llm_instance | StrOutputParser()
intent = chain.invoke({"question": state['question']})
state['intent'] = intent
return state

def unsupported_question_node(state: SqlAgentState):
print("--- SQL 관련 없는 질문 ---")
state['final_response'] = "죄송합니다, 해당 질문에는 답변할 수 없습니다. 데이터베이스 관련 질문만 가능합니다."
return state

def sql_generator_node(state: SqlAgentState):
print("--- 1. SQL 생성 중 ---")
parser = PydanticOutputParser(pydantic_object=SqlQuery)
Expand Down Expand Up @@ -73,7 +88,7 @@ def sql_generator_node(state: SqlAgentState):
)

response = llm_instance.invoke(prompt)
parsed_query = parser.invoke(response)
parsed_query = parser.invoke(response.content)
state['sql_query'] = parsed_query.query
state['validation_error'] = None
state['execution_result'] = None
Expand Down Expand Up @@ -145,6 +160,13 @@ def response_synthesizer_node(state: SqlAgentState):
return state

# --- 엣지 함수 정의 ---
def route_after_intent_classification(state: SqlAgentState):
if state['intent'] == "SQL":
print("--- 의도: SQL 관련 질문 ---")
return "sql_generator"
print("--- 의도: SQL과 관련 없는 질문 ---")
return "unsupported_question"

def should_execute_sql(state: SqlAgentState):
if state.get("validation_error_count", 0) >= MAX_ERROR_COUNT:
print(f"--- 검증 실패 {MAX_ERROR_COUNT}회 초과: 답변 생성으로 이동 ---")
Expand All @@ -168,12 +190,26 @@ def should_retry_or_respond(state: SqlAgentState):
# --- 그래프 구성 ---
def create_sql_agent_graph() -> StateGraph:
graph = StateGraph(SqlAgentState)

graph.add_node("intent_classifier", intent_classifier_node)
graph.add_node("unsupported_question", unsupported_question_node)
graph.add_node("sql_generator", sql_generator_node)
graph.add_node("sql_validator", sql_validator_node)
graph.add_node("sql_executor", sql_executor_node)
graph.add_node("response_synthesizer", response_synthesizer_node)

graph.set_entry_point("sql_generator")
graph.set_entry_point("intent_classifier")

graph.add_conditional_edges(
"intent_classifier",
route_after_intent_classification,
{
"sql_generator": "sql_generator",
"unsupported_question": "unsupported_question"
}
)
graph.add_edge("unsupported_question", END)

graph.add_edge("sql_generator", "sql_validator")

graph.add_conditional_edges("sql_validator", should_execute_sql, {
Expand Down
30 changes: 30 additions & 0 deletions src/prompts/v1/sql_agent/intent_classifier.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

_type: prompt
input_variables:
- question
template: |
You are an intelligent assistant responsible for classifying user questions.
Your task is to determine whether a user's question is related to retrieving information from a database using SQL.

- If the question can be answered with a SQL query, respond with "SQL".
- If the question is a simple greeting, a question about your identity, or anything that does not require database access, respond with "non-SQL".

Example 1:
Question: "Show me the list of users who signed up last month."
Classification: SQL

Example 2:
Question: "What is the total revenue for the last quarter?"
Classification: SQL

Example 3:
Question: "Hello, who are you?"
Classification: non-SQL

Example 4:
Question: "What is the weather like today?"
Classification: non-SQL

Now, classify the following question:
Question: {question}
Classification: