Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions app/api/news_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,24 @@

router = APIRouter()

@router.post("/transfer", response_model=NewsTransferResponse)
@router.post("/news/transfer", response_model=NewsTransferResponse)
async def transform_news(request: NewsTransferRequest):
transform_prompt = build_transform_prompt(request.title, request.originalContent, request.level)
summary_prompt = build_summary_prompt(request.title, request.originalContent)
word_prompt = build_difficult_word_prompt(request.originalContent, request.level)
# 내부적으로 뉴스 유형 분류 및 모델 선택
from app.services.mcp import classify_news_type, select_model_by_news_type
news_type = classify_news_type(request.title, request.originalContent)
model = select_model_by_news_type(news_type)

transformed_content = generate_content(transform_prompt)
summary = generate_content(summary_prompt)
difficult_words_raw = generate_content(word_prompt)
# 모델에 따라 분기
if model == "gemini":
# Gemini API 호출 함수로 프롬프트 전달
transformed_content = call_gemini_api(build_transform_prompt(request.title, request.originalContent, request.level))
summary = call_gemini_api(build_summary_prompt(request.title, request.originalContent))
difficult_words_raw = call_gemini_api(build_difficult_word_prompt(request.originalContent, request.level))
else:
# 로컬 모델 사용
transformed_content = generate_content(build_transform_prompt(request.title, request.originalContent, request.level))
summary = generate_content(build_summary_prompt(request.title, request.originalContent))
difficult_words_raw = generate_content(build_difficult_word_prompt(request.originalContent, request.level))

difficult_words = []
for line in difficult_words_raw.splitlines():
Expand All @@ -33,4 +42,10 @@ async def transform_news(request: NewsTransferRequest):
difficultWords=difficult_words
),
success=True
)
)

@router.post("/news/auto_generate")
async def auto_generate_news(title: str = Body(...), content: str = Body(...), level: str = Body(...)):
mcp_request = build_mcp_request_auto(title, content, level)
result = await call_local_mcp(mcp_request)
return result
2 changes: 1 addition & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from app.api.v1.news_transform import router

app = FastAPI()
app.include_router(router, prefix="/api/news")
app.include_router(router, prefix="/api")
24 changes: 24 additions & 0 deletions app/services/mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# app/services/mcp.py

from app.models.mcp import MCPRequest, MCPRequestItem
from app.services.summarizer import build_transform_prompt

def classify_news_type(title: str, content: str) -> str:
keywords_politics = ["총선", "외교", "갈등", "정책", "정치", "국회", "대통령", "정부", "외교부"]
text = (title + " " + content)
if any(keyword in text for keyword in keywords_politics):
return "정치, 외교, 사회 이슈"
return "기타"

def select_model_by_news_type(news_type: str) -> str:
if news_type == "정치, 외교, 사회 이슈":
return "gemini"
else:
return "kullm3"

def build_mcp_request_auto(title: str, content: str, level: str) -> MCPRequest:
news_type = classify_news_type(title, content)
model = select_model_by_news_type(news_type)
prompt = build_transform_prompt(title, content, level)
item = MCPRequestItem(prompt=prompt, model=model, metadata={"news_type": news_type, "level": level})
return MCPRequest(items=[item])
60 changes: 5 additions & 55 deletions app/services/summarizer.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,33 @@
# 비즈니스 로직 / AI 추론 모듈
# app/services/summarizer.py

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import time
from typing import List

# 1. 모델 불러오기 및 4bit 양자화 설정
model_id = "nlpai-lab/KULLM3"
# 모델 로딩 및 4bit 양자화 설정
model_id = "sunnyanna/KULLM3-AWQ"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"Initial VRAM usage: {torch.cuda.memory_allocated() / (1024**3):.2f} GB")

start_load_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto"
)
end_load_time = time.time()

print(f"\nModel loaded in {end_load_time - start_load_time:.2f} seconds")

if torch.cuda.is_available():
initial_vram_after_load = torch.cuda.memory_allocated()
peak_vram_after_load = torch.cuda.max_memory_allocated()
print(f"VRAM allocated after model load: {initial_vram_after_load / (1024**3):.2f} GB")
print(f"Peak VRAM used during model load: {peak_vram_after_load / (1024**3):.2f} GB")

# 2. LLM 채팅 프롬프트 포맷

def build_chat_prompt(prompt: str):
return f"<s>[INST] {prompt.strip()} [/INST]"

# 3. 프롬프트 생성 함수 (기존 유지)
def build_transform_prompt(title: str, content: str, level: str) -> str:
base = f"다음 뉴스 제목과 본문을 사용자의 이해 수준에 맞게 다시 써줘.\n\n뉴스 제목: {title}\n뉴스 본문: {content}\n"
if level == "상":
instruction = "원문에 가깝게 유지해."
elif level == "중":
instruction = "간결하고 이해하기 쉬운 문장으로 재구성해줘. 핵심 내용만 유지해도 좋아."
else:
instruction = "초등학생도 이해할 수 있도록 아주 쉽게 설명해줘. 쉬운 단어와 짧은 문장을 써줘."
return base + "\n요청 사항: " + instruction

def build_summary_prompt(title: str, content: str) -> str:
return f"다음 뉴스 제목과 본문을 한문장으로 간단히 요약해줘.\n\n뉴스 제목: {title}\n뉴스 본문: {content}"

# 4. 배치 추론 함수
def kullm_batch_generate(prompts: List[str], max_new_tokens=512):
chat_prompts = [build_chat_prompt(p) for p in prompts]
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
inputs = tokenizer(chat_prompts, return_tensors="pt", padding=True).to(model.device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
start_infer_time = time.time()
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
Expand All @@ -74,24 +37,11 @@ def kullm_batch_generate(prompts: List[str], max_new_tokens=512):
top_p=0.2,
pad_token_id=tokenizer.eos_token_id
)
end_infer_time = time.time()
generation_time = end_infer_time - start_infer_time
decoded_results = []
generated_tokens_list = []
for i in range(len(prompts)):
original_input_len = (input_ids[i] != tokenizer.pad_token_id).sum().item()
generated_tokens = output[i].shape[0] - original_input_len
generated_tokens_list.append(generated_tokens)
result_text = tokenizer.decode(output[i], skip_special_tokens=True)
decoded_results.append(result_text.split('[/INST]')[-1].strip())
current_vram = 0
peak_vram = 0
if torch.cuda.is_available():
current_vram = torch.cuda.memory_allocated()
peak_vram = torch.cuda.max_memory_allocated()
return decoded_results, generation_time, generated_tokens_list, current_vram, peak_vram
return decoded_results

# 5. 단일 프롬프트용 generate_content 함수
def generate_content(prompt: str, max_new_tokens=512) -> str:
results, _, _, _, _ = kullm_batch_generate([prompt], max_new_tokens=max_new_tokens)
return results[0]
return kullm_batch_generate([prompt], max_new_tokens=max_new_tokens)[0]