Skip to content

Commit f770ff8

Browse files
committed
feat/1st_presentation
1 parent 1d6056c commit f770ff8

File tree

1 file changed

+38
-30
lines changed

1 file changed

+38
-30
lines changed

routers/file.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
# ~/noteflow/Backend/routers/file.py
2-
1+
# routers/file.py
32
import os
43
import io
54
import whisper
65
model = whisper.load_model("base")
76
from datetime import datetime
87
import numpy as np
98
from typing import Optional, List
10-
from urllib.parse import quote
119

1210
from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, status
1311
from fastapi.responses import FileResponse
@@ -19,6 +17,9 @@
1917
from models.note import Note as NoteModel
2018
from utils.jwt_utils import get_current_user
2119

20+
# 추가: 파일명 인코딩용
21+
import urllib.parse
22+
2223
# -------------------------------
2324
# 1) EasyOCR 라이브러리 임포트 (GPU 모드 활성화)
2425
# -------------------------------
@@ -55,7 +56,6 @@
5556
trust_remote_code=True
5657
)
5758

58-
# 업로드 디렉토리 설정
5959
BASE_UPLOAD_DIR = os.path.join(
6060
os.path.dirname(os.path.abspath(__file__)),
6161
"..",
@@ -80,11 +80,9 @@ async def upload_file(
8080
orig_filename: str = upload_file.filename or "unnamed"
8181
content_type: str = upload_file.content_type or "application/octet-stream"
8282

83-
# 사용자별 디렉토리 생성
8483
user_dir = os.path.join(BASE_UPLOAD_DIR, str(current_user.u_id))
8584
os.makedirs(user_dir, exist_ok=True)
8685

87-
# 원본 파일명 그대로 저장 (동명이인 방지)
8886
saved_filename = orig_filename
8987
saved_path = os.path.join(user_dir, saved_filename)
9088
if os.path.exists(saved_path):
@@ -99,15 +97,13 @@ async def upload_file(
9997
break
10098
counter += 1
10199

102-
# 파일 저장
103100
try:
104101
with open(saved_path, "wb") as buffer:
105102
content = await upload_file.read()
106103
buffer.write(content)
107104
except Exception as e:
108105
raise HTTPException(status_code=500, detail=f"파일 저장 실패: {e}")
109106

110-
# DB에 메타데이터 기록
111107
new_file = FileModel(
112108
user_id=current_user.u_id,
113109
folder_id=folder_id,
@@ -177,9 +173,9 @@ def download_file(
177173
if not os.path.exists(file_path):
178174
raise HTTPException(status_code=404, detail="서버에 파일이 존재하지 않습니다.")
179175

180-
# original_name 을 percent-encoding 해서 ASCII 만으로 헤더 구성
181-
filename_quoted = quote(file_obj.original_name)
182-
content_disposition = f"inline; filename*=UTF-8''{filename_quoted}"
176+
# 원본 파일명 UTF-8 URL 인코딩 처리
177+
quoted_name = urllib.parse.quote(file_obj.original_name, safe='')
178+
content_disposition = f"inline; filename*=UTF-8''{quoted_name}"
183179

184180
return FileResponse(
185181
path=file_path,
@@ -200,52 +196,64 @@ async def ocr_and_create_note(
200196
current_user = Depends(get_current_user)
201197
):
202198
"""
203-
• EasyOCR + TrOCR 모델로 이미지에서 텍스트 추출
204-
• 가장 긴 결과를 선택해 새 노트로 저장
199+
• ocr_file: 이미지 파일(UploadFile)
200+
• 1) EasyOCR로 기본 텍스트 추출 (GPU 모드)
201+
• 2) TrOCR 4개 모델로 OCR 수행 (모두 GPU)
202+
• 3) 가장 긴 결과를 최종 OCR 결과로 선택
203+
• 4) Note로 저장 및 결과 반환
205204
"""
206-
# 1) 이미지 로드
205+
206+
# 1) 이미지 로드 (PIL)
207207
contents = await ocr_file.read()
208208
try:
209209
image = Image.open(io.BytesIO(contents)).convert("RGB")
210210
except Exception as e:
211211
raise HTTPException(status_code=400, detail=f"이미지 처리 실패: {e}")
212212

213-
# 2) EasyOCR
213+
# 2) EasyOCR로 텍스트 추출
214214
try:
215215
image_np = np.array(image)
216-
easy_results = reader.readtext(image_np)
216+
easy_results = reader.readtext(image_np) # GPU 모드 사용
217217
easy_text = " ".join([res[1] for res in easy_results])
218218
except Exception:
219219
easy_text = ""
220220

221-
# 3) TrOCR 4개 모델
221+
# 3) TrOCR 모델 4개로 OCR 수행 (모두 GPU input)
222222
hf_texts: List[str] = []
223223
try:
224-
for pipe in (
225-
hf_trocr_printed,
226-
hf_trocr_handwritten,
227-
hf_trocr_small_printed,
228-
hf_trocr_large_printed
229-
):
230-
out = pipe(image)
231-
if isinstance(out, list) and "generated_text" in out[0]:
232-
hf_texts.append(out[0]["generated_text"].strip())
224+
out1 = hf_trocr_printed(image)
225+
if isinstance(out1, list) and "generated_text" in out1[0]:
226+
hf_texts.append(out1[0]["generated_text"].strip())
227+
228+
out2 = hf_trocr_handwritten(image)
229+
if isinstance(out2, list) and "generated_text" in out2[0]:
230+
hf_texts.append(out2[0]["generated_text"].strip())
231+
232+
out3 = hf_trocr_small_printed(image)
233+
if isinstance(out3, list) and "generated_text" in out3[0]:
234+
hf_texts.append(out3[0]["generated_text"].strip())
235+
236+
out4 = hf_trocr_large_printed(image)
237+
if isinstance(out4, list) and "generated_text" in out4[0]:
238+
hf_texts.append(out4[0]["generated_text"].strip())
233239
except Exception:
240+
# TrOCR 중 오류 발생 시 무시하고 계속 진행
234241
pass
235242

236-
# 4) 가장 긴 결과 선택
243+
# 4) 여러 OCR 결과 병합: 가장 긴 문자열을 최종 ocr_text로 선택
237244
candidates = [t for t in [easy_text] + hf_texts if t and t.strip()]
238245
if not candidates:
239246
raise HTTPException(status_code=500, detail="텍스트를 인식할 수 없습니다.")
240-
ocr_text = max(candidates, key=len)
241247

242-
# 5) Note 생성
248+
ocr_text = max(candidates, key=lambda s: len(s))
249+
250+
# 5) 새 노트 생성 및 DB에 저장
243251
try:
244252
new_note = NoteModel(
245253
user_id=current_user.u_id,
246254
folder_id=folder_id,
247255
title="OCR 결과",
248-
content=ocr_text
256+
content=ocr_text # **원본 OCR 텍스트만 저장**
249257
)
250258
db.add(new_note)
251259
db.commit()

0 commit comments

Comments
 (0)