Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ dist
.vscode
.idea
test.ipynb
__pycache__
__pycache__
workflow_graph.png
100 changes: 88 additions & 12 deletions src/agents/sql_agent_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# app/agents/sql_agent_graph.py
# src/agents/sql_agent_graph.py

from typing import List, TypedDict, Optional
from langchain_core.messages import BaseMessage
Expand All @@ -8,72 +8,138 @@
from core.db_manager import db_instance
from core.llm_provider import llm_instance

MAX_ERROR_COUNT = 3

# Agent μƒνƒœ μ •μ˜
class SqlAgentState(TypedDict):
question: str
chat_history: List[BaseMessage]
db_schema: str
sql_query: str
validation_error: Optional[str]
validation_error_count: int
execution_result: Optional[str]
execution_error_count: int
final_response: str

# --- λ…Έλ“œ ν•¨μˆ˜ μ •μ˜ ---
def sql_generator_node(state: SqlAgentState):
print("--- 1. SQL 생성 쀑 ---")
parser = PydanticOutputParser(pydantic_object=SqlQuery)

# --- μ—λŸ¬ ν”Όλ“œλ°± μ»¨ν…μŠ€νŠΈ 생성 ---
error_feedback = ""
# 1. 검증 였λ₯˜κ°€ μžˆμ—ˆμ„ 경우
if state.get("validation_error") and state.get("validation_error_count", 0) > 0:
error_feedback = f"""
Your previous query was rejected for the following reason: {state['validation_error']}
Please generate a new, safe query that does not contain forbidden keywords.
"""
# 2. μ‹€ν–‰ 였λ₯˜κ°€ μžˆμ—ˆμ„ 경우
elif state.get("execution_result") and "였λ₯˜" in state.get("execution_result") and state.get("execution_error_count", 0) > 0:
error_feedback = f"""
Your previously generated SQL query failed with the following database error:
FAILED SQL: {state['sql_query']}
DATABASE ERROR: {state['execution_result']}
Please correct the SQL query based on the error.
"""

prompt = f"""
You are a powerful text-to-SQL model. Your role is to generate a SQL query based on the provided database schema and user question.

{parser.get_format_instructions()}

Schema: {state['db_schema']}
History: {state['chat_history']}

{error_feedback}

Question: {state['question']}
"""

response = llm_instance.invoke(prompt)
parsed_query = parser.invoke(response)
state['sql_query'] = parsed_query.query
state['validation_error'] = None
state['execution_result'] = None
return state

def sql_validator_node(state: SqlAgentState):
print("--- 2. SQL 검증 쀑 ---")
query = state['sql_query'].lower()
if "drop" in query or "delete" in query:
state['validation_error'] = "μœ„ν—˜ν•œ ν‚€μ›Œλ“œκ°€ ν¬ν•¨λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€."
query_words = state['sql_query'].lower().split()
dangerous_keywords = [
"drop", "delete", "update", "insert", "truncate",
"alter", "create", "grant", "revoke"
]
found_keywords = [keyword for keyword in dangerous_keywords if keyword in query_words]

if found_keywords:
keyword_str = ', '.join(f"'{k}'" for k in found_keywords)
state['validation_error'] = f'μœ„ν—˜ν•œ ν‚€μ›Œλ“œ {keyword_str}κ°€ ν¬ν•¨λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€.'
state['validation_error_count'] += 1 # sql 검증 횟수 μΆ”κ°€
else:
state['validation_error'] = None
state['validation_error_count'] = 0 # sql 검증 횟수 μ΄ˆκΈ°ν™”
return state

def sql_executor_node(state: SqlAgentState):
print("--- 3. SQL μ‹€ν–‰ 쀑 ---")
try:
result = db_instance.run(state['sql_query'])
state['execution_result'] = str(result)
state['validation_error_count'] = 0 # sql 검증 횟수 μ΄ˆκΈ°ν™”
state['execution_error_count'] = 0 # sql μ‹€ν–‰ 횟수 μ΄ˆκΈ°ν™”
except Exception as e:
state['execution_result'] = f"μ‹€ν–‰ 였λ₯˜: {e}"
state['validation_error_count'] = 0 # sql 검증 횟수 μ΄ˆκΈ°ν™”
state['execution_error_count'] += 1 # sql μ‹€ν–‰ 횟수 μΆ”κ°€
return state

def response_synthesizer_node(state: SqlAgentState):
print("--- 4. μ΅œμ’… λ‹΅λ³€ 생성 쀑 ---")

if state.get('validation_error_count', 0) >= MAX_ERROR_COUNT:
message = f"SQL 검증에 {MAX_ERROR_COUNT}회 이상 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. λ§ˆμ§€λ§‰ 였λ₯˜: {state.get('validation_error')}"
elif state.get('execution_error_count', 0) >= MAX_ERROR_COUNT:
message = f"SQL 싀행에 {MAX_ERROR_COUNT}회 이상 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. λ§ˆμ§€λ§‰ 였λ₯˜: {state.get('execution_result')}"
else:
message = f"SQL Result: {state['execution_result']}"

prompt = f"""
Question: {state['question']}
SQL Result: {state['execution_result']}
Final Answer:
SQL: {state['sql_query']}
{message}

Based on the information above, provide a final answer to the user in Korean.
If there was an error, explain the problem to the user in a friendly way.
μ‚¬μš©μž 질문과 쿼리가 μ–΄λ–€ 관계가 μžˆλŠ”μ§€ 같이 μ„€λͺ…ν•΄
"""
response = llm_instance.invoke(prompt)
state['final_response'] = response.content
return state

# --- μ—£μ§€ ν•¨μˆ˜ μ •μ˜ ---
def should_execute_sql(state: SqlAgentState):
return "regenerate" if state.get("validation_error") else "execute"
"""SQL 검증 ν›„, μ‹€ν–‰ν• μ§€/μž¬μƒμ„±ν• μ§€/포기할지 κ²°μ •ν•©λ‹ˆλ‹€."""
if state.get("validation_error_count", 0) >= MAX_ERROR_COUNT:
print(f"--- 검증 μ‹€νŒ¨ {MAX_ERROR_COUNT}회 초과: λ‹΅λ³€ μƒμ„±μœΌλ‘œ 이동 ---")
return "synthesize_failure"
if state.get("validation_error"):
print(f"--- 검증 μ‹€νŒ¨ {state['validation_error_count']}회: SQL μž¬μƒμ„± ---")
return "regenerate"
print("--- 검증 성곡: SQL μ‹€ν–‰ ---")
return "execute"

def should_retry_or_respond(state: SqlAgentState):
return "regenerate" if "였λ₯˜" in (state.get("execution_result") or "") else "synthesize"
"""SQL μ‹€ν–‰ ν›„, 성곡/μž¬μ‹œλ„/포기 μ—¬λΆ€λ₯Ό κ²°μ •ν•©λ‹ˆλ‹€."""
if state.get("execution_error_count", 0) >= MAX_ERROR_COUNT:
print(f"--- μ‹€ν–‰ μ‹€νŒ¨ {MAX_ERROR_COUNT}회 초과: λ‹΅λ³€ μƒμ„±μœΌλ‘œ 이동 ---")
return "synthesize_failure"
if "였λ₯˜" in (state.get("execution_result") or ""):
print(f"--- μ‹€ν–‰ μ‹€νŒ¨ {state['execution_error_count']}회: SQL μž¬μƒμ„± ---")
return "regenerate"
print("--- μ‹€ν–‰ 성곡: μ΅œμ’… λ‹΅λ³€ 생성 ---")
return "synthesize_success"

# --- κ·Έλž˜ν”„ ꡬ성 ---
def create_sql_agent_graph() -> StateGraph:
Expand All @@ -85,14 +151,24 @@ def create_sql_agent_graph() -> StateGraph:

graph.set_entry_point("sql_generator")
graph.add_edge("sql_generator", "sql_validator")

graph.add_conditional_edges("sql_validator", should_execute_sql, {
"regenerate": "sql_generator", "execute": "sql_executor"
"regenerate": "sql_generator",
"execute": "sql_executor",
"synthesize_failure": "response_synthesizer"
})
graph.add_conditional_edges("sql_executor", should_retry_or_respond, {
"regenerate": "sql_generator", "synthesize": "response_synthesizer"
"regenerate": "sql_generator",
"synthesize_success": "response_synthesizer",
"synthesize_failure": "response_synthesizer"
})
graph.add_edge("response_synthesizer", END)

return graph.compile()

sql_agent_app = create_sql_agent_graph()
sql_agent_app = create_sql_agent_graph()

# μ›Œν¬ ν”Œλ‘œμš° κ·Έλ¦Ό μž‘μ„±
# graph_image_bytes = sql_agent_app.get_graph(xray=True).draw_mermaid_png()
# with open("workflow_graph.png", "wb") as f:
# f.write(graph_image_bytes)
6 changes: 4 additions & 2 deletions src/services/chatbot_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from agents.sql_agent_graph import sql_agent_app
from core.db_manager import schema_instance

class ChatbotService:
class ChatbotService():
def __init__(self):
self.db_schema = schema_instance

Expand All @@ -12,7 +12,9 @@ def handle_request(self, user_question: str) -> str:
initial_state = {
"question": user_question,
"chat_history": [],
"db_schema": self.db_schema
"db_schema": self.db_schema,
"validation_error_count": 0,
"execution_error_count": 0
}

# 2. κ·Έλž˜ν”„ μ‹€ν–‰
Expand Down