Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions src/agents/sql_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def resource_path(relative_path):

# --- 프롬프트 로드 ---
INTENT_CLASSIFIER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "intent_classifier.yaml")))
DB_CLASSIFIER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "db_classifier.yaml")))
SQL_GENERATOR_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "sql_generator.yaml")))
RESPONSE_SYNTHESIZER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "response_synthesizer.yaml")))

Expand Down Expand Up @@ -58,6 +59,51 @@ def unsupported_question_node(state: SqlAgentState):
state['final_response'] = "죄송합니다, 해당 질문에는 답변할 수 없습니다. 데이터베이스 관련 질문만 가능합니다."
return state

def db_classifier_node(state: SqlAgentState):
print("--- 0.5. DB 분류 중 ---")

# TODO: BE API 호출로 대체 필요
available_dbs = [
{
"connection_name": "local_mysql",
"database_name": "sakila",
"description": "DVD 대여점 비즈니스 모델을 다루는 샘플 데이터베이스로, 영화, 배우, 고객, 대여 기록 등의 정보를 포함합니다."
},
{
"connection_name": "local_mysql",
"database_name": "ecom_prod",
"description": "온라인 쇼핑몰의 운영 데이터베이스로, 상품 카탈로그, 고객 주문, 재고 및 배송 정보를 관리합니다."
},
{
"connection_name": "local_mysql",
"database_name": "hr_analytics",
"description": "회사의 인사 관리 데이터베이스로, 직원 정보, 급여, 부서, 성과 평가 기록을 포함합니다."
},
{
"connection_name": "local_mysql",
"database_name": "web_logs",
"description": "웹사이트 트래픽 분석을 위한 로그 데이터베이스로, 사용자 방문 기록, 페이지 뷰, 에러 로그 등을 저장합니다."
}
]

db_options = "\n".join([f"- {db['database_name']}: {db['description']}" for db in available_dbs])

chain = DB_CLASSIFIER_PROMPT | llm_instance | StrOutputParser()
selected_db_name = chain.invoke({
"db_options": db_options,
"question": state['question']
})

state['selected_db'] = selected_db_name.strip()

# 선택된 DB의 스키마 정보를 가져와서 상태에 업데이트합니다.
print(f'--- 선택된 DB: {selected_db_name} ---')

# TODO: get_schema_for_db
state['db_schema'] = db_instance.get_schema_for_db(db_name=selected_db_name)

return state

def sql_generator_node(state: SqlAgentState):
print("--- 1. SQL 생성 중 ---")
parser = PydanticOutputParser(pydantic_object=SqlQuery)
Expand Down Expand Up @@ -163,7 +209,7 @@ def response_synthesizer_node(state: SqlAgentState):
def route_after_intent_classification(state: SqlAgentState):
if state['intent'] == "SQL":
print("--- 의도: SQL 관련 질문 ---")
return "sql_generator"
return "db_classifier"
print("--- 의도: SQL과 관련 없는 질문 ---")
return "unsupported_question"

Expand Down Expand Up @@ -192,6 +238,7 @@ def create_sql_agent_graph() -> StateGraph:
graph = StateGraph(SqlAgentState)

graph.add_node("intent_classifier", intent_classifier_node)
graph.add_node("db_classifier", db_classifier_node)
graph.add_node("unsupported_question", unsupported_question_node)
graph.add_node("sql_generator", sql_generator_node)
graph.add_node("sql_validator", sql_validator_node)
Expand All @@ -204,11 +251,13 @@ def create_sql_agent_graph() -> StateGraph:
"intent_classifier",
route_after_intent_classification,
{
"sql_generator": "sql_generator",
"db_classifier": "db_classifier",
"unsupported_question": "unsupported_question"
}
)
graph.add_edge("unsupported_question", END)

graph.add_edge("db_classifier", "sql_generator")

graph.add_edge("sql_generator", "sql_validator")

Expand Down
15 changes: 15 additions & 0 deletions src/prompts/v1/sql_agent/db_classifier.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_type: "prompt"
input_variables:
- db_options
- question
template: |
Based on the user's question, which of the following databases is most likely to contain the answer?
Please respond with only the database name.

Available databases:
{db_options}

User Question:
{question}

Selected Database: