1212from models .note import Note as NoteModel
1313from utils .jwt_utils import get_current_user
1414
15+ cuda_gpu
16+ # 추가: 파일명 인코딩용
17+ import urllib .parse
18+
19+ # -------------------------------
20+ # 1) EasyOCR 라이브러리 임포트 (GPU 모드 활성화)
21+ # -------------------------------
22+ import easyocr
23+ reader = easyocr .Reader (["ko" , "en" ], gpu = True )
24+
25+ # -------------------------------
26+ # 2) Hugging Face TrOCR 모델용 파이프라인 (GPU 사용)
27+ # -------------------------------
28+ from transformers import pipeline
29+
30+ hf_trocr_printed = pipeline (
31+ "image-to-text" ,
32+ model = "microsoft/trocr-base-printed" ,
33+ device = 0 ,
34+ trust_remote_code = True
35+ )
36+ hf_trocr_handwritten = pipeline (
37+ "image-to-text" ,
38+ model = "microsoft/trocr-base-handwritten" ,
39+ device = 0 ,
40+ trust_remote_code = True
41+ )
42+ hf_trocr_small_printed = pipeline (
43+ "image-to-text" ,
44+ model = "microsoft/trocr-small-printed" ,
45+ device = 0 ,
46+ trust_remote_code = True
47+ )
48+ hf_trocr_large_printed = pipeline (
49+ "image-to-text" ,
50+ model = "microsoft/trocr-large-printed" ,
51+ device = 0 ,
52+ trust_remote_code = True
53+ )
54+
55+ BASE_UPLOAD_DIR = os .path .join (
56+ os .path .dirname (os .path .abspath (__file__ )),
57+ ".." ,
58+ "uploads"
59+ )
60+
1561# 공통 OCR 파이프라인
1662from utils .ocr import run_pipeline , detect_type
1763from schemas .file import OCRResponse
@@ -79,11 +125,10 @@ async def upload_file(
79125 orig_filename : str = upload_file .filename or "unnamed"
80126 content_type : str = upload_file .content_type or "application/octet-stream"
81127
82- # 사용자별 디렉토리 생성
83128 user_dir = os .path .join (BASE_UPLOAD_DIR , str (current_user .u_id ))
84129 os .makedirs (user_dir , exist_ok = True )
85130
86- # 원본 파일명 유지 (중복 방지)
131+
87132 saved_filename = orig_filename
88133 saved_path = os .path .join (user_dir , saved_filename )
89134 if os .path .exists (saved_path ):
@@ -98,14 +143,17 @@ async def upload_file(
98143 break
99144 counter += 1
100145
146+
101147 # 저장
148+
102149 try :
103150 with open (saved_path , "wb" ) as buffer :
104151 content = await upload_file .read ()
105152 buffer .write (content )
106153 except Exception as e :
107154 raise HTTPException (status_code = 500 , detail = f"파일 저장 실패: { e } " )
108155
156+
109157 # note_id가 있으면 해당 노트 확인
110158 note_obj = None
111159 if note_id is not None :
@@ -118,6 +166,7 @@ async def upload_file(
118166 raise HTTPException (status_code = 404 , detail = "해당 노트를 찾을 수 없습니다." )
119167
120168 # DB 메타 기록
169+
121170 new_file = FileModel (
122171 user_id = current_user .u_id ,
123172 folder_id = None if note_id else folder_id ,
@@ -202,6 +251,10 @@ def download_file(
202251 if not os .path .exists (file_path ):
203252 raise HTTPException (status_code = 404 , detail = "서버에 파일이 존재하지 않습니다." )
204253
254+ # 원본 파일명 UTF-8 URL 인코딩 처리
255+ quoted_name = urllib .parse .quote (file_obj .original_name , safe = '' )
256+ content_disposition = f"inline; filename*=UTF-8''{ quoted_name } "
257+
205258 return FileResponse (
206259 path = file_path ,
207260 media_type = file_obj .content_type ,
@@ -254,6 +307,67 @@ async def ocr_and_create_note(
254307 db : Session = Depends (get_db ),
255308 current_user = Depends (get_current_user )
256309):
310+
311+ """
312+ • ocr_file: 이미지 파일(UploadFile)
313+ • 1) EasyOCR로 기본 텍스트 추출 (GPU 모드)
314+ • 2) TrOCR 4개 모델로 OCR 수행 (모두 GPU)
315+ • 3) 가장 긴 결과를 최종 OCR 결과로 선택
316+ • 4) Note로 저장 및 결과 반환
317+ """
318+
319+ # 1) 이미지 로드 (PIL)
320+ contents = await ocr_file .read ()
321+ try :
322+ image = Image .open (io .BytesIO (contents )).convert ("RGB" )
323+ except Exception as e :
324+ raise HTTPException (status_code = 400 , detail = f"이미지 처리 실패: { e } " )
325+
326+ # 2) EasyOCR로 텍스트 추출
327+ try :
328+ image_np = np .array (image )
329+ easy_results = reader .readtext (image_np ) # GPU 모드 사용
330+ easy_text = " " .join ([res [1 ] for res in easy_results ])
331+ except Exception :
332+ easy_text = ""
333+
334+ # 3) TrOCR 모델 4개로 OCR 수행 (모두 GPU input)
335+ hf_texts : List [str ] = []
336+ try :
337+ out1 = hf_trocr_printed (image )
338+ if isinstance (out1 , list ) and "generated_text" in out1 [0 ]:
339+ hf_texts .append (out1 [0 ]["generated_text" ].strip ())
340+
341+ out2 = hf_trocr_handwritten (image )
342+ if isinstance (out2 , list ) and "generated_text" in out2 [0 ]:
343+ hf_texts .append (out2 [0 ]["generated_text" ].strip ())
344+
345+ out3 = hf_trocr_small_printed (image )
346+ if isinstance (out3 , list ) and "generated_text" in out3 [0 ]:
347+ hf_texts .append (out3 [0 ]["generated_text" ].strip ())
348+
349+ out4 = hf_trocr_large_printed (image )
350+ if isinstance (out4 , list ) and "generated_text" in out4 [0 ]:
351+ hf_texts .append (out4 [0 ]["generated_text" ].strip ())
352+ except Exception :
353+ # TrOCR 중 오류 발생 시 무시하고 계속 진행
354+ pass
355+
356+ # 4) 여러 OCR 결과 병합: 가장 긴 문자열을 최종 ocr_text로 선택
357+ candidates = [t for t in [easy_text ] + hf_texts if t and t .strip ()]
358+ if not candidates :
359+ raise HTTPException (status_code = 500 , detail = "텍스트를 인식할 수 없습니다." )
360+
361+ ocr_text = max (candidates , key = lambda s : len (s ))
362+
363+ # 5) 새 노트 생성 및 DB에 저장
364+ try :
365+ new_note = NoteModel (
366+ user_id = current_user .u_id ,
367+ folder_id = folder_id ,
368+ title = "OCR 결과" ,
369+ content = ocr_text # **원본 OCR 텍스트만 저장**
370+
257371 # 422 방지: 파일 필드명 유연 처리
258372 upload = file or ocr_file
259373 if upload is None :
@@ -277,6 +391,7 @@ async def ocr_and_create_note(
277391 warnings = [f"허용되지 않는 확장자({ ext } ). 허용: { sorted (ALLOWED_ALL_EXTS )} " ],
278392 note_id = None ,
279393 text = None ,
394+
280395 )
281396
282397 # 타입 판별
0 commit comments