Skip to content

Commit e05cb38

Browse files
authored
Merge pull request #26 from KKU-NoteFlow/fix/cuda_gpu
Fix/cuda gpu
2 parents d70cd87 + ef9bdac commit e05cb38

File tree

7 files changed

+177
-2
lines changed

7 files changed

+177
-2
lines changed

main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Backend/main.py
22
import os
33
from dotenv import load_dotenv
4+
# 환경 변수를 최대한 빨리 로드하여 GPU 설정(CUDA_VISIBLE_DEVICES)이 라우터 임포트 전에 적용되도록 함
5+
load_dotenv()
46
from fastapi import FastAPI
57
from fastapi.middleware.cors import CORSMiddleware
68
from fastapi.staticfiles import StaticFiles
@@ -12,8 +14,12 @@
1214
from routers.checklist import router as checklist_router
1315
from routers.file import router as file_router
1416

17+
18+
# 1) 환경변수 로드 (상단에서 선 로드됨)
19+
1520
import uvicorn
1621

22+
1723
load_dotenv()
1824

1925
app = FastAPI()

models/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from sqlalchemy.ext.declarative import declarative_base
12
from sqlalchemy.orm import declarative_base
23

4+
35
Base = declarative_base()

models/file.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
from sqlalchemy.orm import relationship
3+
14
from sqlalchemy import Column, Integer, String, ForeignKey, TIMESTAMP, text
25
from sqlalchemy.orm import relationship
36
from .base import Base
@@ -6,6 +9,20 @@ class File(Base):
69
__tablename__ = "file"
710

811
id = Column(Integer, primary_key=True, autoincrement=True)
12+
13+
user_id = Column(Integer, ForeignKey('user.u_id', ondelete='CASCADE'), nullable=False)
14+
folder_id = Column(Integer, ForeignKey('folder.id', ondelete='SET NULL'), nullable=True)
15+
note_id = Column(Integer, ForeignKey('note.id', ondelete='CASCADE'), nullable=True)
16+
original_name = Column(String(255), nullable=False)
17+
saved_path = Column(String(512), nullable=False)
18+
content_type = Column(String(100), nullable=False)
19+
created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'))
20+
21+
# ✅ 관계
22+
user = relationship("User", back_populates="files")
23+
folder = relationship("Folder", back_populates="files")
24+
note = relationship("Note", back_populates="files")
25+
926
user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False)
1027
folder_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True)
1128
note_id = Column(Integer, ForeignKey("note.id", ondelete="SET NULL"), nullable=True)
@@ -17,3 +34,4 @@ class File(Base):
1734
# relations
1835
user = relationship("User", back_populates="files")
1936
note = relationship("Note", back_populates="files")
37+

models/folder.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,21 @@ class Folder(Base):
1111
parent_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True)
1212
created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
1313
updated_at = Column(TIMESTAMP, nullable=False,
14+
15+
server_default=text('CURRENT_TIMESTAMP'),
16+
onupdate=text('CURRENT_TIMESTAMP'))
17+
18+
# ✅ 관계
19+
user = relationship("User", back_populates="folders")
20+
parent = relationship("Folder", remote_side=[id], backref="children")
21+
notes = relationship("Note", back_populates="folder", cascade="all, delete")
22+
files = relationship("File", back_populates="folder", cascade="all, delete")
23+
1424
server_default=text("CURRENT_TIMESTAMP"),
1525
onupdate=text("CURRENT_TIMESTAMP"))
1626

1727
# relations
1828
user = relationship("User")
1929
parent = relationship("Folder", remote_side=[id], backref="children")
2030
notes = relationship("Note", back_populates="folder", cascade="all, delete")
31+

models/note.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,17 @@ class Note(Base):
1414
last_accessed = Column(TIMESTAMP, nullable=True)
1515
created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
1616
updated_at = Column(TIMESTAMP, nullable=False,
17+
18+
server_default=text('CURRENT_TIMESTAMP'),
19+
onupdate=text('CURRENT_TIMESTAMP'))
20+
21+
# ✅ 관계
22+
1723
server_default=text("CURRENT_TIMESTAMP"),
1824
onupdate=text("CURRENT_TIMESTAMP"))
1925

2026
# relations
27+
2128
user = relationship("User", back_populates="notes")
2229
folder = relationship("Folder", back_populates="notes")
2330
files = relationship("File", back_populates="note", cascade="all, delete")

models/user.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,22 @@ class User(Base):
99
id = Column(String(50), nullable=False, unique=True) # 로그인 ID 또는 소셜 ID
1010
email = Column(String(150), nullable=False, unique=True)
1111
password = Column(String(255), nullable=False)
12+
13+
provider = Column(
14+
Enum('local','google','kakao','naver', name='provider_enum'),
15+
nullable=False,
16+
server_default=text("'local'")
17+
)
18+
created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'))
19+
updated_at = Column(TIMESTAMP, nullable=False,
20+
server_default=text('CURRENT_TIMESTAMP'),
21+
onupdate=text('CURRENT_TIMESTAMP'))
22+
23+
# ✅ 관계
24+
folders = relationship("Folder", back_populates="user", cascade="all, delete")
25+
notes = relationship("Note", back_populates="user", cascade="all, delete")
26+
files = relationship("File", back_populates="user", cascade="all, delete")
27+
1228
provider = Column(Enum("local", "google", "kakao", "naver", name="provider_enum"),
1329
nullable=False, server_default=text("'local'"))
1430
created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP"))

routers/file.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,52 @@
1212
from models.note import Note as NoteModel
1313
from utils.jwt_utils import get_current_user
1414

15+
cuda_gpu
16+
# 추가: 파일명 인코딩용
17+
import urllib.parse
18+
19+
# -------------------------------
20+
# 1) EasyOCR 라이브러리 임포트 (GPU 모드 활성화)
21+
# -------------------------------
22+
import easyocr
23+
reader = easyocr.Reader(["ko", "en"], gpu=True)
24+
25+
# -------------------------------
26+
# 2) Hugging Face TrOCR 모델용 파이프라인 (GPU 사용)
27+
# -------------------------------
28+
from transformers import pipeline
29+
30+
hf_trocr_printed = pipeline(
31+
"image-to-text",
32+
model="microsoft/trocr-base-printed",
33+
device=0,
34+
trust_remote_code=True
35+
)
36+
hf_trocr_handwritten = pipeline(
37+
"image-to-text",
38+
model="microsoft/trocr-base-handwritten",
39+
device=0,
40+
trust_remote_code=True
41+
)
42+
hf_trocr_small_printed = pipeline(
43+
"image-to-text",
44+
model="microsoft/trocr-small-printed",
45+
device=0,
46+
trust_remote_code=True
47+
)
48+
hf_trocr_large_printed = pipeline(
49+
"image-to-text",
50+
model="microsoft/trocr-large-printed",
51+
device=0,
52+
trust_remote_code=True
53+
)
54+
55+
BASE_UPLOAD_DIR = os.path.join(
56+
os.path.dirname(os.path.abspath(__file__)),
57+
"..",
58+
"uploads"
59+
)
60+
1561
# 공통 OCR 파이프라인
1662
from utils.ocr import run_pipeline, detect_type
1763
from schemas.file import OCRResponse
@@ -79,11 +125,10 @@ async def upload_file(
79125
orig_filename: str = upload_file.filename or "unnamed"
80126
content_type: str = upload_file.content_type or "application/octet-stream"
81127

82-
# 사용자별 디렉토리 생성
83128
user_dir = os.path.join(BASE_UPLOAD_DIR, str(current_user.u_id))
84129
os.makedirs(user_dir, exist_ok=True)
85130

86-
# 원본 파일명 유지 (중복 방지)
131+
87132
saved_filename = orig_filename
88133
saved_path = os.path.join(user_dir, saved_filename)
89134
if os.path.exists(saved_path):
@@ -98,14 +143,17 @@ async def upload_file(
98143
break
99144
counter += 1
100145

146+
101147
# 저장
148+
102149
try:
103150
with open(saved_path, "wb") as buffer:
104151
content = await upload_file.read()
105152
buffer.write(content)
106153
except Exception as e:
107154
raise HTTPException(status_code=500, detail=f"파일 저장 실패: {e}")
108155

156+
109157
# note_id가 있으면 해당 노트 확인
110158
note_obj = None
111159
if note_id is not None:
@@ -118,6 +166,7 @@ async def upload_file(
118166
raise HTTPException(status_code=404, detail="해당 노트를 찾을 수 없습니다.")
119167

120168
# DB 메타 기록
169+
121170
new_file = FileModel(
122171
user_id=current_user.u_id,
123172
folder_id=None if note_id else folder_id,
@@ -202,6 +251,10 @@ def download_file(
202251
if not os.path.exists(file_path):
203252
raise HTTPException(status_code=404, detail="서버에 파일이 존재하지 않습니다.")
204253

254+
# 원본 파일명 UTF-8 URL 인코딩 처리
255+
quoted_name = urllib.parse.quote(file_obj.original_name, safe='')
256+
content_disposition = f"inline; filename*=UTF-8''{quoted_name}"
257+
205258
return FileResponse(
206259
path=file_path,
207260
media_type=file_obj.content_type,
@@ -254,6 +307,67 @@ async def ocr_and_create_note(
254307
db: Session = Depends(get_db),
255308
current_user = Depends(get_current_user)
256309
):
310+
311+
"""
312+
• ocr_file: 이미지 파일(UploadFile)
313+
• 1) EasyOCR로 기본 텍스트 추출 (GPU 모드)
314+
• 2) TrOCR 4개 모델로 OCR 수행 (모두 GPU)
315+
• 3) 가장 긴 결과를 최종 OCR 결과로 선택
316+
• 4) Note로 저장 및 결과 반환
317+
"""
318+
319+
# 1) 이미지 로드 (PIL)
320+
contents = await ocr_file.read()
321+
try:
322+
image = Image.open(io.BytesIO(contents)).convert("RGB")
323+
except Exception as e:
324+
raise HTTPException(status_code=400, detail=f"이미지 처리 실패: {e}")
325+
326+
# 2) EasyOCR로 텍스트 추출
327+
try:
328+
image_np = np.array(image)
329+
easy_results = reader.readtext(image_np) # GPU 모드 사용
330+
easy_text = " ".join([res[1] for res in easy_results])
331+
except Exception:
332+
easy_text = ""
333+
334+
# 3) TrOCR 모델 4개로 OCR 수행 (모두 GPU input)
335+
hf_texts: List[str] = []
336+
try:
337+
out1 = hf_trocr_printed(image)
338+
if isinstance(out1, list) and "generated_text" in out1[0]:
339+
hf_texts.append(out1[0]["generated_text"].strip())
340+
341+
out2 = hf_trocr_handwritten(image)
342+
if isinstance(out2, list) and "generated_text" in out2[0]:
343+
hf_texts.append(out2[0]["generated_text"].strip())
344+
345+
out3 = hf_trocr_small_printed(image)
346+
if isinstance(out3, list) and "generated_text" in out3[0]:
347+
hf_texts.append(out3[0]["generated_text"].strip())
348+
349+
out4 = hf_trocr_large_printed(image)
350+
if isinstance(out4, list) and "generated_text" in out4[0]:
351+
hf_texts.append(out4[0]["generated_text"].strip())
352+
except Exception:
353+
# TrOCR 중 오류 발생 시 무시하고 계속 진행
354+
pass
355+
356+
# 4) 여러 OCR 결과 병합: 가장 긴 문자열을 최종 ocr_text로 선택
357+
candidates = [t for t in [easy_text] + hf_texts if t and t.strip()]
358+
if not candidates:
359+
raise HTTPException(status_code=500, detail="텍스트를 인식할 수 없습니다.")
360+
361+
ocr_text = max(candidates, key=lambda s: len(s))
362+
363+
# 5) 새 노트 생성 및 DB에 저장
364+
try:
365+
new_note = NoteModel(
366+
user_id=current_user.u_id,
367+
folder_id=folder_id,
368+
title="OCR 결과",
369+
content=ocr_text # **원본 OCR 텍스트만 저장**
370+
257371
# 422 방지: 파일 필드명 유연 처리
258372
upload = file or ocr_file
259373
if upload is None:
@@ -277,6 +391,7 @@ async def ocr_and_create_note(
277391
warnings=[f"허용되지 않는 확장자({ext}). 허용: {sorted(ALLOWED_ALL_EXTS)}"],
278392
note_id=None,
279393
text=None,
394+
280395
)
281396

282397
# 타입 판별

0 commit comments

Comments
 (0)