|
1 | | -# ~/noteflow/Backend/utils/llm.py |
| 1 | +from langchain.callbacks import AsyncIteratorCallbackHandler |
| 2 | +from langchain_ollama import ChatOllama |
| 3 | +from langchain.schema import HumanMessage, SystemMessage |
| 4 | +import re, asyncio |
2 | 5 |
|
3 | | -import torch |
4 | | -from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM |
| 6 | +_THOUGHT_PAT = re.compile( |
| 7 | + r"^\s*(okay|let\s*me|i\s*need\s*to|first[, ]|then[, ]|next[, ]|in summary|먼저|그\s*다음|요약하면)", |
| 8 | + re.I, |
| 9 | +) |
5 | 10 |
|
6 | | -_MODEL_NAME = "Qwen/Qwen3-8B" |
7 | | - |
8 | | -# 전역 변수: 최초에는 토크나이저/모델이 None |
9 | | -_tokenizer = None |
10 | | -_model = None |
11 | | - |
12 | | -def _load_model(): |
13 | | - """ |
14 | | - summarize_with_qwen3()가 최초 호출될 때만 Qwen3-8B 모델과 토크나이저를 메모리에 로드합니다. |
15 | | - """ |
16 | | - global _tokenizer, _model |
17 | | - if _model is None or _tokenizer is None: |
18 | | - # 1) Config 불러와서 parallel_style 지정 |
19 | | - config = AutoConfig.from_pretrained( |
20 | | - _MODEL_NAME, |
21 | | - trust_remote_code=True |
22 | | - ) |
23 | | - # 반드시 "auto"로 지정 (NoneType 오류 방지) |
24 | | - config.parallel_style = "auto" |
25 | | - |
26 | | - # 2) 토크나이저 로드 |
27 | | - _tokenizer = AutoTokenizer.from_pretrained( |
28 | | - _MODEL_NAME, |
29 | | - trust_remote_code=True |
30 | | - ) |
31 | | - |
32 | | - # 3) 모델 로드 시 config 인자 추가 |
33 | | - _model = AutoModelForCausalLM.from_pretrained( |
34 | | - _MODEL_NAME, |
35 | | - config=config, # custom config 전달 |
36 | | - torch_dtype="auto", |
37 | | - device_map="auto", |
38 | | - trust_remote_code=True |
39 | | - ) |
40 | | - _model.eval() |
41 | | - |
42 | | - |
43 | | -def summarize_with_qwen3( |
44 | | - text: str, |
45 | | - max_new_tokens: int = 256, |
46 | | - temperature: float = 0.6 |
47 | | -) -> str: |
| 11 | +async def stream_summary_with_langchain(text: str): |
48 | 12 | """ |
49 | | - - 한국어 문서를 간결하고 핵심적으로 요약 |
50 | | - - 반환값: 요약된 한국어 문자열 |
| 13 | + LangChain + Ollama에서 토큰을 비동기로 받아 |
| 14 | + SSE("data: ...\\n\\n") 형식으로 yield 하는 async generator |
51 | 15 | """ |
52 | | - # 모델/토크나이저가 아직 로드되지 않았다면, 이 시점에만 로드 |
53 | | - if _model is None or _tokenizer is None: |
54 | | - _load_model() |
| 16 | + # 1) LangChain용 콜백 핸들러 |
| 17 | + cb = AsyncIteratorCallbackHandler() |
| 18 | + |
| 19 | + # 2) Ollama Chat 모델 (streaming=True) |
| 20 | + llm = ChatOllama( |
| 21 | + base_url="http://localhost:11434", |
| 22 | + model="qwen3:8b", |
| 23 | + streaming=True, |
| 24 | + callbacks=[cb], |
| 25 | + temperature=0.6, |
| 26 | + ) |
55 | 27 |
|
56 | | - # Chat-format prompt 생성 |
| 28 | + # 3) 프롬프트 |
57 | 29 | messages = [ |
58 | | - { |
59 | | - "role": "system", |
60 | | - "content": ( |
61 | | - "당신은 한국어 문서를 간결하고 핵심적으로 요약하는 전문가입니다. " |
62 | | - "요약 외에는 절대 다른 말을 하지 마세요." |
63 | | - ) |
64 | | - }, |
65 | | - { |
66 | | - "role": "user", |
67 | | - "content": text |
68 | | - } |
| 30 | + SystemMessage( |
| 31 | + content="다음 텍스트를 한국어로 간결하게 요약하세요. " |
| 32 | + "사고 과정(Chain‑of‑Thought)은 절대 출력하지 마세요./no_think" |
| 33 | + ), |
| 34 | + HumanMessage(content=text), |
69 | 35 | ] |
70 | 36 |
|
71 | | - # tokenizer.apply_chat_template()를 통해 모델 친화적인 프롬프트 생성 |
72 | | - prompt = _tokenizer.apply_chat_template( |
73 | | - messages, |
74 | | - tokenize=False, |
75 | | - add_generation_prompt=True, |
76 | | - enable_thinking=False |
77 | | - ) |
78 | | - |
79 | | - # 입력 토크나이즈 후 모델 디바이스로 이동 |
80 | | - inputs = _tokenizer(prompt, return_tensors="pt").to(_model.device) |
| 37 | + # 4) LLM 호출 비동기 실행 |
| 38 | + task = asyncio.create_task(llm.agenerate([messages])) |
81 | 39 |
|
82 | | - # 모델 generate 호출 |
83 | | - outputs = _model.generate( |
84 | | - **inputs, |
85 | | - max_new_tokens=max_new_tokens, |
86 | | - temperature=temperature, |
87 | | - top_p=0.95, |
88 | | - top_k=20, |
89 | | - do_sample=False, # 안정적인 요약을 위해 샘플링 끄기 |
90 | | - eos_token_id=_tokenizer.eos_token_id |
91 | | - ) |
| 40 | + buffer = "" |
| 41 | + async for token in cb.aiter(): |
| 42 | + buffer += token |
| 43 | + if buffer.endswith(("\n", "。", ".", "…")): |
| 44 | + line = buffer.strip() |
| 45 | + buffer = "" |
92 | 46 |
|
93 | | - # 입력 프롬프트 뒤에 생성된 토큰만 디코딩 |
94 | | - gen_tokens = outputs[0].tolist()[len(inputs.input_ids[0]):] |
95 | | - decoded = _tokenizer.decode(gen_tokens, skip_special_tokens=True) |
| 47 | + if not _THOUGHT_PAT.match(line): |
| 48 | + yield f"data: {line}\n\n" # SSE 청크 전송 |
96 | 49 |
|
97 | | - return decoded.strip() |
| 50 | + await task # 예외 전파 |
0 commit comments