Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Backend/main.py
import os
from dotenv import load_dotenv
# 환경 변수를 최대한 빨리 로드하여 GPU 설정(CUDA_VISIBLE_DEVICES)이 라우터 임포트 전에 적용되도록 함
load_dotenv()
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
Expand All @@ -12,8 +14,12 @@
from routers.checklist import router as checklist_router
from routers.file import router as file_router


# 1) 환경변수 로드 (상단에서 선 로드됨)

import uvicorn


load_dotenv()

app = FastAPI()
Expand Down
2 changes: 2 additions & 0 deletions models/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declarative_base


Base = declarative_base()
18 changes: 18 additions & 0 deletions models/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@

from sqlalchemy.orm import relationship

from sqlalchemy import Column, Integer, String, ForeignKey, TIMESTAMP, text
from sqlalchemy.orm import relationship
from .base import Base
Expand All @@ -6,6 +9,20 @@ class File(Base):
__tablename__ = "file"

id = Column(Integer, primary_key=True, autoincrement=True)

user_id = Column(Integer, ForeignKey('user.u_id', ondelete='CASCADE'), nullable=False)
folder_id = Column(Integer, ForeignKey('folder.id', ondelete='SET NULL'), nullable=True)
note_id = Column(Integer, ForeignKey('note.id', ondelete='CASCADE'), nullable=True)
original_name = Column(String(255), nullable=False)
saved_path = Column(String(512), nullable=False)
content_type = Column(String(100), nullable=False)
created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'))

# ✅ 관계
user = relationship("User", back_populates="files")
folder = relationship("Folder", back_populates="files")
note = relationship("Note", back_populates="files")

user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False)
folder_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True)
note_id = Column(Integer, ForeignKey("note.id", ondelete="SET NULL"), nullable=True)
Expand All @@ -17,3 +34,4 @@ class File(Base):
# relations
user = relationship("User", back_populates="files")
note = relationship("Note", back_populates="files")

11 changes: 11 additions & 0 deletions models/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,21 @@ class Folder(Base):
parent_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True)
created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
updated_at = Column(TIMESTAMP, nullable=False,

server_default=text('CURRENT_TIMESTAMP'),
onupdate=text('CURRENT_TIMESTAMP'))

# ✅ 관계
user = relationship("User", back_populates="folders")
parent = relationship("Folder", remote_side=[id], backref="children")
notes = relationship("Note", back_populates="folder", cascade="all, delete")
files = relationship("File", back_populates="folder", cascade="all, delete")

server_default=text("CURRENT_TIMESTAMP"),
onupdate=text("CURRENT_TIMESTAMP"))

# relations
user = relationship("User")
parent = relationship("Folder", remote_side=[id], backref="children")
notes = relationship("Note", back_populates="folder", cascade="all, delete")

7 changes: 7 additions & 0 deletions models/note.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@ class Note(Base):
last_accessed = Column(TIMESTAMP, nullable=True)
created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
updated_at = Column(TIMESTAMP, nullable=False,

server_default=text('CURRENT_TIMESTAMP'),
onupdate=text('CURRENT_TIMESTAMP'))

# ✅ 관계

server_default=text("CURRENT_TIMESTAMP"),
onupdate=text("CURRENT_TIMESTAMP"))

# relations

user = relationship("User", back_populates="notes")
folder = relationship("Folder", back_populates="notes")
files = relationship("File", back_populates="note", cascade="all, delete")
16 changes: 16 additions & 0 deletions models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@ class User(Base):
id = Column(String(50), nullable=False, unique=True) # 로그인 ID 또는 소셜 ID
email = Column(String(150), nullable=False, unique=True)
password = Column(String(255), nullable=False)

provider = Column(
Enum('local','google','kakao','naver', name='provider_enum'),
nullable=False,
server_default=text("'local'")
)
created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'))
updated_at = Column(TIMESTAMP, nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
onupdate=text('CURRENT_TIMESTAMP'))

# ✅ 관계
folders = relationship("Folder", back_populates="user", cascade="all, delete")
notes = relationship("Note", back_populates="user", cascade="all, delete")
files = relationship("File", back_populates="user", cascade="all, delete")

provider = Column(Enum("local", "google", "kakao", "naver", name="provider_enum"),
nullable=False, server_default=text("'local'"))
created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
Expand Down
119 changes: 117 additions & 2 deletions routers/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,52 @@
from models.note import Note as NoteModel
from utils.jwt_utils import get_current_user

cuda_gpu
# 추가: 파일명 인코딩용
import urllib.parse

# -------------------------------
# 1) EasyOCR 라이브러리 임포트 (GPU 모드 활성화)
# -------------------------------
import easyocr
reader = easyocr.Reader(["ko", "en"], gpu=True)

# -------------------------------
# 2) Hugging Face TrOCR 모델용 파이프라인 (GPU 사용)
# -------------------------------
from transformers import pipeline

hf_trocr_printed = pipeline(
"image-to-text",
model="microsoft/trocr-base-printed",
device=0,
trust_remote_code=True
)
hf_trocr_handwritten = pipeline(
"image-to-text",
model="microsoft/trocr-base-handwritten",
device=0,
trust_remote_code=True
)
hf_trocr_small_printed = pipeline(
"image-to-text",
model="microsoft/trocr-small-printed",
device=0,
trust_remote_code=True
)
hf_trocr_large_printed = pipeline(
"image-to-text",
model="microsoft/trocr-large-printed",
device=0,
trust_remote_code=True
)

BASE_UPLOAD_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"uploads"
)

# 공통 OCR 파이프라인
from utils.ocr import run_pipeline, detect_type
from schemas.file import OCRResponse
Expand Down Expand Up @@ -79,11 +125,10 @@ async def upload_file(
orig_filename: str = upload_file.filename or "unnamed"
content_type: str = upload_file.content_type or "application/octet-stream"

# 사용자별 디렉토리 생성
user_dir = os.path.join(BASE_UPLOAD_DIR, str(current_user.u_id))
os.makedirs(user_dir, exist_ok=True)

# 원본 파일명 유지 (중복 방지)

saved_filename = orig_filename
saved_path = os.path.join(user_dir, saved_filename)
if os.path.exists(saved_path):
Expand All @@ -98,14 +143,17 @@ async def upload_file(
break
counter += 1


# 저장

try:
with open(saved_path, "wb") as buffer:
content = await upload_file.read()
buffer.write(content)
except Exception as e:
raise HTTPException(status_code=500, detail=f"파일 저장 실패: {e}")


# note_id가 있으면 해당 노트 확인
note_obj = None
if note_id is not None:
Expand All @@ -118,6 +166,7 @@ async def upload_file(
raise HTTPException(status_code=404, detail="해당 노트를 찾을 수 없습니다.")

# DB 메타 기록

new_file = FileModel(
user_id=current_user.u_id,
folder_id=None if note_id else folder_id,
Expand Down Expand Up @@ -202,6 +251,10 @@ def download_file(
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="서버에 파일이 존재하지 않습니다.")

# 원본 파일명 UTF-8 URL 인코딩 처리
quoted_name = urllib.parse.quote(file_obj.original_name, safe='')
content_disposition = f"inline; filename*=UTF-8''{quoted_name}"

return FileResponse(
path=file_path,
media_type=file_obj.content_type,
Expand Down Expand Up @@ -254,6 +307,67 @@ async def ocr_and_create_note(
db: Session = Depends(get_db),
current_user = Depends(get_current_user)
):

"""
• ocr_file: 이미지 파일(UploadFile)
• 1) EasyOCR로 기본 텍스트 추출 (GPU 모드)
• 2) TrOCR 4개 모델로 OCR 수행 (모두 GPU)
• 3) 가장 긴 결과를 최종 OCR 결과로 선택
• 4) Note로 저장 및 결과 반환
"""

# 1) 이미지 로드 (PIL)
contents = await ocr_file.read()
try:
image = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"이미지 처리 실패: {e}")

# 2) EasyOCR로 텍스트 추출
try:
image_np = np.array(image)
easy_results = reader.readtext(image_np) # GPU 모드 사용
easy_text = " ".join([res[1] for res in easy_results])
except Exception:
easy_text = ""

# 3) TrOCR 모델 4개로 OCR 수행 (모두 GPU input)
hf_texts: List[str] = []
try:
out1 = hf_trocr_printed(image)
if isinstance(out1, list) and "generated_text" in out1[0]:
hf_texts.append(out1[0]["generated_text"].strip())

out2 = hf_trocr_handwritten(image)
if isinstance(out2, list) and "generated_text" in out2[0]:
hf_texts.append(out2[0]["generated_text"].strip())

out3 = hf_trocr_small_printed(image)
if isinstance(out3, list) and "generated_text" in out3[0]:
hf_texts.append(out3[0]["generated_text"].strip())

out4 = hf_trocr_large_printed(image)
if isinstance(out4, list) and "generated_text" in out4[0]:
hf_texts.append(out4[0]["generated_text"].strip())
except Exception:
# TrOCR 중 오류 발생 시 무시하고 계속 진행
pass

# 4) 여러 OCR 결과 병합: 가장 긴 문자열을 최종 ocr_text로 선택
candidates = [t for t in [easy_text] + hf_texts if t and t.strip()]
if not candidates:
raise HTTPException(status_code=500, detail="텍스트를 인식할 수 없습니다.")

ocr_text = max(candidates, key=lambda s: len(s))

# 5) 새 노트 생성 및 DB에 저장
try:
new_note = NoteModel(
user_id=current_user.u_id,
folder_id=folder_id,
title="OCR 결과",
content=ocr_text # **원본 OCR 텍스트만 저장**

# 422 방지: 파일 필드명 유연 처리
upload = file or ocr_file
if upload is None:
Expand All @@ -277,6 +391,7 @@ async def ocr_and_create_note(
warnings=[f"허용되지 않는 확장자({ext}). 허용: {sorted(ALLOWED_ALL_EXTS)}"],
note_id=None,
text=None,

)

# 타입 판별
Expand Down
Loading