Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 53 additions & 29 deletions src/agents/sql_agent_graph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,34 @@
# src/agents/sql_agent_graph.py

import os
import sys
from typing import List, TypedDict, Optional
from langchain_core.messages import BaseMessage
from langgraph.graph import StateGraph, END
from langchain.output_parsers.pydantic import PydanticOutputParser
from langchain.prompts import load_prompt
from schemas.sql_schemas import SqlQuery
from core.db_manager import db_instance
from core.llm_provider import llm_instance

# --- PyInstaller 경둜 ν•΄κ²° ν•¨μˆ˜ ---
def resource_path(relative_path):
try:
# PyInstaller creates a temp folder and stores path in _MEIPASS
base_path = sys._MEIPASS
except Exception:
# 개발 ν™˜κ²½μ—μ„œλŠ” src 폴더λ₯Ό κΈ°μ€€μœΌλ‘œ 경둜 μ„€μ •
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
return os.path.join(base_path, relative_path)

# --- μƒμˆ˜ μ •μ˜ ---
MAX_ERROR_COUNT = 3
PROMPT_VERSION = "v1"
PROMPT_DIR = os.path.join("prompts", PROMPT_VERSION, "sql_agent")

# --- ν”„λ‘¬ν”„νŠΈ λ‘œλ“œ ---
SQL_GENERATOR_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "sql_generator.yaml")))
RESPONSE_SYNTHESIZER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "response_synthesizer.yaml")))

# Agent μƒνƒœ μ •μ˜
class SqlAgentState(TypedDict):
Expand All @@ -27,7 +47,7 @@ def sql_generator_node(state: SqlAgentState):
print("--- 1. SQL 생성 쀑 ---")
parser = PydanticOutputParser(pydantic_object=SqlQuery)

# --- μ—λŸ¬ ν”Όλ“œλ°± μ»¨ν…μŠ€νŠΈ 생성 ---
# --- μ—λŸ¬ ν”Όλ“œλ°± μ»¨ν…μŠ€νŠΈ 생성 ---
error_feedback = ""
# 1. 검증 였λ₯˜κ°€ μžˆμ—ˆμ„ 경우
if state.get("validation_error") and state.get("validation_error_count", 0) > 0:
Expand All @@ -44,18 +64,13 @@ def sql_generator_node(state: SqlAgentState):
Please correct the SQL query based on the error.
"""

prompt = f"""
You are a powerful text-to-SQL model. Your role is to generate a SQL query based on the provided database schema and user question.

{parser.get_format_instructions()}

Schema: {state['db_schema']}
History: {state['chat_history']}

{error_feedback}

Question: {state['question']}
"""
prompt = SQL_GENERATOR_PROMPT.format(
format_instructions=parser.get_format_instructions(),
db_schema=state['db_schema'],
chat_history=state['chat_history'],
question=state['question'],
error_feedback=error_feedback
)

response = llm_instance.invoke(prompt)
parsed_query = parser.invoke(response)
Expand Down Expand Up @@ -98,29 +113,39 @@ def sql_executor_node(state: SqlAgentState):
def response_synthesizer_node(state: SqlAgentState):
print("--- 4. μ΅œμ’… λ‹΅λ³€ 생성 쀑 ---")

if state.get('validation_error_count', 0) >= MAX_ERROR_COUNT:
message = f"SQL 검증에 {MAX_ERROR_COUNT}회 이상 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. λ§ˆμ§€λ§‰ 였λ₯˜: {state.get('validation_error')}"
elif state.get('execution_error_count', 0) >= MAX_ERROR_COUNT:
message = f"SQL 싀행에 {MAX_ERROR_COUNT}회 이상 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. λ§ˆμ§€λ§‰ 였λ₯˜: {state.get('execution_result')}"
is_failure = state.get('validation_error_count', 0) >= MAX_ERROR_COUNT or \
state.get('execution_error_count', 0) >= MAX_ERROR_COUNT

if is_failure:
if state.get('validation_error_count', 0) >= MAX_ERROR_COUNT:
error_type = "SQL 검증"
error_details = state.get('validation_error')
else:
error_type = "SQL μ‹€ν–‰"
error_details = state.get('execution_result')

context_message = f"""
An attempt to answer the user's question failed after multiple retries.
Failure Type: {error_type}
Last Error: {error_details}
"""
else:
message = f"SQL Result: {state['execution_result']}"
context_message = f"""
Successfully executed the SQL query to answer the user's question.
SQL Query: {state['sql_query']}
SQL Result: {state['execution_result']}
"""

prompt = f"""
Question: {state['question']}
SQL: {state['sql_query']}
{message}

Based on the information above, provide a final answer to the user in Korean.
If there was an error, explain the problem to the user in a friendly way.
μ‚¬μš©μž 질문과 쿼리가 μ–΄λ–€ 관계가 μžˆλŠ”μ§€ 같이 μ„€λͺ…ν•΄
"""
prompt = RESPONSE_SYNTHESIZER_PROMPT.format(
question=state['question'],
context_message=context_message
)
response = llm_instance.invoke(prompt)
state['final_response'] = response.content
return state

# --- μ—£μ§€ ν•¨μˆ˜ μ •μ˜ ---
def should_execute_sql(state: SqlAgentState):
"""SQL 검증 ν›„, μ‹€ν–‰ν• μ§€/μž¬μƒμ„±ν• μ§€/포기할지 κ²°μ •ν•©λ‹ˆλ‹€."""
if state.get("validation_error_count", 0) >= MAX_ERROR_COUNT:
print(f"--- 검증 μ‹€νŒ¨ {MAX_ERROR_COUNT}회 초과: λ‹΅λ³€ μƒμ„±μœΌλ‘œ 이동 ---")
return "synthesize_failure"
Expand All @@ -131,7 +156,6 @@ def should_execute_sql(state: SqlAgentState):
return "execute"

def should_retry_or_respond(state: SqlAgentState):
"""SQL μ‹€ν–‰ ν›„, 성곡/μž¬μ‹œλ„/포기 μ—¬λΆ€λ₯Ό κ²°μ •ν•©λ‹ˆλ‹€."""
if state.get("execution_error_count", 0) >= MAX_ERROR_COUNT:
print(f"--- μ‹€ν–‰ μ‹€νŒ¨ {MAX_ERROR_COUNT}회 초과: λ‹΅λ³€ μƒμ„±μœΌλ‘œ 이동 ---")
return "synthesize_failure"
Expand Down
25 changes: 25 additions & 0 deletions src/prompts/v1/sql_agent/response_synthesizer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_type: prompt
input_variables:
- question
- context_message
template: |
You are a friendly and helpful database assistant chatbot.
Your goal is to provide a clear and easy-to-understand final answer to the user in Korean.
Please carefully analyze the user's question and the provided context below.

User's Question: {question}

Context:
{context_message}

Instructions:
- If the process was successful:
- Do not just show the raw data from the SQL result.
- Explain what the data means in relation to the user's question.
- Present the answer in a natural, conversational, and polite Korean.
- If the process failed:
- Apologize for the inconvenience.
- Explain the reason for the failure in simple, non-technical terms.
- Gently suggest trying a different or simpler question.

Final Answer (in Korean):
19 changes: 19 additions & 0 deletions src/prompts/v1/sql_agent/sql_generator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_type: prompt
input_variables:
- format_instructions
- db_schema
- chat_history
- question
- error_feedback
template: |
You are a powerful text-to-SQL model.
Your role is to generate a SQL query based on the provided database schema and user question.

Schema: {db_schema}
History: {chat_history}

{error_feedback}

Question: {question}

{format_instructions}