diff --git a/.github/workflows/deploy_executables.yml b/.github/workflows/deploy_executables.yml index a9339ce..2317b2e 100644 --- a/.github/workflows/deploy_executables.yml +++ b/.github/workflows/deploy_executables.yml @@ -4,8 +4,27 @@ name: Build and Deploy Executables # 워크플로우의 전체 이름 on: release: types: [published] + workflow_dispatch: jobs: + # ================================== + # 1. 파이프라인 시작 알림 + # ================================== + start: + if: github.event_name == 'release' + runs-on: ubuntu-latest + steps: + - name: Send Pipeline Start Notification + run: | + curl -X POST -H "Content-Type: application/json" \ + -d '{ + "username": "AI 배포 봇", + "embeds": [{ + "description": "**${{ github.ref_name }}** AI 배포를 시작합니다.", + "color": 2243312 + }] + }' \ + ${{ secrets.DISCORD_WEBHOOK_URL }} # 1단계: 각 OS에서 실행 파일을 빌드하는 잡 build: strategy: @@ -44,14 +63,14 @@ jobs: shell: bash run: | if [ "${{ runner.os }}" == "macOS" ]; then - echo "EXE_NAME=askql-ai" >> $GITHUB_ENV + echo "EXE_NAME=qgenie-ai" >> $GITHUB_ENV elif [ "${{ runner.os }}" == "Windows" ]; then - echo "EXE_NAME=askql-ai.exe" >> $GITHUB_ENV + echo "EXE_NAME=qgenie-ai.exe" >> $GITHUB_ENV fi # 6. PyInstaller를 사용해 파이썬 코드를 실행 파일로 만듭니다. - name: Build executable with PyInstaller - run: pyinstaller src/main.py --name ${{ env.EXE_NAME }} --onefile --noconsole + run: pyinstaller --clean --onefile --name ${{ env.EXE_NAME }} src/main.py # 7. 빌드된 실행 파일을 다음 단계(deploy)에서 사용할 수 있도록 아티팩트로 업로드합니다. - name: Upload artifact @@ -59,17 +78,20 @@ jobs: with: name: executable-${{ runner.os }} path: dist/${{ env.EXE_NAME }} + retention-days: 1 + # 2단계: 빌드된 실행 파일들을 Front 레포지토리에 배포하는 잡 deploy: # build 잡이 성공해야 실행됨 needs: build + if: github.event_name == 'release' runs-on: ubuntu-latest steps: # 1. 배포 대상인 Front 리포지토리의 코드를 가져옵니다. - - name: Checkout Front Repository + - name: Checkout App Repository uses: actions/checkout@v4 with: - repository: AskQL/Front + repository: Queryus/QGenie_app token: ${{ secrets.PAT_FOR_FRONT_REPO }} # 배포할 브랜치를 develop으로 변경 ref: develop @@ -78,15 +100,14 @@ jobs: - name: Download all artifacts uses: actions/download-artifact@v4 with: - # artifacts 폴더에 모든 아티팩트를 다운로드 path: artifacts # 3. 다운로드한 실행 파일들을 정해진 폴더(resources/mac, resources/win)로 이동시킵니다. - name: Organize files run: | mkdir -p resources/mac resources/win - mv artifacts/executable-macOS/askql-ai resources/mac/ - mv artifacts/executable-Windows/askql-ai.exe resources/win/ + mv artifacts/executable-macOS/qgenie-ai resources/mac/ + mv artifacts/executable-Windows/qgenie-ai.exe resources/win/ # 4. 변경된 파일들을 Front 리포지토리에 커밋하고 푸시합니다. - name: Commit and push changes @@ -98,6 +119,60 @@ jobs: if git diff-index --quiet HEAD; then echo "No changes to commit." else - git commit -m "feat: AI 실행 파일 업데이트 (${{ github.ref_name }})" + git commit -m "feat: Update AI executable (${{ github.ref_name }})" git push fi + + # ================================== + # 파이프라인 최종 결과 알림 + # ================================== + finish: + needs: deploy + runs-on: ubuntu-latest + if: always() && github.event_name == 'release' + + steps: + - name: Send Success Notification + if: needs.deploy.result == 'success' + run: | + curl -X POST -H "Content-Type: application/json" \ + -d '{ + "username": "AI 배포 봇", + "embeds": [{ + "title": "New AI Release: ${{ github.ref_name }}", + "url": "${{ github.event.release.html_url }}", + "description": "**${{ github.ref_name }}** AI 배포가 성공적으로 완료되었습니다!", + "color": 5167473 + }] + }' \ + ${{ secrets.DISCORD_WEBHOOK_URL }} + + - name: Send Failure Notification + if: contains(needs.*.result, 'failure') + run: | + curl -X POST -H "Content-Type: application/json" \ + -d '{ + "username": "AI 배포 봇", + "embeds": [{ + "title": "AI 배포 실패", + "url": "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}", + "description": "**${{ github.ref_name }}** AI 배포 중 오류가 발생했습니다.", + "color": 15219495 + }] + }' \ + ${{ secrets.DISCORD_WEBHOOK_URL }} + + - name: Send Skipped or Cancelled Notification + if: contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') + run: | + curl -X POST -H "Content-Type: application/json" \ + -d '{ + "username": "AI 배포 봇", + "embeds": [{ + "title": "AI 배포 미완료", + "url": "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}", + "description": "**${{ github.ref_name }}** AI 배포가 완료되지 않았습니다. (상태: 취소 또는 건너뜀)\n이전 단계에서 문제가 발생했을 수 있습니다.", + "color": 16577629 + }] + }' \ + ${{ secrets.DISCORD_WEBHOOK_URL }} diff --git a/.github/workflows/pr_bot.yml b/.github/workflows/pr_bot.yml new file mode 100644 index 0000000..22daee4 --- /dev/null +++ b/.github/workflows/pr_bot.yml @@ -0,0 +1,86 @@ +# .github/workflows/pr_bot.yml +name: Pull Request Bot + +on: + # Pull Request 관련 이벤트 발생 시 + pull_request: + types: [opened, closed, reopened, synchronize] + issue_comment: + types: [created] + +jobs: + notify: + runs-on: ubuntu-latest + steps: + # ------------------------- + # 생성/동기화 알림 + # ------------------------- + - name: Send PR Created Notification + if: github.event_name == 'pull_request' && (github.event.action == 'opened' || github.event.action == 'synchronize') + run: | + curl -X POST -H "Content-Type: application/json" \ + -d '{ + "username": "GitHub PR 봇", + "embeds": [{ + "title": "Pull Request #${{ github.event.pull_request.number }}: ${{ github.event.pull_request.title }}", + "description": "**${{ github.actor }}**님이 Pull Request를 생성하거나 업데이트했습니다.", + "url": "${{ github.event.pull_request.html_url }}", + "color": 2243312 + }] + }' \ + ${{ secrets.DISCORD_WEBHOOK_URL }} + + # ------------------------- + # 댓글 알림 + # ------------------------- + - name: Send PR Comment Notification + if: github.event_name == 'issue_comment' && github.event.issue.pull_request + run: | + COMMENT_BODY=$(echo "${{ github.event.comment.body }}" | sed 's/"/\\"/g' | sed ':a;N;$!ba;s/\n/\\n/g') + curl -X POST -H "Content-Type: application/json" \ + -d "{ + \"username\": \"GitHub 댓글 봇\", + \"embeds\": [{ + \"title\": \"New Comment on PR #${{ github.event.issue.number }}\", + \"description\": \"**${{ github.actor }}**님의 새 댓글: \\n${COMMENT_BODY}\", + \"url\": \"${{ github.event.comment.html_url }}\", + \"color\": 15105570 + }] + }" \ + ${{ secrets.DISCORD_WEBHOOK_URL }} + + # ------------------------- + # 머지(Merge) 알림 + # ------------------------- + - name: Send PR Merged Notification + if: github.event.action == 'closed' && github.event.pull_request.merged == true + run: | + curl -X POST -H "Content-Type: application/json" \ + -d '{ + "username": "GitHub Merge 봇", + "embeds": [{ + "title": "Pull Request #${{ github.event.pull_request.number }} Merged!", + "description": "**${{ github.actor }}**님이 **${{ github.event.pull_request.title }}** PR을 머지했습니다.", + "url": "${{ github.event.pull_request.html_url }}", + "color": 5167473 + }] + }' \ + ${{ secrets.DISCORD_WEBHOOK_URL }} + + # ------------------------- + # 닫힘(Close) 알림 + # ------------------------- + - name: Send PR Closed Notification + if: github.event.action == 'closed' && github.event.pull_request.merged == false + run: | + curl -X POST -H "Content-Type: application/json" \ + -d '{ + "username": "GitHub PR 봇", + "embeds": [{ + "title": "Pull Request #${{ github.event.pull_request.number }} Closed", + "description": "**${{ github.actor }}**님이 **${{ github.event.pull_request.title }}** PR을 닫았습니다.", + "url": "${{ github.event.pull_request.html_url }}", + "color": 15219495 + }] + }' \ + ${{ secrets.DISCORD_WEBHOOK_URL }} diff --git a/.gitignore b/.gitignore index 2fa36ba..6d83b6d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ # 2. 가상환경 파일 *venv* +.env # 3. 빌드 결과물 (Build output) *.spec @@ -11,4 +12,7 @@ dist # 4. 기타 .vscode -.idea \ No newline at end of file +.idea +test.ipynb +__pycache__ +workflow_graph.png diff --git a/README.md b/README.md index ab17ebd..75f36c7 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # AI 모델 프로젝트 -이 프로젝트는 ...을 위한 AI 모델을 개발합니다. +이 프로젝트는 DB 어노테이션과 SQL 챗봇을 위한 AI 서버를 개발합니다. --- @@ -12,7 +12,7 @@ 1. **저장소 복제** ```bash - git clone https://github.com/AskQL/AI.git + git clone https://github.com/Queryus/QGenie_ai.git ``` 2. **가상 환경 생성 및 활성화** @@ -20,8 +20,11 @@ # 가상 환경 생성 (최초 한 번) python3 -m venv .venv - # 가상 환경 활성화 + # 가상 환경 활성화 (macOS/Linux) source .venv/bin/activate + + # 가상 환경 활성화 (Windows) + .venv\Scripts\activate ``` 3. **라이브러리 설치** @@ -36,10 +39,7 @@ rm -rf build dist # 실행 파일 빌드 - pyinstaller src/main.py --name ai --onefile --noconsole - - # 실행 파일 실행 - ./dist/ai + pyinstaller --clean --onefile --add-data "src/prompts:prompts" --name ai src/main.py ``` --- @@ -59,21 +59,17 @@ GitHub에서 새로운 태그를 발행하면 파이프라인이 자동으로 ~~~markdown 1. 모든 기능 개발과 테스트가 완료된 코드를 main 브랜치에 병합(Merge)합니다. -2. AskQL/AI 저장소의 Releases 탭으로 이동하여 Draft a new release 버튼을 클릭합니다. -3. Choose a tag 항목에서 v1.0.0과 같이 새로운 버전 태그를 입력하고 생성합니다. -4. (가장 중요 ⭐) Target 드롭다운 메뉴에서 반드시 main 브랜치를 선택합니다. -5. 릴리즈 노트를 작성하고 Publish release 버튼을 클릭합니다. -6. Target 드롭다운 메뉴에서 반드시 main 브랜치를 선택합니다. -7. 릴리즈 노트를 작성하고 Publish release 버튼을 클릭합니다. +2. 레포지토리에서 Releases 탭으로 이동하여 Create a new release 버튼을 클릭합니다. +3. Choose a tag 항목을 클릭한후 Find or create a new tag 부분에 버전(v1.0.0)과 같이 새로운 버전 태그를 입력하고 아래 Create new tag를 클릭하여 태그를 생성합니다. +4. ⭐중요) Target 드롭다운 메뉴에서 반드시 main 브랜치를 선택합니다. +5. 제목에 버전을 입력하고 릴리즈 노트를 작성합니다 +6. 🚨주의) Publish release 버튼을 클릭합니다. + 릴리즈 발행은 되돌릴 수 없습니다. + 잘못된 릴리즈는 서비스에 직접적인 영향을 줄 수 있으니, 반드시 팀의 승인을 받고 신중하게 진행해 주십시오. +7. Actions 탭에 들어가 파이프라인을 확인 후 정상 배포되었다면 App 레포에 develop 브랜치에서 실행 파일을 확인합니다. +8. 만약 실패하였다면 인프라 담당자에게 말해주세요. ~~~ -버튼을 클릭하는 즉시, 아래의 경고 사항에 설명된 자동 배포 프로세스가 시작됩니다. - -🛑 경고: 릴리즈 발행은 되돌릴 수 없습니다 -GitHub에서 새로운 릴리즈를 발행하면, 자동으로 프로덕션 배포가 시작됩니다. - -이 과정은 중간에 멈출 수 없으며, main 브랜치의 현재 상태가 즉시 Front 리포지토리의 develop 브랜치에 반영됩니다. -잘못된 릴리즈는 서비스에 직접적인 영향을 줄 수 있으니, 반드시 팀의 승인을 받고 신중하게 진행해 주십시오. --- @@ -82,11 +78,9 @@ GitHub에서 새로운 릴리즈를 발행하면, 자동으로 프로덕션 배 이 레포지토리에는 간단한 헬스체크 서버가 포함되어 있습니다. 아래 명령어로 실행 파일을 빌드하여 서버 상태를 확인할 수 있습니다. ```bash -# 실행 파일 빌드 -pyinstaller src/main.py --name ai --onefile --noconsole - # 빌드된 파일 실행 (dist 폴더에 생성됨) ./dist/ai -curl http://localhost:33332/health -``` \ No newline at end of file +# 다른 터미널에서 헬스체크 요청 +curl http://localhost:<할당된 포트>/api/v1/health +``` diff --git a/requirements.txt b/requirements.txt index e6efb79..e484c4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,32 +1,106 @@ +aiohappyeyeballs==2.6.1 +aiohttp==3.12.13 +aiosignal==1.4.0 +altgraph==0.17.4 annotated-types==0.7.0 anyio==4.9.0 +appnope==0.1.4 +asttokens==3.0.0 +attrs==25.3.0 +blinker==1.9.0 certifi==2025.6.15 +cffi==1.17.1 charset-normalizer==3.4.2 +click==8.2.1 +comm==0.2.2 +cryptography==45.0.5 +dataclasses-json==0.6.7 +debugpy==1.8.14 +decorator==5.2.1 +distro==1.9.0 +executing==2.2.0 +fastapi==0.116.0 +Flask==3.1.1 +frozenlist==1.7.0 h11==0.16.0 httpcore==1.0.9 httpx==0.28.1 +httpx-sse==0.4.1 idna==3.10 +ipykernel==6.29.5 +ipython==9.4.0 +ipython_pygments_lexers==1.1.1 +itsdangerous==2.2.0 +jedi==0.19.2 +Jinja2==3.1.6 +jiter==0.10.0 jsonpatch==1.33 jsonpointer==3.0.0 +jupyter_client==8.6.3 +jupyter_core==5.8.1 langchain==0.3.26 +langchain-community==0.3.27 langchain-core==0.3.67 +langchain-openai==0.3.27 langchain-text-splitters==0.3.8 +langgraph==0.5.1 +langgraph-checkpoint==2.1.0 +langgraph-prebuilt==0.5.2 +langgraph-sdk==0.1.72 langsmith==0.4.4 +macholib==1.16.3 +MarkupSafe==3.0.2 +marshmallow==3.26.1 +matplotlib-inline==0.1.7 +multidict==6.6.3 +mypy_extensions==1.1.0 +nest-asyncio==1.6.0 +numpy==2.3.1 +openai==1.93.0 orjson==3.10.18 +ormsgpack==1.10.0 packaging==24.2 +parso==0.8.4 +pexpect==4.9.0 +platformdirs==4.3.8 +prompt_toolkit==3.0.51 +propcache==0.3.2 +psutil==7.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pycparser==2.22 pydantic==2.11.7 +pydantic-settings==2.10.1 pydantic_core==2.33.2 +Pygments==2.19.2 +pyinstaller==6.14.2 +pyinstaller-hooks-contrib==2025.5 +PyMySQL==1.1.1 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 PyYAML==6.0.2 +pyzmq==27.0.0 +regex==2024.11.6 requests==2.32.4 requests-toolbelt==1.0.0 +setuptools==80.9.0 +six==1.17.0 sniffio==1.3.1 SQLAlchemy==2.0.41 +stack-data==0.6.3 +starlette==0.46.2 tenacity==9.1.2 +tiktoken==0.9.0 +tornado==6.5.1 +tqdm==4.67.1 +traitlets==5.14.3 +typing-inspect==0.9.0 typing-inspection==0.4.1 typing_extensions==4.14.0 urllib3==2.5.0 +uvicorn==0.35.0 +wcwidth==0.2.13 +Werkzeug==3.1.3 +xxhash==3.5.0 +yarl==1.20.1 zstandard==0.23.0 - -# infra -Flask -pyinstaller diff --git a/sql_agent_workflow.png b/sql_agent_workflow.png new file mode 100644 index 0000000..cecf785 Binary files /dev/null and b/sql_agent_workflow.png differ diff --git a/src/agents/__init__.py b/src/agents/__init__.py new file mode 100644 index 0000000..f677fbe --- /dev/null +++ b/src/agents/__init__.py @@ -0,0 +1,6 @@ +""" +에이전트 루트 패키지 +""" + + + diff --git a/src/agents/sql_agent/__init__.py b/src/agents/sql_agent/__init__.py new file mode 100644 index 0000000..cd71571 --- /dev/null +++ b/src/agents/sql_agent/__init__.py @@ -0,0 +1,27 @@ +# src/agents/sql_agent/__init__.py + +from .state import SqlAgentState +from .nodes import SqlAgentNodes +from .edges import SqlAgentEdges +from .graph import SqlAgentGraph +from .exceptions import ( + SqlAgentException, + ValidationException, + ExecutionException, + DatabaseConnectionException, + LLMProviderException, + MaxRetryExceededException +) + +__all__ = [ + 'SqlAgentState', + 'SqlAgentNodes', + 'SqlAgentEdges', + 'SqlAgentGraph', + 'SqlAgentException', + 'ValidationException', + 'ExecutionException', + 'DatabaseConnectionException', + 'LLMProviderException', + 'MaxRetryExceededException' +] diff --git a/src/agents/sql_agent/edges.py b/src/agents/sql_agent/edges.py new file mode 100644 index 0000000..9a24fe3 --- /dev/null +++ b/src/agents/sql_agent/edges.py @@ -0,0 +1,51 @@ +# src/agents/sql_agent/edges.py + +from .state import SqlAgentState + +# 상수 정의 +MAX_ERROR_COUNT = 3 + +class SqlAgentEdges: + """SQL Agent의 모든 엣지 로직을 담당하는 클래스""" + + @staticmethod + def route_after_intent_classification(state: SqlAgentState) -> str: + """의도 분류 결과에 따라 라우팅을 결정합니다.""" + if state['intent'] == "SQL": + print("--- 의도: SQL 관련 질문 ---") + return "db_classifier" + print("--- 의도: SQL과 관련 없는 질문 ---") + return "unsupported_question" + + @staticmethod + def should_execute_sql(state: SqlAgentState) -> str: + """SQL 검증 결과에 따라 다음 단계를 결정합니다.""" + validation_error_count = state.get("validation_error_count", 0) + + if validation_error_count >= MAX_ERROR_COUNT: + print(f"--- 검증 실패 {MAX_ERROR_COUNT}회 초과: 답변 생성으로 이동 ---") + return "synthesize_failure" + + if state.get("validation_error"): + print(f"--- 검증 실패 {validation_error_count}회: SQL 재생성 ---") + return "regenerate" + + print("--- 검증 성공: SQL 실행 ---") + return "execute" + + @staticmethod + def should_retry_or_respond(state: SqlAgentState) -> str: + """SQL 실행 결과에 따라 다음 단계를 결정합니다.""" + execution_error_count = state.get("execution_error_count", 0) + execution_result = state.get("execution_result", "") + + if execution_error_count >= MAX_ERROR_COUNT: + print(f"--- 실행 실패 {MAX_ERROR_COUNT}회 초과: 답변 생성으로 이동 ---") + return "synthesize_failure" + + if "오류" in execution_result: + print(f"--- 실행 실패 {execution_error_count}회: SQL 재생성 ---") + return "regenerate" + + print("--- 실행 성공: 최종 답변 생성 ---") + return "synthesize_success" diff --git a/src/agents/sql_agent/exceptions.py b/src/agents/sql_agent/exceptions.py new file mode 100644 index 0000000..0e0171a --- /dev/null +++ b/src/agents/sql_agent/exceptions.py @@ -0,0 +1,31 @@ +# src/agents/sql_agent/exceptions.py + +class SqlAgentException(Exception): + """SQL Agent 관련 기본 예외 클래스""" + pass + +class ValidationException(SqlAgentException): + """SQL 검증 실패 예외""" + def __init__(self, message: str, error_count: int = 0): + super().__init__(message) + self.error_count = error_count + +class ExecutionException(SqlAgentException): + """SQL 실행 실패 예외""" + def __init__(self, message: str, error_count: int = 0): + super().__init__(message) + self.error_count = error_count + +class DatabaseConnectionException(SqlAgentException): + """데이터베이스 연결 실패 예외""" + pass + +class LLMProviderException(SqlAgentException): + """LLM 제공자 관련 예외""" + pass + +class MaxRetryExceededException(SqlAgentException): + """최대 재시도 횟수 초과 예외""" + def __init__(self, message: str, max_retries: int): + super().__init__(f"{message} (최대 재시도 {max_retries}회 초과)") + self.max_retries = max_retries diff --git a/src/agents/sql_agent/graph.py b/src/agents/sql_agent/graph.py new file mode 100644 index 0000000..cf48421 --- /dev/null +++ b/src/agents/sql_agent/graph.py @@ -0,0 +1,129 @@ +# src/agents/sql_agent/graph.py + +from langgraph.graph import StateGraph, END +from core.providers.llm_provider import LLMProvider +from services.database.database_service import DatabaseService +from .state import SqlAgentState +from .nodes import SqlAgentNodes +from .edges import SqlAgentEdges + +class SqlAgentGraph: + """SQL Agent 그래프를 구성하고 관리하는 클래스""" + + def __init__(self, llm_provider: LLMProvider, database_service: DatabaseService): + self.llm_provider = llm_provider + self.database_service = database_service + self.nodes = SqlAgentNodes(llm_provider, database_service) + self.edges = SqlAgentEdges() + self._graph = None + + def create_graph(self) -> StateGraph: + """SQL Agent 그래프를 생성하고 구성합니다.""" + if self._graph is not None: + return self._graph + + graph = StateGraph(SqlAgentState) + + # 노드 추가 + self._add_nodes(graph) + + # 엣지 추가 + self._add_edges(graph) + + # 진입점 설정 + graph.set_entry_point("intent_classifier") + + # 그래프 컴파일 + self._graph = graph.compile() + return self._graph + + def _add_nodes(self, graph: StateGraph): + """그래프에 모든 노드를 추가합니다.""" + graph.add_node("intent_classifier", self.nodes.intent_classifier_node) + graph.add_node("db_classifier", self.nodes.db_classifier_node) + graph.add_node("unsupported_question", self.nodes.unsupported_question_node) + graph.add_node("sql_generator", self.nodes.sql_generator_node) + graph.add_node("sql_validator", self.nodes.sql_validator_node) + graph.add_node("sql_executor", self.nodes.sql_executor_node) + graph.add_node("response_synthesizer", self.nodes.response_synthesizer_node) + + def _add_edges(self, graph: StateGraph): + """그래프에 모든 엣지를 추가합니다.""" + # 의도 분류 후 조건부 라우팅 + graph.add_conditional_edges( + "intent_classifier", + self.edges.route_after_intent_classification, + { + "db_classifier": "db_classifier", + "unsupported_question": "unsupported_question" + } + ) + + # 지원되지 않는 질문 처리 후 종료 + graph.add_edge("unsupported_question", END) + + # DB 분류 후 SQL 생성으로 이동 + graph.add_edge("db_classifier", "sql_generator") + + # SQL 생성 후 검증으로 이동 + graph.add_edge("sql_generator", "sql_validator") + + # SQL 검증 후 조건부 라우팅 + graph.add_conditional_edges( + "sql_validator", + self.edges.should_execute_sql, + { + "regenerate": "sql_generator", + "execute": "sql_executor", + "synthesize_failure": "response_synthesizer" + } + ) + + # SQL 실행 후 조건부 라우팅 + graph.add_conditional_edges( + "sql_executor", + self.edges.should_retry_or_respond, + { + "regenerate": "sql_generator", + "synthesize_success": "response_synthesizer", + "synthesize_failure": "response_synthesizer" + } + ) + + # 응답 생성 후 종료 + graph.add_edge("response_synthesizer", END) + + async def run(self, initial_state: dict) -> dict: + """그래프를 실행하고 결과를 반환합니다.""" + try: + if self._graph is None: + self.create_graph() + + result = await self._graph.ainvoke(initial_state) + return result + + except Exception as e: + print(f"그래프 실행 중 오류 발생: {e}") + # 에러 발생 시 예외를 다시 발생시켜 상위 레벨에서 HTTP 에러로 처리되도록 함 + raise e + + def save_graph_visualization(self, file_path: str = "sql_agent_graph.png") -> bool: + """그래프 시각화를 파일로 저장합니다.""" + try: + if self._graph is None: + self.create_graph() + + # PNG 이미지 생성 + png_data = self._graph.get_graph(xray=True).draw_mermaid_png() + + # 파일로 저장 + with open(file_path, "wb") as f: + f.write(png_data) + + print(f"그래프 시각화가 {file_path}에 저장되었습니다.") + return True + + except Exception as e: + print(f"그래프 시각화 저장 실패: {e}") + return False + \ No newline at end of file diff --git a/src/agents/sql_agent/nodes.py b/src/agents/sql_agent/nodes.py new file mode 100644 index 0000000..dd2e916 --- /dev/null +++ b/src/agents/sql_agent/nodes.py @@ -0,0 +1,380 @@ +# src/agents/sql_agent/nodes.py + +import os +import sys +import asyncio +from typing import List, Optional +from langchain.output_parsers.pydantic import PydanticOutputParser +from langchain_core.output_parsers import StrOutputParser +from langchain.prompts import load_prompt + +from schemas.agent.sql_schemas import SqlQuery +from services.database.database_service import DatabaseService +from core.providers.llm_provider import LLMProvider +from .state import SqlAgentState +from .exceptions import ( + ValidationException, + ExecutionException, + DatabaseConnectionException, + MaxRetryExceededException +) + +# 상수 정의 +MAX_ERROR_COUNT = 3 +PROMPT_VERSION = "v1" +PROMPT_DIR = os.path.join("prompts", PROMPT_VERSION, "sql_agent") + +def resource_path(relative_path): + """PyInstaller 경로 해결 함수""" + try: + base_path = sys._MEIPASS + except Exception: + base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) + return os.path.join(base_path, relative_path) + +class SqlAgentNodes: + """SQL Agent의 모든 노드 로직을 담당하는 클래스""" + + def __init__(self, llm_provider: LLMProvider, database_service: DatabaseService): + self.llm_provider = llm_provider + self.database_service = database_service + + # 프롬프트 로드 + self._load_prompts() + + def _load_prompts(self): + """프롬프트 파일들을 로드합니다.""" + try: + self.intent_classifier_prompt = load_prompt( + resource_path(os.path.join(PROMPT_DIR, "intent_classifier.yaml")) + ) + self.db_classifier_prompt = load_prompt( + resource_path(os.path.join(PROMPT_DIR, "db_classifier.yaml")) + ) + self.sql_generator_prompt = load_prompt( + resource_path(os.path.join(PROMPT_DIR, "sql_generator.yaml")) + ) + self.response_synthesizer_prompt = load_prompt( + resource_path(os.path.join(PROMPT_DIR, "response_synthesizer.yaml")) + ) + except Exception as e: + raise FileNotFoundError(f"프롬프트 파일 로드 실패: {e}") + + async def intent_classifier_node(self, state: SqlAgentState) -> SqlAgentState: + """사용자 질문의 의도를 분류하는 노드""" + print("--- 0. 의도 분류 중 ---") + + try: + llm = await self.llm_provider.get_llm() + + # 채팅 내역을 활용하여 의도 분류 + input_data = { + "question": state['question'], + "chat_history": state.get('chat_history', []) + } + + chain = self.intent_classifier_prompt | llm | StrOutputParser() + intent = await chain.ainvoke(input_data) + state['intent'] = intent.strip() + + print(f"의도 분류 결과: {state['intent']}") + return state + + except Exception as e: + print(f"의도 분류 실패: {e}") + # 기본값으로 SQL 처리 + state['intent'] = "SQL" + return state + + async def unsupported_question_node(self, state: SqlAgentState) -> SqlAgentState: + """SQL과 관련 없는 질문을 처리하는 노드""" + print("--- SQL 관련 없는 질문 ---") + + state['final_response'] = """죄송합니다, 해당 질문에는 답변할 수 없습니다. +저는 데이터베이스 관련 질문만 처리할 수 있습니다. +SQL 쿼리나 데이터 분석과 관련된 질문을 해주세요.""" + + return state + + async def db_classifier_node(self, state: SqlAgentState) -> SqlAgentState: + """데이터베이스를 분류하고 스키마를 가져오는 노드""" + print("--- 0.5. DB 분류 중 ---") + + try: + # DBMS 프로필과 어노테이션을 함께 조회 + available_dbs_with_annotations = await self.database_service.get_databases_with_annotations() + + if not available_dbs_with_annotations: + raise DatabaseConnectionException("사용 가능한 DBMS가 없습니다.") + + print(f"--- {len(available_dbs_with_annotations)}개의 DBMS 발견 ---") + + # 어노테이션 정보를 포함한 DBMS 옵션 생성 + db_options = "\n".join([ + f"- {db['display_name']}: {db['description']}" + for db in available_dbs_with_annotations + ]) + + # LLM을 사용하여 적절한 DBMS 선택 + llm = await self.llm_provider.get_llm() + chain = self.db_classifier_prompt | llm | StrOutputParser() + selected_db_display_name = await chain.ainvoke({ + "db_options": db_options, + "chat_history": state['chat_history'], + "question": state['question'] + }) + + selected_db_display_name = selected_db_display_name.strip() + + # 선택된 display_name으로 실제 DBMS 정보 찾기 + selected_db_info = None + for db in available_dbs_with_annotations: + if db['display_name'] == selected_db_display_name: + selected_db_info = db + break + + if not selected_db_info: + # 부분 매칭 시도 + for db in available_dbs_with_annotations: + if selected_db_display_name in db['display_name'] or db['display_name'] in selected_db_display_name: + selected_db_info = db + break + + if not selected_db_info: + print(f"--- 선택된 DBMS를 찾을 수 없음: {selected_db_display_name}, 첫 번째 DBMS 사용 ---") + selected_db_info = available_dbs_with_annotations[0] + + state['selected_db'] = selected_db_info['display_name'] + state['selected_db_profile'] = selected_db_info['profile'] + state['selected_db_annotations'] = selected_db_info['annotations'] + + print(f'--- 선택된 DBMS: {selected_db_info["display_name"]} ---') + print(f'--- DBMS 프로필 ID: {selected_db_info["profile"]["id"]} ---') + + # 어노테이션 정보를 스키마로 사용 + if selected_db_info['annotations'] and 'data' in selected_db_info['annotations']: + schema_info = self._convert_annotations_to_schema(selected_db_info['annotations']) + state['db_schema'] = schema_info + print(f"--- 어노테이션 기반 스키마 사용 ---") + else: + # 어노테이션이 없는 경우 기본 정보로 대체 + schema_info = f"DBMS 유형: {selected_db_info['profile']['type']}\n" + schema_info += f"호스트: {selected_db_info['profile']['host']}\n" + schema_info += f"포트: {selected_db_info['profile']['port']}\n" + schema_info += "상세 스키마 정보가 없습니다. 기본 SQL 구문을 사용하세요." + state['db_schema'] = schema_info + print(f"--- 기본 DBMS 정보 사용 ---") + + return state + + except Exception as e: + print(f"데이터베이스 분류 실패: {e}") + print(f"에러 타입: {type(e).__name__}") + print(f"에러 상세: {str(e)}") + + # 폴백 없이 에러를 다시 발생시킴 + raise e + + def _convert_annotations_to_schema(self, annotations: dict) -> str: + """어노테이션 데이터를 스키마 문자열로 변환합니다.""" + try: + if not annotations or 'data' not in annotations: + return "어노테이션 스키마 정보가 없습니다." + + # 어노테이션 구조에 따라 스키마 정보 추출 + # 실제 어노테이션 응답 구조를 확인 후 구현 필요 + schema_parts = [] + schema_parts.append("=== 어노테이션 기반 스키마 정보 ===") + + annotation_data = annotations.get('data', {}) + + # 어노테이션 데이터가 데이터베이스 정보를 포함하는 경우 + if isinstance(annotation_data, dict): + for key, value in annotation_data.items(): + schema_parts.append(f"{key}: {str(value)[:200]}...") + elif isinstance(annotation_data, list): + for i, item in enumerate(annotation_data): + schema_parts.append(f"항목 {i+1}: {str(item)[:200]}...") + else: + schema_parts.append(f"어노테이션 데이터: {str(annotation_data)[:500]}...") + + return "\n".join(schema_parts) + + except Exception as e: + print(f"어노테이션 변환 중 오류: {e}") + return f"어노테이션 변환 실패: {e}" + + async def sql_generator_node(self, state: SqlAgentState) -> SqlAgentState: + """SQL 쿼리를 생성하는 노드""" + print("--- 1. SQL 생성 중 ---") + + try: + parser = PydanticOutputParser(pydantic_object=SqlQuery) + + # 에러 피드백 컨텍스트 생성 + error_feedback = self._build_error_feedback(state) + + prompt = self.sql_generator_prompt.format( + format_instructions=parser.get_format_instructions(), + db_schema=state['db_schema'], + chat_history=state['chat_history'], + question=state['question'], + error_feedback=error_feedback + ) + + llm = await self.llm_provider.get_llm() + response = await llm.ainvoke(prompt) + parsed_query = parser.invoke(response.content) + + state['sql_query'] = parsed_query.query + state['validation_error'] = None + state['execution_result'] = None + + return state + + except Exception as e: + raise ExecutionException(f"SQL 생성 실패: {e}") + + def _build_error_feedback(self, state: SqlAgentState) -> str: + """에러 피드백 컨텍스트를 생성합니다.""" + error_feedback = "" + + # 검증 오류가 있었을 경우 + if state.get("validation_error") and state.get("validation_error_count", 0) > 0: + error_feedback = f""" + Your previous query was rejected for the following reason: {state['validation_error']} + Please generate a new, safe query that does not contain forbidden keywords. + """ + # 실행 오류가 있었을 경우 + elif (state.get("execution_result") and + "오류" in state.get("execution_result", "") and + state.get("execution_error_count", 0) > 0): + error_feedback = f""" + Your previously generated SQL query failed with the following database error: + FAILED SQL: {state['sql_query']} + DATABASE ERROR: {state['execution_result']} + Please correct the SQL query based on the error. + """ + + return error_feedback + + async def sql_validator_node(self, state: SqlAgentState) -> SqlAgentState: + """SQL 쿼리의 안전성을 검증하는 노드""" + print("--- 2. SQL 검증 중 ---") + + try: + query_words = state['sql_query'].lower().split() + dangerous_keywords = [ + "drop", "delete", "update", "insert", "truncate", + "alter", "create", "grant", "revoke" + ] + found_keywords = [keyword for keyword in dangerous_keywords if keyword in query_words] + + if found_keywords: + keyword_str = ', '.join(f"'{k}'" for k in found_keywords) + error_msg = f'위험한 키워드 {keyword_str}가 포함되어 있습니다.' + state['validation_error'] = error_msg + state['validation_error_count'] = state.get('validation_error_count', 0) + 1 + + if state['validation_error_count'] >= MAX_ERROR_COUNT: + raise MaxRetryExceededException( + f"SQL 검증 실패가 {MAX_ERROR_COUNT}회 반복됨", MAX_ERROR_COUNT + ) + else: + state['validation_error'] = None + state['validation_error_count'] = 0 + + return state + + except MaxRetryExceededException: + raise + except Exception as e: + raise ValidationException(f"SQL 검증 중 오류 발생: {e}") + + async def sql_executor_node(self, state: SqlAgentState) -> SqlAgentState: + """SQL 쿼리를 실행하는 노드""" + print("--- 3. SQL 실행 중 ---") + + try: + selected_db = state.get('selected_db', 'default') + + # 선택된 DB 프로필에서 실제 DB ID 가져오기 + db_profile = state.get('selected_db_profile') + if db_profile and 'id' in db_profile: + user_db_id = db_profile['id'] + print(f"--- 실행용 DB 프로필 ID: {user_db_id} ---") + else: + user_db_id = 'TEST-USER-DB-12345' # 폴백 + print(f"--- DB 프로필이 없어 테스트 ID 사용: {user_db_id} ---") + + result = await self.database_service.execute_query( + state['sql_query'], + database_name=selected_db, + user_db_id=user_db_id + ) + + state['execution_result'] = result + state['validation_error_count'] = 0 + state['execution_error_count'] = 0 + + return state + + except Exception as e: + error_msg = f"실행 오류: {e}" + state['execution_result'] = error_msg + state['validation_error_count'] = 0 + state['execution_error_count'] = state.get('execution_error_count', 0) + 1 + + print(f"⚠️ SQL 실행 실패 ({state['execution_error_count']}/{MAX_ERROR_COUNT}): {error_msg}") + + # 실행 실패 시에도 상태를 반환하여 엣지에서 판단하도록 함 + return state + + async def response_synthesizer_node(self, state: SqlAgentState) -> SqlAgentState: + """최종 답변을 생성하는 노드""" + print("--- 4. 최종 답변 생성 중 ---") + + try: + is_failure = (state.get('validation_error_count', 0) >= MAX_ERROR_COUNT or + state.get('execution_error_count', 0) >= MAX_ERROR_COUNT) + + if is_failure: + context_message = self._build_failure_context(state) + else: + context_message = f""" + Successfully executed the SQL query to answer the user's question. + SQL Query: {state['sql_query']} + SQL Result: {state['execution_result']} + """ + + prompt = self.response_synthesizer_prompt.format( + question=state['question'], + chat_history=state['chat_history'], + context_message=context_message + ) + + llm = await self.llm_provider.get_llm() + response = await llm.ainvoke(prompt) + state['final_response'] = response.content + + return state + + except Exception as e: + # 최종 답변 생성 실패 시 기본 메시지 제공 + state['final_response'] = f"죄송합니다. 답변 생성 중 오류가 발생했습니다: {e}" + return state + + def _build_failure_context(self, state: SqlAgentState) -> str: + """실패 상황에 대한 컨텍스트 메시지를 생성합니다.""" + if state.get('validation_error_count', 0) >= MAX_ERROR_COUNT: + error_type = "SQL 검증" + error_details = state.get('validation_error') + else: + error_type = "SQL 실행" + error_details = state.get('execution_result') + + return f""" + An attempt to answer the user's question failed after multiple retries. + Failure Type: {error_type} + Last Error: {error_details} + """ diff --git a/src/agents/sql_agent/state.py b/src/agents/sql_agent/state.py new file mode 100644 index 0000000..fdf909e --- /dev/null +++ b/src/agents/sql_agent/state.py @@ -0,0 +1,32 @@ +# src/agents/sql_agent/state.py + +from typing import List, TypedDict, Optional, Dict, Any +from langchain_core.messages import BaseMessage + +class SqlAgentState(TypedDict): + """SQL Agent의 상태를 정의하는 TypedDict""" + + # 입력 정보 + question: str + chat_history: List[BaseMessage] + + # 데이터베이스 관련 + selected_db: Optional[str] + db_schema: str + selected_db_profile: Optional[Dict[str, Any]] # DB 프로필 정보 + selected_db_annotations: Optional[Dict[str, Any]] # DB 어노테이션 정보 + + # 의도 분류 결과 + intent: str + + # SQL 생성 및 검증 + sql_query: str + validation_error: Optional[str] + validation_error_count: int + + # SQL 실행 결과 + execution_result: Optional[str] + execution_error_count: int + + # 최종 응답 + final_response: str diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..365cd08 --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,6 @@ +""" +API 패키지 루트 +""" + + + diff --git a/src/api/v1/__init__.py b/src/api/v1/__init__.py new file mode 100644 index 0000000..cd0ff49 --- /dev/null +++ b/src/api/v1/__init__.py @@ -0,0 +1,6 @@ +""" +API v1 패키지 +""" + + + diff --git a/src/api/v1/routers/annotator.py b/src/api/v1/routers/annotator.py new file mode 100644 index 0000000..e918ce8 --- /dev/null +++ b/src/api/v1/routers/annotator.py @@ -0,0 +1,67 @@ +# src/api/v1/routers/annotator.py + +from fastapi import APIRouter, HTTPException, Depends +from typing import Dict, Any + +from schemas.api.annotator_schemas import AnnotationRequest, AnnotationResponse +from services.annotation.annotation_service import AnnotationService, get_annotation_service +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.post("/annotator", response_model=AnnotationResponse) +async def create_annotations( + request: AnnotationRequest, + service: AnnotationService = Depends(get_annotation_service) +) -> AnnotationResponse: + """ + DB 스키마 정보를 받아 각 요소에 대한 설명을 비동기적으로 생성하여 반환합니다. + + Args: + request: 어노테이션 요청 (DB 스키마 정보) + service: 어노테이션 서비스 로직 + + Returns: + AnnotationResponse: 어노테이션이 추가된 스키마 정보 + + Raises: + HTTPException: 요청 처리 실패 시 + """ + try: + logger.info(f"Received annotation request for {len(request.databases)} databases") + + response = await service.generate_for_schema(request) + + logger.info("Annotation request processed successfully") + + return response + + except Exception as e: + logger.error(f"Annotation request failed: {e}") + raise HTTPException( + status_code=500, + detail=f"어노테이션 생성 중 오류가 발생했습니다: {e}" + ) + +@router.get("/annotator/health") +async def annotator_health_check( + service: AnnotationService = Depends(get_annotation_service) +) -> Dict[str, Any]: + """ + 어노테이션 서비스의 상태를 확인합니다. + + Returns: + Dict: 서비스 상태 정보 + """ + try: + health_status = await service.health_check() + return health_status + + except Exception as e: + logger.error(f"Annotator health check failed: {e}") + return { + "status": "unhealthy", + "error": str(e) + } diff --git a/src/api/v1/routers/chat.py b/src/api/v1/routers/chat.py new file mode 100644 index 0000000..113196b --- /dev/null +++ b/src/api/v1/routers/chat.py @@ -0,0 +1,91 @@ +# src/api/v1/routers/chat.py + +from fastapi import APIRouter, HTTPException, Depends +from typing import Dict, Any, List + +from schemas.api.chat_schemas import ChatRequest, ChatResponse +from services.chat.chatbot_service import ChatbotService, get_chatbot_service +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.post("/chat", response_model=ChatResponse) +async def handle_chat_request( + request: ChatRequest, + service: ChatbotService = Depends(get_chatbot_service) +) -> ChatResponse: + """ + 사용자의 채팅 요청을 받아 챗봇의 답변을 반환합니다. + + Args: + request: 챗봇 요청 (질문과 채팅 히스토리) + service: 챗봇 서비스 로직 + + Returns: + ChatResponse: 챗봇 응답 + + Raises: + HTTPException: 요청 처리 실패 시 + """ + try: + logger.info(f"Received chat request: {request.question[:100]}...") + + final_answer = await service.handle_request( + user_question=request.question, + chat_history=request.chat_history + ) + + logger.info("Chat request processed successfully") + + return ChatResponse(answer=final_answer) + + except Exception as e: + logger.error(f"Chat request failed: {e}") + raise HTTPException( + status_code=500, + detail=f"채팅 요청 처리 중 오류가 발생했습니다: {e}" + ) + +@router.get("/chat/health") +async def chat_health_check( + service: ChatbotService = Depends(get_chatbot_service) +) -> Dict[str, Any]: + """ + 챗봇 서비스의 상태를 확인합니다. + + Returns: + Dict: 서비스 상태 정보 + """ + try: + health_status = await service.health_check() + return health_status + + except Exception as e: + logger.error(f"Chat health check failed: {e}") + return { + "status": "unhealthy", + "error": str(e) + } + +@router.get("/chat/databases") +async def get_available_databases( + service: ChatbotService = Depends(get_chatbot_service) +) -> Dict[str, List[Dict[str, str]]]: + """ + 사용 가능한 데이터베이스 목록을 반환합니다. + + Returns: + Dict: 데이터베이스 목록 + """ + try: + databases = await service.get_available_databases() + return {"databases": databases} + + except Exception as e: + logger.error(f"Failed to get databases: {e}") + raise HTTPException( + status_code=500, + detail=f"데이터베이스 목록 조회 중 오류가 발생했습니다: {e}" + ) diff --git a/src/api/v1/routers/health.py b/src/api/v1/routers/health.py new file mode 100644 index 0000000..1ea2ce0 --- /dev/null +++ b/src/api/v1/routers/health.py @@ -0,0 +1,77 @@ +# src/api/v1/routers/health.py + +from fastapi import APIRouter, Depends +from typing import Dict, Any + +from services.chat.chatbot_service import ChatbotService, get_chatbot_service +from services.annotation.annotation_service import AnnotationService, get_annotation_service +from services.database.database_service import DatabaseService, get_database_service +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.get("/health") +async def root_health_check() -> Dict[str, str]: + """ + 루트 헬스체크 엔드포인트, 서버 상태가 정상이면 'ok' 반환합니다. + + Returns: + Dict: 기본 상태 정보 + """ + return { + "status": "ok", + "message": "Welcome to the QGenie Chatbot AI!", + "version": "2.0.0" + } + +@router.get("/health/detailed") +async def detailed_health_check( + chatbot_service: ChatbotService = Depends(get_chatbot_service), + annotation_service: AnnotationService = Depends(get_annotation_service), + database_service: DatabaseService = Depends(get_database_service) +) -> Dict[str, Any]: + """ + 전체 시스템의 상세 헬스체크를 수행합니다. + + Returns: + Dict: 상세 상태 정보 + """ + try: + # 모든 서비스의 헬스체크를 병렬로 실행 + import asyncio + + chatbot_health, annotation_health, database_health = await asyncio.gather( + chatbot_service.health_check(), + annotation_service.health_check(), + database_service.health_check(), + return_exceptions=True + ) + + # 각 서비스 상태 처리 + services_status = { + "chatbot": chatbot_health if not isinstance(chatbot_health, Exception) else {"status": "unhealthy", "error": str(chatbot_health)}, + "annotation": annotation_health if not isinstance(annotation_health, Exception) else {"status": "unhealthy", "error": str(annotation_health)}, + "database": {"status": "healthy" if database_health and not isinstance(database_health, Exception) else "unhealthy"} + } + + # 전체 상태 결정 + all_healthy = all( + service.get("status") == "healthy" + for service in services_status.values() + ) + + return { + "status": "healthy" if all_healthy else "partial", + "services": services_status, + "timestamp": __import__("datetime").datetime.now().isoformat() + } + + except Exception as e: + logger.error(f"Detailed health check failed: {e}") + return { + "status": "unhealthy", + "error": str(e), + "timestamp": __import__("datetime").datetime.now().isoformat() + } diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..9c22b0b --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,15 @@ +# src/core/__init__.py + +""" +코어 모듈 - 기본 인프라스트럭처 구성 요소들 +""" + +from .providers.llm_provider import LLMProvider, get_llm_provider +from .clients.api_client import APIClient, get_api_client + +__all__ = [ + 'LLMProvider', + 'get_llm_provider', + 'APIClient', + 'get_api_client' +] diff --git a/src/core/clients/api_client.py b/src/core/clients/api_client.py new file mode 100644 index 0000000..48b0f3b --- /dev/null +++ b/src/core/clients/api_client.py @@ -0,0 +1,339 @@ +# src/core/clients/api_client.py + +import httpx +import asyncio +from typing import List, Dict, Any, Optional, Union +from pydantic import BaseModel +import logging + +# 로깅 설정 +logger = logging.getLogger(__name__) + +class DatabaseInfo(BaseModel): + """데이터베이스 정보 모델""" + connection_name: str + database_name: str + description: str + +class DBProfileInfo(BaseModel): + """DB 프로필 정보 모델""" + id: str + type: str + host: Optional[str] = None + port: Optional[int] = None + name: Optional[str] = None + username: Optional[str] = None + view_name: Optional[str] = None + created_at: str + updated_at: str + +class DBProfileResponse(BaseModel): + """DB 프로필 조회 응답 모델""" + code: str + message: str + data: List[DBProfileInfo] + +class QueryExecutionRequest(BaseModel): + """쿼리 실행 요청 모델""" + user_db_id: str + database: str + query_text: str + +class QueryResultData(BaseModel): + """쿼리 실행 결과 데이터 모델""" + columns: List[str] + data: List[Dict[str, Any]] + +class QueryExecutionResponse(BaseModel): + """쿼리 실행 응답 모델""" + code: str + message: str + data: Union[QueryResultData, str, bool] # 결과 데이터, 에러 메시지 + +class APIClient: + """백엔드 API와 통신하는 클라이언트 클래스""" + + def __init__(self, base_url: str = "http://localhost:39722"): + self.base_url = base_url + self.timeout = httpx.Timeout(30.0) + self.headers = { + "Content-Type": "application/json" + } + self._client: Optional[httpx.AsyncClient] = None + + async def _get_client(self) -> httpx.AsyncClient: + """재사용 가능한 HTTP 클라이언트를 반환합니다.""" + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(timeout=self.timeout) + return self._client + + async def close(self): + """HTTP 클라이언트 연결을 닫습니다.""" + if self._client and not self._client.is_closed: + await self._client.aclose() + async def get_db_profiles(self) -> List[DBProfileInfo]: + """모든 DBMS 프로필 정보를 가져옵니다.""" + try: + client = await self._get_client() + response = await client.get( + f"{self.base_url}/api/user/db/find/all", + headers=self.headers + ) + response.raise_for_status() + + data = response.json() + + # 응답 구조 검증 + if data.get("code") != "2102": + logger.warning(f"Unexpected response code: {data.get('code')}") + + profiles = [DBProfileInfo(**profile) for profile in data.get("data", [])] + logger.info(f"Successfully fetched {len(profiles)} DB profiles") + return profiles + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}") + raise + except httpx.RequestError as e: + logger.error(f"Request error occurred: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error: {e}") + raise + + async def get_db_annotations(self, db_profile_id: str) -> Dict[str, Any]: + """특정 DBMS의 어노테이션을 조회합니다.""" + try: + client = await self._get_client() + response = await client.get( + f"{self.base_url}/api/annotations/find/db/{db_profile_id}", + headers=self.headers + ) + response.raise_for_status() + + data = response.json() + logger.info(f"Successfully fetched annotations for DB profile: {db_profile_id}") + return data + + except httpx.HTTPStatusError as e: + if e.response.status_code == 404: + # 404는 어노테이션이 없는 정상적인 상황 + logger.info(f"No annotations found for DB profile {db_profile_id}: {e.response.text}") + return {"code": "4401", "message": "어노테이션이 없습니다", "data": []} + else: + logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}") + raise + except httpx.RequestError as e: + logger.error(f"Request error occurred: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error: {e}") + raise + + async def get_available_databases(self) -> List[DatabaseInfo]: + """ + [DEPRECATED] 사용 가능한 데이터베이스 목록을 가져옵니다. + 대신 get_db_profiles()를 사용하세요. + """ + logger.warning("get_available_databases()는 deprecated입니다. get_db_profiles()를 사용하세요.") + + # DBMS 프로필 기반으로 DatabaseInfo 형태로 변환하여 호환성 유지 + try: + profiles = await self.get_db_profiles() + databases = [] + + for profile in profiles: + db_info = DatabaseInfo( + connection_name=f"{profile.type}_{profile.host}_{profile.port}", + database_name=profile.view_name or f"{profile.type}_db", + description=f"{profile.type} 데이터베이스 ({profile.host}:{profile.port})" + ) + databases.append(db_info) + + logger.info(f"Successfully converted {len(databases)} DB profiles to DatabaseInfo") + return databases + + except Exception as e: + logger.error(f"Failed to convert DB profiles: {e}") + raise + # TODO: DB 스키마 조회 API 필요 + async def get_database_schema(self, database_name: str) -> str: + """특정 데이터베이스의 스키마 정보를 가져옵니다.""" + try: + client = await self._get_client() + response = await client.get( + f"{self.base_url}/api/v1/databases/{database_name}/schema", + headers=self.headers + ) + response.raise_for_status() + + data = response.json() + schema = data.get("schema", "") + logger.info(f"Successfully fetched schema for database: {database_name}") + return schema + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}") + raise + except httpx.RequestError as e: + logger.error(f"Request error occurred: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error: {e}") + raise + + async def execute_query( + self, + sql_query: str, + database_name: str, + user_db_id: str = None + ) -> QueryExecutionResponse: + """SQL 쿼리를 Backend 서버에 전송하여 실행하고 결과를 받아옵니다.""" + try: + logger.info(f"Sending SQL query to backend: {sql_query}") + + request_data = QueryExecutionRequest( + user_db_id=user_db_id, + database=database_name, + query_text=sql_query + ) + + client = await self._get_client() + response = await client.post( + f"{self.base_url}/api/query/execute/test", + json=request_data.model_dump(), + headers=self.headers, + timeout=httpx.Timeout(35.0) # 고정 타임아웃 + ) + + response.raise_for_status() # HTTP 에러 시 예외 발생 + + response_data = response.json() + + # data 필드 타입에 따라 처리 + raw_data = response_data.get("data") + parsed_data = raw_data + + # data가 객체 형태(쿼리 결과)인지 확인 + if isinstance(raw_data, dict) and "columns" in raw_data and "data" in raw_data: + try: + parsed_data = QueryResultData(**raw_data) + except Exception as e: + logger.warning(f"Failed to parse query result data: {e}, using raw data") + parsed_data = raw_data + + result = QueryExecutionResponse( + code=response_data.get("code"), + message=response_data.get("message"), + data=parsed_data + ) + + if result.code == "2400": + logger.info(f"Query executed successfully: {result.message}") + else: + logger.warning(f"Query execution returned non-success code: {result.code} - {result.message}") + + return result + + except httpx.TimeoutException: + logger.error("Backend API 요청 시간 초과") + raise + except httpx.ConnectError: + logger.error("Backend 서버 연결 실패") + raise + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}") + raise + except Exception as e: + logger.error(f"Unexpected error during query execution: {e}") + raise + + async def health_check(self) -> bool: + """API 서버 상태를 확인합니다.""" + try: + client = await self._get_client() + response = await client.get( + f"{self.base_url}/health", + timeout=httpx.Timeout(5.0) + ) + return response.status_code == 200 + except Exception as e: + logger.error(f"Health check failed: {e}") + return False + + async def get_openai_api_key(self) -> str: + """백엔드에서 OpenAI API 키를 가져옵니다.""" + try: + client = await self._get_client() + + # 1단계: 암호화된 API 키 조회 + response = await client.get( + f"{self.base_url}/api/keys/find", + headers=self.headers, + timeout=httpx.Timeout(10.0) + ) + response.raise_for_status() + + data = response.json() + + # data 배열에서 OpenAI 서비스 찾기 + api_keys = data.get("data", []) + openai_key = None + + # 가장 첫번째 OpenAI 키 사용 + for key_info in api_keys: + if key_info.get("service_name") == "OpenAI": + openai_key = key_info.get("id") + break + + if not openai_key: + raise ValueError("백엔드에서 OpenAI API 키를 찾을 수 없습니다.") + + # 2단계: 복호화된 실제 API 키 조회 + decrypt_response = await client.get( + f"{self.base_url}/api/keys/find/decrypted/OpenAI", + headers=self.headers, + timeout=httpx.Timeout(10.0) + ) + decrypt_response.raise_for_status() + + decrypt_data = decrypt_response.json() + + # 복호화된 키 데이터에서 실제 API 키 추출 + data_field = decrypt_data.get("data", {}) + + if isinstance(data_field, dict) and "api_key" in data_field: + actual_api_key = data_field["api_key"] + else: + raise ValueError("백엔드 응답에서 API 키를 찾을 수 없습니다.") + + if not actual_api_key: + raise ValueError("백엔드에서 복호화된 OpenAI API 키를 가져올 수 없습니다.") + + logger.info("Successfully fetched decrypted OpenAI API key from backend") + return actual_api_key + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred while fetching API key: {e.response.status_code} - {e.response.text}") + raise + except httpx.RequestError as e: + logger.error(f"Request error occurred while fetching API key: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error while fetching API key: {e}") + raise + + async def __aenter__(self): + """비동기 컨텍스트 매니저 진입""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """비동기 컨텍스트 매니저 종료""" + await self.close() + +# 싱글톤 인스턴스 +_api_client = APIClient() + +async def get_api_client() -> APIClient: + """API Client 인스턴스를 반환합니다.""" + return _api_client diff --git a/src/core/providers/llm_provider.py b/src/core/providers/llm_provider.py new file mode 100644 index 0000000..0980a34 --- /dev/null +++ b/src/core/providers/llm_provider.py @@ -0,0 +1,114 @@ +# src/core/providers/llm_provider.py + +import os +import asyncio +import logging +from typing import Optional +from langchain_openai import ChatOpenAI +from core.clients.api_client import get_api_client + +logger = logging.getLogger(__name__) + +class LLMProvider: + """ + LLM 제공자를 관리하는 클래스 + 지연 초기화를 지원하여 BE 서버가 늦게 시작되어도 작동합니다. + """ + + def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0): + self.model_name = model_name + self.temperature = temperature + self._llm: Optional[ChatOpenAI] = None + self._api_key: Optional[str] = None + self._api_client = None + self._initialization_attempted: bool = False + self._initialization_failed: bool = False + + async def _load_api_key(self) -> str: + """백엔드에서 OpenAI API 키를 로드합니다.""" + try: + if self._api_key is None: + if self._api_client is None: + self._api_client = await get_api_client() + + self._api_key = await self._api_client.get_openai_api_key() + return self._api_key + + except Exception as e: + logger.error(f"Failed to fetch API key from backend: {e}") + raise ValueError("백엔드에서 OpenAI API 키를 가져올 수 없습니다. 백엔드 서버를 확인해주세요.") + + async def get_llm(self) -> ChatOpenAI: + """ + ChatOpenAI 인스턴스를 반환합니다. + 지연 초기화를 통해 BE 서버 연결이 실패해도 재시도합니다. + """ + if self._llm is None: + # 이전에 초기화를 시도했고 실패했다면 재시도 + if self._initialization_failed: + logger.info("🔄 LLM 초기화 재시도 중...") + self._initialization_failed = False + self._initialization_attempted = False + + try: + self._initialization_attempted = True + self._llm = await self._create_llm() + self._initialization_failed = False + logger.info("✅ LLM 초기화 성공") + + except Exception as e: + self._initialization_failed = True + logger.error(f"❌ LLM 초기화 실패: {e}") + raise RuntimeError(f"LLM을 초기화할 수 없습니다. 백엔드 서버가 실행 중인지 확인해주세요: {e}") + + return self._llm + + async def _create_llm(self) -> ChatOpenAI: + """ChatOpenAI 인스턴스를 생성합니다.""" + try: + # API 키를 비동기적으로 로드 + api_key = await self._load_api_key() + logger.info("✅ 백엔드에서 OpenAI API 키를 성공적으로 가져왔습니다") + + llm = ChatOpenAI( + model=self.model_name, + temperature=self.temperature, + api_key=api_key + ) + return llm + + except Exception as e: + raise RuntimeError(f"LLM 인스턴스 생성 실패: {e}") + + def update_model(self, model_name: str, temperature: float = None): + """모델 설정을 업데이트하고 인스턴스를 재생성합니다.""" + self.model_name = model_name + if temperature is not None: + self.temperature = temperature + self._llm = None # 다음 호출 시 재생성되도록 함 + + async def refresh_api_key(self): + """API 키를 새로고침합니다.""" + self._api_key = None + self._llm = None # LLM 인스턴스도 재생성 + self._initialization_attempted = False + self._initialization_failed = False + logger.info("API key refreshed") + + async def test_connection(self) -> bool: + """LLM 연결을 테스트합니다.""" + try: + llm = await self.get_llm() + test_response = await llm.ainvoke("테스트") + return test_response is not None + + except Exception as e: + print(f"LLM 연결 테스트 실패: {e}") + return False + +# 싱글톤 인스턴스 +_llm_provider = LLMProvider() + +async def get_llm_provider() -> LLMProvider: + """LLM Provider 인스턴스를 반환합니다.""" + return _llm_provider diff --git a/src/health_check/__init__.py b/src/health_check/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/health_check/router.py b/src/health_check/router.py deleted file mode 100644 index 8fbbb2e..0000000 --- a/src/health_check/router.py +++ /dev/null @@ -1,9 +0,0 @@ -# src/health_check/router.py -from flask import Flask, jsonify - -app = Flask(__name__) - -@app.route("/health") -def health_check(): - """헬스체크 엔드포인트, 서버 상태가 정상이면 'ok'를 반환합니다.""" - return jsonify(status="ok"), 200 \ No newline at end of file diff --git a/src/main.py b/src/main.py index 2b7529d..672f01d 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,113 @@ # src/main.py -from health_check.router import app + +import logging +from contextlib import asynccontextmanager +from fastapi import FastAPI + +from api.v1.routers import chat, annotator, health + +# 로깅 설정 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +@asynccontextmanager +async def lifespan(app: FastAPI): + """애플리케이션 라이프사이클 관리""" + logger.info("QGenie AI Chatbot 시작 중...") + + # 시작 시 BE 서버 연결 체크 + try: + from core.clients.api_client import get_api_client + api_client = await get_api_client() + + # BE 서버 상태 확인 + if await api_client.health_check(): + logger.info("✅ 백엔드 서버 연결 성공") + else: + logger.warning("⚠️ 백엔드 서버에 연결할 수 없습니다. 첫 요청 시 연결을 재시도합니다.") + + except Exception as e: + logger.warning(f"⚠️ 백엔드 서버 초기 연결 실패: {e}") + logger.info("🔄 서비스는 지연 초기화 모드로 시작됩니다.") + + try: + logger.info("애플리케이션 초기화 완료") + yield + finally: + # 종료 시 정리 작업 + logger.info("애플리케이션 종료 중...") + + # API 클라이언트 정리 + try: + from core.clients.api_client import get_api_client + api_client = await get_api_client() + await api_client.close() + logger.info("API 클라이언트 정리 완료") + except Exception as e: + logger.error(f"API 클라이언트 정리 실패: {e}") + + logger.info("애플리케이션 종료 완료") + +# FastAPI 앱 인스턴스 생성 +app = FastAPI( + title="QGenie - Agentic SQL Chatbot", + description="LangGraph로 구현된 사전 스키마를 지원하는 SQL 챗봇 (리팩터링 버전)", + version="2.0.0", + lifespan=lifespan +) + +# 라우터 등록 +app.include_router( + health.router, + prefix="/api/v1", + tags=["Health"] +) + +app.include_router( + chat.router, + prefix="/api/v1", + tags=["Chatbot"] +) + +app.include_router( + annotator.router, + prefix="/api/v1", + tags=["Annotator"] +) + +# 루트 엔드포인트 +@app.get("/") +async def root(): + """루트 엔드포인트 - 기본 상태 확인""" + return { + "status": "ok", + "message": "Welcome to the QGenie Chatbot AI! (Refactored)", + "version": "2.0.0", + "endpoints": { + "chat": "/api/v1/chat", + "annotator": "/api/v1/annotator", + "health": "/api/v1/health", + "detailed_health": "/api/v1/health/detailed" + } + } if __name__ == "__main__": - # 8080 포트에서 헬스체크 앱 실행 - app.run(host="0.0.0.0", port=33332) \ No newline at end of file + import uvicorn + + # 포트 번호 고정 (기존 설정 유지) + free_port = 35816 + + # 할당된 포트 번호를 콘솔에 특정 형식으로 출력 (Electron 연동을 위해) + print(f"PYTHON_SERVER_PORT:{free_port}") + + # FastAPI 서버 실행 + uvicorn.run( + app, + host="127.0.0.1", + port=free_port, + reload=False, + log_level="info" + ) \ No newline at end of file diff --git a/src/prompts/v1/sql_agent/db_classifier.yaml b/src/prompts/v1/sql_agent/db_classifier.yaml new file mode 100644 index 0000000..aa89aaa --- /dev/null +++ b/src/prompts/v1/sql_agent/db_classifier.yaml @@ -0,0 +1,19 @@ +_type: "prompt" +input_variables: + - db_options + - chat_history + - question +template: | + Based on the user's question, which of the following databases is most likely to contain the answer? + Please respond with only the database name. + + Available databases: + {db_options} + + conversation History: + {chat_history} + + User Question: + {question} + + Selected Database: \ No newline at end of file diff --git a/src/prompts/v1/sql_agent/intent_classifier.yaml b/src/prompts/v1/sql_agent/intent_classifier.yaml new file mode 100644 index 0000000..9aa3a34 --- /dev/null +++ b/src/prompts/v1/sql_agent/intent_classifier.yaml @@ -0,0 +1,39 @@ + +_type: prompt +input_variables: + - question + - chat_history +template: | + You are an intelligent assistant responsible for classifying user questions. + Your task is to determine whether a user's question is related to retrieving information from a database using SQL. + + - If the question can be answered with a SQL query, respond with "SQL". + - If the question is a simple greeting, a question about your identity, or anything that does not require database access, respond with "non-SQL". + + Consider the chat history context when classifying the current question. + If the current question is a follow-up or continuation of a previous SQL-related conversation, classify it as "SQL". + + Example 1: + Question: "Show me the list of users who signed up last month." + Classification: SQL + + Example 2: + Question: "What is the total revenue for the last quarter?" + Classification: SQL + + Example 3: + Question: "Hello, who are you?" + Classification: non-SQL + + Example 4: + Question: "What is the weather like today?" + Classification: non-SQL + + Example 5 (Follow-up): + Previous: "Show me sales data for January" + Current: "How about February?" + Classification: SQL (continuation of data query) + + Chat History: {chat_history} + Current Question: {question} + Classification: diff --git a/src/prompts/v1/sql_agent/response_synthesizer.yaml b/src/prompts/v1/sql_agent/response_synthesizer.yaml new file mode 100644 index 0000000..2fd970c --- /dev/null +++ b/src/prompts/v1/sql_agent/response_synthesizer.yaml @@ -0,0 +1,29 @@ +_type: prompt +input_variables: + - question + - chat_history + - context_message +template: | + You are a friendly and helpful database assistant chatbot. + Your goal is to provide a clear and easy-to-understand final answer to the user in Korean. + Please carefully analyze the user's question and the provided context below. + + User's Question: {question} + + Context: + {context_message} + + conversation History: + {chat_history} + + Instructions: + - If the process was successful: + - Do not just show the raw data from the SQL result. + - Explain what the data means in relation to the user's question. + - Present the answer in a natural, conversational, and polite Korean. + - If the process failed: + - Apologize for the inconvenience. + - Explain the reason for the failure in simple, non-technical terms. + - Gently suggest trying a different or simpler question. + + Final Answer (in Korean): \ No newline at end of file diff --git a/src/prompts/v1/sql_agent/sql_generator.yaml b/src/prompts/v1/sql_agent/sql_generator.yaml new file mode 100644 index 0000000..3b9ca84 --- /dev/null +++ b/src/prompts/v1/sql_agent/sql_generator.yaml @@ -0,0 +1,19 @@ +_type: prompt +input_variables: + - format_instructions + - db_schema + - chat_history + - question + - error_feedback +template: | + You are a powerful text-to-SQL model. + Your role is to generate a SQL query based on the provided database schema and user question. + + Schema: {db_schema} + History: {chat_history} + + {error_feedback} + + Question: {question} + + {format_instructions} \ No newline at end of file diff --git a/src/schemas/__init__.py b/src/schemas/__init__.py new file mode 100644 index 0000000..2aec2f4 --- /dev/null +++ b/src/schemas/__init__.py @@ -0,0 +1,6 @@ +""" +스키마 루트 패키지 +""" + + + diff --git a/src/schemas/agent/sql_schemas.py b/src/schemas/agent/sql_schemas.py new file mode 100644 index 0000000..0fcbc2b --- /dev/null +++ b/src/schemas/agent/sql_schemas.py @@ -0,0 +1,7 @@ +# src/schemas/agent/sql_schemas.py + +from pydantic import BaseModel, Field + +class SqlQuery(BaseModel): + """SQL 쿼리를 나타내는 Pydantic 모델""" + query: str = Field(description="생성된 SQL 쿼리") diff --git a/src/schemas/api/__init__.py b/src/schemas/api/__init__.py new file mode 100644 index 0000000..93577a6 --- /dev/null +++ b/src/schemas/api/__init__.py @@ -0,0 +1,6 @@ +""" +API 스키마 패키지 +""" + + + diff --git a/src/schemas/api/annotator_schemas.py b/src/schemas/api/annotator_schemas.py new file mode 100644 index 0000000..6dda6d7 --- /dev/null +++ b/src/schemas/api/annotator_schemas.py @@ -0,0 +1,60 @@ +# src/schemas/api/annotator_schemas.py + +from pydantic import BaseModel, Field +from typing import List, Dict, Any + +class Column(BaseModel): + """데이터베이스 컬럼 모델""" + column_name: str + data_type: str + +class Table(BaseModel): + """데이터베이스 테이블 모델""" + table_name: str + columns: List[Column] + sample_rows: List[Dict[str, Any]] + +class Relationship(BaseModel): + """테이블 관계 모델""" + from_table: str + from_columns: List[str] + to_table: str + to_columns: List[str] + +class Database(BaseModel): + """데이터베이스 모델""" + database_name: str + tables: List[Table] + relationships: List[Relationship] + +class AnnotationRequest(BaseModel): + """어노테이션 요청 모델""" + dbms_type: str + databases: List[Database] + +class AnnotatedColumn(BaseModel): + """어노테이션이 추가된 컬럼 모델""" + column_name: str + description: str = Field(..., description="AI가 생성한 컬럼 설명") + +class AnnotatedTable(BaseModel): + """어노테이션이 추가된 테이블 모델""" + table_name: str + description: str = Field(..., description="AI가 생성한 테이블 설명") + columns: List[AnnotatedColumn] + +class AnnotatedRelationship(Relationship): + """어노테이션이 추가된 관계 모델""" + description: str = Field(..., description="AI가 생성한 관계 설명") + +class AnnotatedDatabase(BaseModel): + """어노테이션이 추가된 데이터베이스 모델""" + database_name: str + description: str = Field(..., description="AI가 생성한 데이터베이스 설명") + tables: List[AnnotatedTable] + relationships: List[AnnotatedRelationship] + +class AnnotationResponse(BaseModel): + """어노테이션 응답 모델""" + dbms_type: str + databases: List[AnnotatedDatabase] diff --git a/src/schemas/api/chat_schemas.py b/src/schemas/api/chat_schemas.py new file mode 100644 index 0000000..c18d956 --- /dev/null +++ b/src/schemas/api/chat_schemas.py @@ -0,0 +1,18 @@ +# src/schemas/api/chat_schemas.py + +from pydantic import BaseModel +from typing import List, Optional + +class ChatMessage(BaseModel): + """대화 기록의 단일 메시지를 나타내는 모델""" + role: str # "user" 또는 "assistant" + content: str + +class ChatRequest(BaseModel): + """채팅 요청 모델""" + question: str + chat_history: Optional[List[ChatMessage]] = None + +class ChatResponse(BaseModel): + """채팅 응답 모델""" + answer: str diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..5d88207 --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,18 @@ +# src/services/__init__.py + +""" +서비스 계층 - 비즈니스 로직 구성 요소들 +""" + +from .chat.chatbot_service import ChatbotService, get_chatbot_service +from .annotation.annotation_service import AnnotationService, get_annotation_service +from .database.database_service import DatabaseService, get_database_service + +__all__ = [ + 'ChatbotService', + 'get_chatbot_service', + 'AnnotationService', + 'get_annotation_service', + 'DatabaseService', + 'get_database_service' +] diff --git a/src/services/annotation/__init__.py b/src/services/annotation/__init__.py new file mode 100644 index 0000000..85f7091 --- /dev/null +++ b/src/services/annotation/__init__.py @@ -0,0 +1,6 @@ +""" +어노테이션 서비스 패키지 +""" + + + diff --git a/src/services/annotation/annotation_service.py b/src/services/annotation/annotation_service.py new file mode 100644 index 0000000..356e3f9 --- /dev/null +++ b/src/services/annotation/annotation_service.py @@ -0,0 +1,260 @@ +# src/services/annotation/annotation_service.py + +import asyncio +from typing import List, Dict, Any +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser + +from schemas.api.annotator_schemas import ( + AnnotationRequest, AnnotationResponse, + Database, Table, Column, Relationship, + AnnotatedDatabase, AnnotatedTable, AnnotatedColumn, AnnotatedRelationship +) +from core.providers.llm_provider import LLMProvider, get_llm_provider +import logging + +logger = logging.getLogger(__name__) + +class AnnotationService: + """어노테이션 생성과 관련된 모든 비즈니스 로직을 담당하는 서비스 클래스""" + + def __init__(self, llm_provider: LLMProvider = None): + self.llm_provider = llm_provider + + async def _initialize_dependencies(self): + """필요한 의존성들을 초기화합니다.""" + if self.llm_provider is None: + self.llm_provider = await get_llm_provider() + + async def _generate_description(self, template: str, **kwargs) -> str: + """LLM을 비동기적으로 호출하여 설명을 생성하는 헬퍼 함수""" + try: + await self._initialize_dependencies() + + prompt = ChatPromptTemplate.from_template(template) + llm = await self.llm_provider.get_llm() + chain = prompt | llm | StrOutputParser() + + result = await chain.ainvoke(kwargs) + return result.strip() + + except Exception as e: + logger.error(f"Failed to generate description: {e}") + return f"설명 생성 실패: {e}" + + async def _annotate_column( + self, + table_name: str, + sample_rows: str, + column: Column + ) -> AnnotatedColumn: + """단일 컬럼을 비동기적으로 어노테이트합니다.""" + try: + column_desc = await self._generate_description( + """ + 테이블 '{table_name}'의 컬럼 '{column_name}'(타입: {data_type})의 역할을 한국어로 간결하게 설명해줘. + 샘플 데이터: {sample_rows} + """, + table_name=table_name, + column_name=column.column_name, + data_type=column.data_type, + sample_rows=sample_rows + ) + + return AnnotatedColumn( + **column.model_dump(), + description=column_desc + ) + + except Exception as e: + logger.error(f"Failed to annotate column {column.column_name}: {e}") + return AnnotatedColumn( + **column.model_dump(), + description=f"설명 생성 실패: {e}" + ) + + async def _annotate_table(self, db_name: str, table: Table) -> AnnotatedTable: + """단일 테이블과 그 컬럼들을 비동기적으로 어노테이트합니다.""" + try: + sample_rows_str = str(table.sample_rows[:3]) + + # 테이블 설명 생성과 모든 컬럼 설명을 동시에 병렬로 처리 + table_desc_task = self._generate_description( + "데이터베이스 '{db_name}'에 속한 테이블 '{table_name}'의 역할을 한국어로 간결하게 설명해줘.", + db_name=db_name, + table_name=table.table_name + ) + + column_tasks = [ + self._annotate_column(table.table_name, sample_rows_str, col) + for col in table.columns + ] + + # 모든 작업을 병렬 실행 + results = await asyncio.gather( + table_desc_task, + *column_tasks, + return_exceptions=True + ) + + # 결과 처리 + table_desc = results[0] if not isinstance(results[0], Exception) else "테이블 설명 생성 실패" + annotated_columns = [ + result for result in results[1:] + if not isinstance(result, Exception) + ] + + return AnnotatedTable( + **table.model_dump(exclude={'columns'}), + description=table_desc, + columns=annotated_columns + ) + + except Exception as e: + logger.error(f"Failed to annotate table {table.table_name}: {e}") + # 실패 시 기본 어노테이션 반환 + annotated_columns = [ + AnnotatedColumn(**col.model_dump(), description="설명 생성 실패") + for col in table.columns + ] + return AnnotatedTable( + **table.model_dump(exclude={'columns'}), + description=f"테이블 설명 생성 실패: {e}", + columns=annotated_columns + ) + + async def _annotate_relationship(self, relationship: Relationship) -> AnnotatedRelationship: + """단일 관계를 비동기적으로 어노테이트합니다.""" + try: + rel_desc = await self._generate_description( + """ + 테이블 '{from_table}'이(가) 테이블 '{to_table}'을(를) 참조하고 있습니다. + 이 관계를 한국어 문장으로 설명해줘. + """, + from_table=relationship.from_table, + to_table=relationship.to_table + ) + + return AnnotatedRelationship( + **relationship.model_dump(), + description=rel_desc + ) + + except Exception as e: + logger.error(f"Failed to annotate relationship: {e}") + return AnnotatedRelationship( + **relationship.model_dump(), + description=f"관계 설명 생성 실패: {e}" + ) + + async def generate_for_schema(self, request: AnnotationRequest) -> AnnotationResponse: + """입력된 스키마 전체에 대한 어노테이션을 비동기적으로 생성합니다.""" + try: + logger.info(f"Starting annotation generation for {len(request.databases)} databases") + + annotated_databases = [] + + for db in request.databases: + try: + # DB 설명, 모든 테이블, 모든 관계 설명을 동시에 병렬로 처리 + db_desc_task = self._generate_description( + "데이터베이스 '{db_name}'의 역할을 한국어로 간결하게 설명해줘.", + db_name=db.database_name + ) + + table_tasks = [ + self._annotate_table(db.database_name, table) + for table in db.tables + ] + + relationship_tasks = [ + self._annotate_relationship(rel) + for rel in db.relationships + ] + + # 모든 작업을 병렬 실행 + all_results = await asyncio.gather( + db_desc_task, + *table_tasks, + *relationship_tasks, + return_exceptions=True + ) + + # 결과 분리 + db_desc = all_results[0] if not isinstance(all_results[0], Exception) else "DB 설명 생성 실패" + + num_tables = len(table_tasks) + annotated_tables = [ + result for result in all_results[1:1+num_tables] + if not isinstance(result, Exception) + ] + + annotated_relationships = [ + result for result in all_results[1+num_tables:] + if not isinstance(result, Exception) + ] + + annotated_databases.append( + AnnotatedDatabase( + database_name=db.database_name, + description=db_desc, + tables=annotated_tables, + relationships=annotated_relationships + ) + ) + + logger.info(f"Completed annotation for database: {db.database_name}") + + except Exception as e: + logger.error(f"Failed to annotate database {db.database_name}: {e}") + # 실패한 데이터베이스도 기본값으로 포함 + annotated_databases.append( + AnnotatedDatabase( + database_name=db.database_name, + description=f"데이터베이스 어노테이션 생성 실패: {e}", + tables=[], + relationships=[] + ) + ) + + logger.info("Annotation generation completed successfully") + + return AnnotationResponse( + dbms_type=request.dbms_type, + databases=annotated_databases + ) + + except Exception as e: + logger.error(f"Failed to generate annotations: {e}") + # 전체 실패 시 기본 응답 반환 + return AnnotationResponse( + dbms_type=request.dbms_type, + databases=[] + ) + + async def health_check(self) -> Dict[str, Any]: + """어노테이션 서비스의 상태를 확인합니다.""" + try: + await self._initialize_dependencies() + + # LLM 연결 테스트 + llm_status = await self.llm_provider.test_connection() + + return { + "status": "healthy" if llm_status else "unhealthy", + "llm_provider": "connected" if llm_status else "disconnected" + } + + except Exception as e: + logger.error(f"Annotation service health check failed: {e}") + return { + "status": "unhealthy", + "error": str(e) + } + +# 싱글톤 인스턴스 +_annotation_service = AnnotationService() + +async def get_annotation_service() -> AnnotationService: + """Annotation Service 인스턴스를 반환합니다.""" + return _annotation_service diff --git a/src/services/chat/__init__.py b/src/services/chat/__init__.py new file mode 100644 index 0000000..3a90e18 --- /dev/null +++ b/src/services/chat/__init__.py @@ -0,0 +1,6 @@ +""" +챗 서비스 패키지 +""" + + + diff --git a/src/services/chat/chatbot_service.py b/src/services/chat/chatbot_service.py new file mode 100644 index 0000000..5f2d3ff --- /dev/null +++ b/src/services/chat/chatbot_service.py @@ -0,0 +1,142 @@ +# src/services/chat/chatbot_service.py + +import asyncio +from typing import List, Optional, Dict, Any +from langchain_core.messages import HumanMessage, AIMessage, BaseMessage + +from schemas.api.chat_schemas import ChatMessage +from agents.sql_agent.graph import SqlAgentGraph +from core.providers.llm_provider import LLMProvider, get_llm_provider +from services.database.database_service import DatabaseService, get_database_service +import logging + +logger = logging.getLogger(__name__) + +class ChatbotService: + """챗봇 관련 비즈니스 로직을 담당하는 서비스 클래스""" + + def __init__( + self, + llm_provider: LLMProvider = None, + database_service: DatabaseService = None + ): + self.llm_provider = llm_provider + self.database_service = database_service + self._sql_agent_graph: Optional[SqlAgentGraph] = None + + async def _initialize_dependencies(self): + """필요한 의존성들을 초기화합니다.""" + if self.llm_provider is None: + self.llm_provider = await get_llm_provider() + + if self.database_service is None: + self.database_service = await get_database_service() + + if self._sql_agent_graph is None: + self._sql_agent_graph = SqlAgentGraph( + self.llm_provider, + self.database_service + ) + + async def handle_request( + self, + user_question: str, + chat_history: Optional[List[ChatMessage]] = None + ) -> str: + """채팅 요청을 처리하고 응답을 반환합니다.""" + try: + # 의존성 초기화 + await self._initialize_dependencies() + + # 채팅 히스토리를 LangChain 메시지로 변환 + langchain_messages = await self._convert_chat_history(chat_history) + + # 초기 상태 구성 + initial_state = { + "question": user_question, + "chat_history": langchain_messages, + "validation_error_count": 0, + "execution_error_count": 0 + } + + # SQL Agent 그래프 실행 + final_state = await self._sql_agent_graph.run(initial_state) + + return final_state.get('final_response', "죄송합니다. 응답을 생성할 수 없습니다.") + + except Exception as e: + logger.error(f"Chat request handling failed: {e}") + # 에러 상황에서는 예외를 다시 발생시켜 라우터에서 HTTP 에러로 처리되도록 함 + raise e + + async def _convert_chat_history( + self, + chat_history: Optional[List[ChatMessage]] + ) -> List[BaseMessage]: + """채팅 히스토리를 LangChain 메시지 형식으로 변환합니다.""" + langchain_messages: List[BaseMessage] = [] + + if chat_history: + for message in chat_history: + try: + if message.role == 'u': + langchain_messages.append(HumanMessage(content=message.content)) + elif message.role == 'a': + langchain_messages.append(AIMessage(content=message.content)) + except Exception as e: + logger.warning(f"Failed to convert message: {e}") + continue + + return langchain_messages + + async def health_check(self) -> Dict[str, Any]: + """챗봇 서비스의 상태를 확인합니다.""" + try: + await self._initialize_dependencies() + + # LLM 연결 테스트 + llm_status = await self.llm_provider.test_connection() + + # 데이터베이스 서비스 상태 확인 + db_status = await self.database_service.health_check() + + overall_status = llm_status and db_status + + return { + "status": "healthy" if overall_status else "unhealthy", + "llm_provider": "connected" if llm_status else "disconnected", + "database_service": "connected" if db_status else "disconnected" + } + + except Exception as e: + logger.error(f"Health check failed: {e}") + return { + "status": "unhealthy", + "error": str(e) + } + + async def get_available_databases(self) -> List[Dict[str, str]]: + """사용 가능한 데이터베이스 목록을 반환합니다.""" + try: + await self._initialize_dependencies() + databases = await self.database_service.get_available_databases() + + return [ + { + "name": db.database_name, + "description": db.description, + "connection": db.connection_name + } + for db in databases + ] + + except Exception as e: + logger.error(f"Failed to get available databases: {e}") + return [] + +# 싱글톤 인스턴스 +_chatbot_service = ChatbotService() + +async def get_chatbot_service() -> ChatbotService: + """Chatbot Service 인스턴스를 반환합니다.""" + return _chatbot_service diff --git a/src/services/database/__init__.py b/src/services/database/__init__.py new file mode 100644 index 0000000..435b095 --- /dev/null +++ b/src/services/database/__init__.py @@ -0,0 +1,6 @@ +""" +데이터베이스 서비스 패키지 +""" + + + diff --git a/src/services/database/database_service.py b/src/services/database/database_service.py new file mode 100644 index 0000000..752f8e3 --- /dev/null +++ b/src/services/database/database_service.py @@ -0,0 +1,247 @@ +# src/services/database/database_service.py + +import asyncio +from typing import List, Optional, Dict, Any +from core.clients.api_client import APIClient, DatabaseInfo, DBProfileInfo, get_api_client +import logging + +logger = logging.getLogger(__name__) + +class DatabaseService: + """ + 데이터베이스 관련 비즈니스 로직을 담당하는 서비스 클래스 + 지연 초기화를 지원하여 BE 서버가 늦게 시작되어도 작동합니다. + """ + + def __init__(self, api_client: APIClient = None): + self.api_client = api_client + self._cached_db_profiles: Optional[List[DBProfileInfo]] = None + self._cached_annotations: Dict[str, Dict[str, Any]] = {} + # 호환성을 위해 유지하지만 더 이상 사용하지 않음 + self._cached_databases: Optional[List[DatabaseInfo]] = None + self._cached_schemas: Dict[str, str] = {} + # 지연 초기화 관련 플래그 + self._connection_attempted: bool = False + self._connection_failed: bool = False + + async def _get_api_client(self) -> APIClient: + """API 클라이언트를 가져옵니다.""" + if self.api_client is None: + self.api_client = await get_api_client() + return self.api_client + + async def get_available_databases(self) -> List[DatabaseInfo]: + """ + [DEPRECATED] 사용 가능한 데이터베이스 목록을 가져옵니다. + 대신 get_databases_with_annotations()를 사용하세요. + """ + logger.warning("get_available_databases()는 deprecated입니다. get_databases_with_annotations()를 사용하세요.") + + # DBMS 프로필 기반으로 DatabaseInfo 형태로 변환하여 호환성 유지 + try: + profiles = await self.get_db_profiles() + databases = [] + + for profile in profiles: + db_info = DatabaseInfo( + connection_name=f"{profile.type}_{profile.host}_{profile.port}", + database_name=profile.view_name or f"{profile.type}_db", + description=f"{profile.type} 데이터베이스 ({profile.host}:{profile.port})" + ) + databases.append(db_info) + + return databases + + except Exception as e: + logger.error(f"Failed to fetch databases: {e}") + raise RuntimeError(f"데이터베이스 목록을 가져올 수 없습니다. 백엔드 서버를 확인해주세요: {e}") + + async def get_schema_for_db(self, db_name: str) -> str: + """특정 데이터베이스의 스키마를 가져옵니다.""" + try: + if db_name not in self._cached_schemas: + api_client = await self._get_api_client() + schema = await api_client.get_database_schema(db_name) + self._cached_schemas[db_name] = schema + logger.info(f"Cached schema for database: {db_name}") + + return self._cached_schemas[db_name] + + except Exception as e: + logger.error(f"Failed to fetch schema for {db_name}: {e}") + raise RuntimeError(f"데이터베이스 '{db_name}' 스키마를 가져올 수 없습니다. 백엔드 서버를 확인해주세요: {e}") + + async def execute_query(self, sql_query: str, database_name: str = None, user_db_id: str = None) -> str: + """SQL 쿼리를 실행하고 결과를 반환합니다.""" + try: + if not database_name: + logger.warning("Database name not provided, using default") + database_name = "default" + + logger.info(f"Executing SQL query on database '{database_name}': {sql_query}") + + api_client = await self._get_api_client() + response = await api_client.execute_query( + sql_query=sql_query, + database_name=database_name, + user_db_id=user_db_id + ) + + # 백엔드 응답 코드 확인 + if response.code == "2400": + logger.info(f"Query executed successfully: {response.message}") + + # 응답 데이터 형태에 따라 다른 메시지 반환 + if hasattr(response.data, 'columns') and hasattr(response.data, 'data'): + # 쿼리 결과 데이터가 있는 경우 + row_count = len(response.data.data) + col_count = len(response.data.columns) + return f"쿼리가 성공적으로 실행되었습니다. {row_count}개 행, {col_count}개 컬럼의 결과를 반환했습니다." + else: + # 일반적인 성공 메시지 + return "쿼리가 성공적으로 실행되었습니다." + else: + # data에 에러 메시지가 있는지 확인 + error_detail = "" + if isinstance(response.data, str): + error_detail = f" 상세: {response.data}" + + error_msg = f"쿼리 실행 실패: {response.message} (코드: {response.code}){error_detail}" + logger.error(error_msg) + return error_msg + + except Exception as e: + logger.error(f"Error during query execution: {e}") + return f"쿼리 실행 중 오류 발생: {e}" + + + + async def get_db_profiles(self) -> List[DBProfileInfo]: + """ + 모든 DBMS 프로필 정보를 가져옵니다. + 지연 초기화를 통해 BE 서버 연결이 실패해도 재시도합니다. + """ + if self._cached_db_profiles is None: + # 이전에 연결을 시도했고 실패했다면 재시도 + if self._connection_failed: + logger.info("🔄 DB 프로필 조회 재시도 중...") + self._connection_failed = False + self._connection_attempted = False + + try: + self._connection_attempted = True + api_client = await self._get_api_client() + self._cached_db_profiles = await api_client.get_db_profiles() + self._connection_failed = False + logger.info(f"✅ DB 프로필 조회 성공: {len(self._cached_db_profiles)}개") + + except Exception as e: + self._connection_failed = True + logger.error(f"❌ DB 프로필 조회 실패: {e}") + raise RuntimeError(f"DB 프로필 목록을 가져올 수 없습니다. 백엔드 서버가 실행 중인지 확인해주세요: {e}") + + return self._cached_db_profiles + + async def get_db_annotations(self, db_profile_id: str) -> Dict[str, Any]: + """특정 DBMS의 어노테이션을 조회합니다.""" + try: + if db_profile_id not in self._cached_annotations: + api_client = await self._get_api_client() + annotations = await api_client.get_db_annotations(db_profile_id) + self._cached_annotations[db_profile_id] = annotations + + if annotations.get("code") == "4401": + logger.info(f"No annotations available for DB profile: {db_profile_id}") + else: + logger.info(f"Cached annotations for DB profile: {db_profile_id}") + + return self._cached_annotations[db_profile_id] + + except Exception as e: + logger.error(f"Failed to fetch annotations for {db_profile_id}: {e}") + # 어노테이션이 없어도 기본 정보는 반환하도록 변경 + return {"code": "4401", "message": "어노테이션이 없습니다", "data": []} + + async def get_databases_with_annotations(self) -> List[Dict[str, Any]]: + """DB 프로필과 어노테이션을 함께 조회합니다.""" + try: + profiles = await self.get_db_profiles() + result = [] + + for profile in profiles: + annotations = await self.get_db_annotations(profile.id) + db_info = { + "profile": profile.model_dump(), + "annotations": annotations, + "display_name": profile.view_name or f"{profile.type}_{profile.host}_{profile.port}", + "description": self._generate_db_description(profile, annotations) + } + result.append(db_info) + + return result + + except Exception as e: + logger.error(f"Failed to get databases with annotations: {e}") + raise RuntimeError(f"어노테이션이 포함된 데이터베이스 목록을 가져올 수 없습니다: {e}") + + def _generate_db_description(self, profile: DBProfileInfo, annotations: Dict[str, Any]) -> str: + """DB 프로필과 어노테이션을 기반으로 설명을 생성합니다.""" + try: + # 기본 설명 + base_desc = f"{profile.type} 데이터베이스" + + if profile.view_name: + base_desc += f" ({profile.view_name})" + else: + base_desc += f" ({profile.host}:{profile.port})" + + # 어노테이션 정보 확인 + if annotations and annotations.get("code") != "4401" and "data" in annotations: + # 실제 어노테이션이 있는 경우 + base_desc += " - 어노테이션 정보 포함" + + return base_desc + + except Exception as e: + logger.warning(f"Failed to generate description: {e}") + return f"{profile.type} 데이터베이스" + + async def refresh_cache(self): + """캐시를 새로고침합니다.""" + self._cached_db_profiles = None + self._cached_annotations.clear() + # 호환성을 위해 유지 + self._cached_databases = None + self._cached_schemas.clear() + # 지연 초기화 플래그 리셋 + self._connection_attempted = False + self._connection_failed = False + logger.info("Database cache refreshed") + + async def clear_cache(self): + """캐시를 클리어합니다.""" + self._cached_db_profiles = None + self._cached_annotations.clear() + # 호환성을 위해 유지 + self._cached_databases = None + self._cached_schemas.clear() + # 지연 초기화 플래그 리셋 + self._connection_attempted = False + self._connection_failed = False + logger.info("Database cache cleared") + + async def health_check(self) -> bool: + """데이터베이스 서비스 상태를 확인합니다.""" + try: + api_client = await self._get_api_client() + return await api_client.health_check() + except Exception as e: + logger.error(f"Database service health check failed: {e}") + return False + +# 싱글톤 인스턴스 +_database_service = DatabaseService() + +async def get_database_service() -> DatabaseService: + """Database Service 인스턴스를 반환합니다.""" + return _database_service diff --git a/test_services.py b/test_services.py new file mode 100644 index 0000000..2fb2b4a --- /dev/null +++ b/test_services.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 +""" +서비스 테스트 스크립트 +""" + +import asyncio +import sys +import os + +# src 디렉토리를 Python 경로에 추가 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) + +async def test_llm_provider(): + """LLM Provider 테스트""" + print("🔍 LLM Provider 테스트 중...") + try: + from core.providers.llm_provider import get_llm_provider + + provider = await get_llm_provider() + print(f"✅ LLM Provider 생성 성공: {provider.model_name}") + + # 연결 테스트 + is_connected = await provider.test_connection() + print(f"🔗 LLM 연결 상태: {'성공' if is_connected else '실패'}") + + # API 키 소스 확인 (로그에서 확인 가능) + print("💡 백엔드에서 API 키를 가져옵니다") + + except Exception as e: + print(f"❌ LLM Provider 테스트 실패: {e}") + +async def test_api_client(): + """API Client 테스트""" + print("\n🔍 API Client 테스트 중...") + try: + from core.clients.api_client import get_api_client + + client = await get_api_client() + print("✅ API Client 생성 성공") + + # OpenAI API 키 조회 테스트 + try: + api_key = await client.get_openai_api_key() + print(f"🔑 OpenAI API 키 조회 성공: {api_key[:20]}...") + except Exception as e: + print(f"⚠️ OpenAI API 키 조회 실패: {e}") + + # 헬스체크 테스트 + try: + is_healthy = await client.health_check() + print(f"🏥 백엔드 서버 상태: {'정상' if is_healthy else '비정상'}") + except Exception as e: + print(f"⚠️ 백엔드 서버 연결 실패: {e}") + + except Exception as e: + print(f"❌ API Client 테스트 실패: {e}") + +async def test_db_annotation_api(): + """DB 어노테이션 API 테스트""" + print("\n🔍 DB 어노테이션 API 테스트 중...") + try: + from services.database.database_service import get_database_service + + service = await get_database_service() + + # DB 프로필 조회 테스트 + try: + profiles = await service.get_db_profiles() + print(f"✅ DB 프로필 조회 성공: {len(profiles)}개") + + if profiles: + print(f"📝 첫 번째 프로필: {profiles[0].type} - {profiles[0].view_name or 'No view name'}") + + # 첫 번째 프로필의 어노테이션 조회 테스트 + try: + annotations = await service.get_db_annotations(profiles[0].id) + print(f"✅ 어노테이션 조회 성공: {profiles[0].id}") + except Exception as e: + print(f"⚠️ 어노테이션 조회 실패: {e}") + + # 통합 조회 테스트 + try: + dbs_with_annotations = await service.get_databases_with_annotations() + print(f"✅ 통합 조회 성공: {len(dbs_with_annotations)}개") + + if dbs_with_annotations: + first_db = dbs_with_annotations[0] + print(f"📝 첫 번째 DB 정보:") + print(f" - Display Name: {first_db['display_name']}") + print(f" - Description: {first_db['description']}") + print(f" - Has Annotations: {'data' in first_db['annotations']}") + + except Exception as e: + print(f"⚠️ 통합 조회 실패: {e}") + else: + print("⚠️ DB 프로필이 없습니다.") + + except Exception as e: + print(f"⚠️ DB 프로필 조회 실패: {e}") + + except Exception as e: + print(f"❌ DB 어노테이션 API 테스트 실패: {e}") + +async def test_database_service(): + """Database Service 테스트""" + print("\n🔍 Database Service 테스트 중...") + try: + from services.database.database_service import get_database_service + + service = await get_database_service() + print("✅ Database Service 생성 성공") + + # 사용 가능한 데이터베이스 목록 조회 + try: + databases = await service.get_available_databases() + print(f"🗄️ 사용 가능한 데이터베이스: {len(databases)}개") + print("✅ 백엔드 API에서 데이터베이스 목록을 성공적으로 가져왔습니다") + + for db in databases[:3]: # 처음 3개만 출력 + print(f" - {db.database_name}: {db.description}") + except Exception as e: + print(f"⚠️ 데이터베이스 목록 조회 실패: {e}") + + except Exception as e: + print(f"❌ Database Service 테스트 실패: {e}") + +async def test_annotation_service(): + """Annotation Service 테스트""" + print("\n🔍 Annotation Service 테스트 중...") + try: + from services.annotation.annotation_service import get_annotation_service + + service = await get_annotation_service() + print("✅ Annotation Service 생성 성공") + + # 헬스체크 테스트 + try: + health = await service.health_check() + print(f"🏥 어노테이션 서비스 상태: {health}") + except Exception as e: + print(f"⚠️ 어노테이션 서비스 헬스체크 실패: {e}") + + except Exception as e: + print(f"❌ Annotation Service 테스트 실패: {e}") + +async def test_chatbot_service(): + """Chatbot Service 테스트""" + print("\n🔍 Chatbot Service 테스트 중...") + try: + from services.chat.chatbot_service import get_chatbot_service + + service = await get_chatbot_service() + print("✅ Chatbot Service 생성 성공") + + # 헬스체크 테스트 + try: + health = await service.health_check() + print(f"🏥 챗봇 서비스 상태: {health}") + except Exception as e: + print(f"⚠️ 챗봇 서비스 헬스체크 실패: {e}") + + except Exception as e: + print(f"❌ Chatbot Service 테스트 실패: {e}") + +async def test_sql_agent(): + """SQL Agent 테스트""" + print("\n🔍 SQL Agent 테스트 중...") + try: + from agents.sql_agent.graph import SqlAgentGraph + from core.providers.llm_provider import get_llm_provider + from services.database.database_service import get_database_service + + llm_provider = await get_llm_provider() + db_service = await get_database_service() + + agent = SqlAgentGraph(llm_provider, db_service) + print("✅ SQL Agent 생성 성공") + + # 그래프 시각화 PNG 저장 + try: + success = agent.save_graph_visualization("sql_agent_workflow.png") + if success: + print("📊 그래프 시각화 PNG 저장 성공: sql_agent_workflow.png") + else: + print("⚠️ 그래프 시각화 PNG 저장 실패") + except Exception as e: + print(f"⚠️ 그래프 시각화 생성 실패: {e}") + + except Exception as e: + print(f"❌ SQL Agent 테스트 실패: {e}") + +async def test_end_to_end_chat(): + """실제 채팅 요청 End-to-End 테스트""" + print("\n🔍 End-to-End 채팅 테스트 중...") + try: + from services.chat.chatbot_service import get_chatbot_service + import time + + service = await get_chatbot_service() + + # SQL 관련 질문으로 테스트 + test_questions = [ + "사용자 테이블에서 모든 데이터를 조회해주세요", + "가장 많이 주문한 고객을 찾아주세요", + ] + + for i, question in enumerate(test_questions, 1): + print(f"🤖 테스트 질문 {i}: {question}") + start_time = time.time() + + try: + response = await service.handle_request(user_question=question) + end_time = time.time() + response_time = round(end_time - start_time, 2) + + print(f"✅ 응답 시간: {response_time}초") + print(f"📝 응답: {response[:100]}{'...' if len(response) > 100 else ''}") + except Exception as e: + print(f"❌ 질문 {i} 실패: {e}") + + print("---") + + except Exception as e: + print(f"❌ End-to-End 테스트 실패: {e}") + +async def test_annotation_functionality(): + """어노테이션 기능 실제 사용 테스트""" + print("\n🔍 어노테이션 기능 테스트 중...") + try: + from services.annotation.annotation_service import get_annotation_service + from schemas.api.annotator_schemas import Database, Table, Column + + service = await get_annotation_service() + + # 샘플 데이터로 어노테이션 테스트 + sample_database = Database( + database_name="test_db", + tables=[ + Table( + table_name="users", + columns=[ + Column(column_name="id", data_type="int"), + Column(column_name="name", data_type="varchar"), + Column(column_name="email", data_type="varchar") + ], + sample_rows=[{"id": 1, "name": "John Doe", "email": "john@example.com"}] + ) + ], + relationships=[] + ) + + try: + from schemas.api.annotator_schemas import AnnotationRequest + request = AnnotationRequest( + dbms_type="MySQL", + databases=[sample_database] + ) + result = await service.generate_for_schema(request) + print(f"✅ 어노테이션 생성 성공") + print(f"📝 생성된 데이터베이스 수: {len(result.databases)}") + if result.databases and result.databases[0].tables: + print(f"📝 첫 번째 테이블 설명: {result.databases[0].tables[0].description[:100]}...") + except Exception as e: + print(f"⚠️ 어노테이션 생성 실패: {e}") + + except Exception as e: + print(f"❌ 어노테이션 기능 테스트 실패: {e}") + +async def test_error_scenarios(): + """에러 시나리오 테스트""" + print("\n🔍 에러 시나리오 테스트 중...") + + # 잘못된 API 키로 LLM 테스트 + print("🧪 잘못된 API 키 시나리오...") + try: + from core.providers.llm_provider import LLMProvider + + # 일시적으로 잘못된 API 키 설정 테스트는 실제 환경에서는 위험하므로 스킵 + print("⚠️ 실제 환경에서는 API 키 에러 테스트 스킵") + + except Exception as e: + print(f"✅ 예상된 에러 발생: {e}") + + print("✅ 에러 시나리오 테스트 완료") + +async def main(): + """메인 테스트 함수""" + print("🚀 QGenie AI 서비스 테스트 시작\n") + + # 기본 서비스 테스트 + await test_llm_provider() + await test_api_client() + await test_annotation_service() + await test_db_annotation_api() # 새로운 DB 어노테이션 API 테스트 추가 + await test_database_service() + await test_chatbot_service() + await test_sql_agent() + + # 확장 테스트 (백엔드 연결이 가능한 경우에만) + try: + from core.clients.api_client import get_api_client + client = await get_api_client() + if await client.health_check(): + print("\n🧪 확장 테스트 시작 (백엔드 연결 확인됨)") + print("⚠️ 참고: 데이터베이스 API가 구현되지 않아 일부 테스트는 실패할 수 있습니다") + await test_end_to_end_chat() + await test_annotation_functionality() + await test_error_scenarios() + else: + print("\n⚠️ 백엔드 연결 불가 - 확장 테스트 스킵") + except Exception: + print("\n⚠️ 백엔드 연결 불가 - 확장 테스트 스킵") + + print("\n✨ 모든 테스트 완료!") + +if __name__ == "__main__": + asyncio.run(main())