Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 79 additions & 36 deletions app/services/annotation_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import os
import sqlite3
from datetime import datetime
from typing import Any

import httpx
from fastapi import Depends

from app.core.enum.constraint_type import ConstraintTypeEnum
Expand Down Expand Up @@ -38,9 +40,6 @@

user_db_service_dependency = Depends(lambda: user_db_service)

# AI 서버의 주소 (임시)
AI_SERVER_URL = "http://localhost:8001/api/v1/annotate/database"


class AnnotationService:
def __init__(
Expand All @@ -55,13 +54,24 @@ def __init__(
"""
self.repository = repository
self.user_db_service = user_db_serv
self._ai_server_url = None

def _get_ai_server_url(self) -> str:
"""AI 서버 URL을 한 번만 로드하고 캐싱하여 재사용합니다 (지연 로딩)."""
if self._ai_server_url is None:
url = os.getenv("ENV_AI_SERVER_URL")
if not url:
raise ValueError("환경 변수 'ENV_AI_SERVER_URL'가 설정되지 않았거나 .env 파일 로드에 실패했습니다.")
# URL 경로를 annotator로 변경
self._ai_server_url = url.replace("/chat", "/annotator")
return self._ai_server_url

async def create_annotation(self, request: AnnotationCreateRequest) -> FullAnnotationResponse:
"""
어노테이션 생성을 위한 전체 프로세스를 관장합니다.
1. DB 프로필, 전체 스키마 정보, 샘플 데이터 조회
2. AI 서버에 요청할 데이터 모델 생성
3. TODO: AI 서버에 요청 (현재는 Mock 데이터 사용)
3. AI 서버에 요청
4. 트랜잭션 내에서 전체 어노테이션 정보 저장 및 DB 프로필 업데이트
"""
logging.info(f"Starting annotation creation for db_profile_id: {request.db_profile_id}")
Expand All @@ -83,12 +93,12 @@ async def create_annotation(self, request: AnnotationCreateRequest) -> FullAnnot
# 2. AI 서버에 요청할 데이터 모델 생성
ai_request_body = self._prepare_ai_request_body(db_profile, full_schema_info, sample_rows)
logging.info("Prepared AI request body.")
logging.debug(f"AI Request Body: {ai_request_body.model_dump_json(indent=2)}")
logging.info(f"AI Request Body: {ai_request_body.model_dump_json(indent=2)}")

# 3. AI 서버에 요청 (현재는 Mock 데이터 사용)
# 3. AI 서버에 요청
ai_response = await self._request_annotation_to_ai_server(ai_request_body)
logging.info("Received AI response.")
logging.debug(f"AI Response: {ai_response}")
logging.info(f"AI Response: {ai_response}")

# 4. 트랜잭션 내에서 전체 어노테이션 정보 저장 및 DB 프로필 업데이트
db_path = get_db_path()
Expand Down Expand Up @@ -223,14 +233,20 @@ def _transform_ai_response_to_db_models(
now = datetime.now()
annotation_id = generate_prefixed_uuid(DBSaveIdEnum.database_annotation.value)

# AI 응답에서 데이터베이스 레벨의 정보 추출
db_data = ai_response.get("databases", [{}])[0]
db_description = db_data.get("description")
tables_data = db_data.get("tables", [])
relationships_data = db_data.get("relationships", [])

# 원본 스키마 정보를 쉽게 조회할 수 있도록 룩업 테이블 생성
schema_lookup: dict[str, UserDBTableInfo] = {table.name: table for table in full_schema_info}

db_anno = DatabaseAnnotationInDB(
id=annotation_id,
db_profile_id=db_profile_id,
database_name=db_profile.name or db_profile.username,
description=ai_response.get("database_annotation"),
description=db_description,
created_at=now,
updated_at=now,
)
Expand All @@ -251,7 +267,7 @@ def _transform_ai_response_to_db_models(
[],
)

for tbl_data in ai_response.get("tables", []):
for tbl_data in tables_data:
original_table = schema_lookup.get(tbl_data["table_name"])
if not original_table:
logging.warning(
Expand All @@ -266,7 +282,7 @@ def _transform_ai_response_to_db_models(
constraint_col_annos,
index_annos,
index_col_annos,
) = self._create_annotations_for_table(tbl_data, original_table, annotation_id, now)
) = self._create_annotations_for_table(tbl_data, original_table, annotation_id, now, relationships_data)

all_table_annos.append(table_anno)
all_col_annos.extend(col_annos)
Expand All @@ -291,6 +307,7 @@ def _create_annotations_for_table(
original_table: UserDBTableInfo,
database_annotation_id: str,
now: datetime,
relationships_data: list[dict[str, Any]],
) -> tuple:
"""
단일 테이블에 대한 모든 하위 어노테이션(컬럼, 제약조건, 인덱스)을 생성합니다.
Expand All @@ -300,7 +317,7 @@ def _create_annotations_for_table(
id=table_id,
database_annotation_id=database_annotation_id,
table_name=original_table.name,
description=tbl_data.get("annotation"),
description=tbl_data.get("description"),
created_at=now,
updated_at=now,
)
Expand All @@ -311,7 +328,7 @@ def _create_annotations_for_table(

col_annos = self._process_columns(tbl_data, original_table, table_id, col_map, now)
constraint_annos, constraint_col_annos = self._process_constraints(
tbl_data, original_table, table_id, col_map, now
original_table, table_id, col_map, now, relationships_data
)
index_annos, index_col_annos = self._process_indexes(tbl_data, original_table, table_id, col_map, now)

Expand All @@ -324,10 +341,12 @@ def _process_columns(
테이블의 컬럼 어노테이션 모델 리스트를 생성합니다.
"""
col_annos = []
for col_data in tbl_data.get("columns", []):
original_column = next((c for c in original_table.columns if c.name == col_data["column_name"]), None)
if not original_column:
continue
ai_columns_lookup = {c["column_name"]: c for c in tbl_data.get("columns", [])}

for original_column in original_table.columns:
col_data = ai_columns_lookup.get(original_column.name)
description = col_data.get("description") if col_data else None

col_annos.append(
ColumnAnnotationInDB(
id=col_map[original_column.name],
Expand All @@ -336,7 +355,7 @@ def _process_columns(
data_type=original_column.type,
is_nullable=1 if original_column.nullable else 0,
default_value=original_column.default,
description=col_data.get("annotation"),
description=description,
ordinal_position=original_column.ordinal_position,
created_at=now,
updated_at=now,
Expand All @@ -345,18 +364,32 @@ def _process_columns(
return col_annos

def _process_constraints(
self, tbl_data: dict, original_table: UserDBTableInfo, table_id: str, col_map: dict, now: datetime
self,
original_table: UserDBTableInfo,
table_id: str,
col_map: dict,
now: datetime,
relationships_data: list[dict[str, Any]],
) -> tuple[list[TableConstraintInDB], list[ConstraintColumnInDB]]:
"""
테이블의 제약조건 및 제약조건 컬럼 어노테이션 모델 리스트를 생성합니다.
AI 응답이 아닌 원본 스키마의 모든 제약조건을 기준으로 처리합니다.
AI 응답의 'relationships'를 기반으로 FK 제약조건의 설명을 매칭합니다.
"""
constraint_annos, constraint_col_annos = [], []
ai_constraints_lookup = {c["name"]: c for c in tbl_data.get("constraints", [])}

for original_constraint in original_table.constraints:
const_data = ai_constraints_lookup.get(original_constraint.name)
annotation = const_data.get("annotation") if const_data else None
annotation = None
# 외래 키 제약조건인 경우, AI 응답의 relationships에서 설명을 찾습니다.
if original_constraint.type == ConstraintTypeEnum.FOREIGN_KEY.value:
for rel in relationships_data:
if (
rel.get("from_table") == original_table.name
and set(rel.get("from_columns", [])) == set(original_constraint.columns)
and rel.get("to_table") == original_constraint.referenced_table
and set(rel.get("to_columns", [])) == set(original_constraint.referenced_columns)
):
annotation = rel.get("description")
break

const_id = generate_prefixed_uuid(DBSaveIdEnum.table_constraint.value)
constraint_annos.append(
Expand Down Expand Up @@ -458,20 +491,30 @@ def delete_annotation(self, annotation_id: str) -> AnnotationDeleteResponse:

async def _request_annotation_to_ai_server(self, ai_request: AIAnnotationRequest) -> dict:
"""AI 서버에 스키마 정보를 보내고 어노테이션을 받아옵니다."""
# 우선은 목업 데이터 활용
return self._get_mock_ai_response(ai_request)

# Real implementation below
# request_body = ai_request.model_dump()
# async with httpx.AsyncClient() as client:
# try:
# response = await client.post(AI_SERVER_URL, json=request_body, timeout=60.0)
# response.raise_for_status()
# return response.json()
# except httpx.HTTPStatusError as e:
# raise APIException(CommonCode.FAIL_AI_SERVER_PROCESSING, detail=f"AI server error: {e.response.text}") from e
# except httpx.RequestError as e:
# raise APIException(CommonCode.FAIL_AI_SERVER_CONNECTION, detail=f"AI server connection failed: {e}") from e
ai_server_url = self._get_ai_server_url()
request_body = ai_request.model_dump()

logging.info(f"Requesting annotation to AI server at {ai_server_url}")
logging.info(f"Request Body: {request_body}")

async with httpx.AsyncClient() as client:
try:
response = await client.post(ai_server_url, json=request_body, timeout=60.0)
response.raise_for_status()
ai_response = response.json()
logging.info("Successfully received annotation response from AI server.")
logging.info(f"AI Response JSON: {ai_response}")
return ai_response
except httpx.HTTPStatusError as e:
logging.error(f"AI server returned an error: {e.response.status_code} - {e.response.text}")
raise APIException(
CommonCode.FAIL_AI_SERVER_PROCESSING, detail=f"AI server error: {e.response.text}"
) from e
except httpx.RequestError as e:
logging.error(f"Failed to connect to AI server: {e}")
raise APIException(
CommonCode.FAIL_AI_SERVER_CONNECTION, detail=f"AI server connection failed: {e}"
) from e

def _get_mock_ai_response(self, ai_request: AIAnnotationRequest) -> dict:
"""테스트를 위한 Mock AI 서버 응답 생성"""
Expand Down