Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions api.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,58 @@
from fastapi import FastAPI, Request
import argparse
import datetime
import logging
import typing as t

import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel

import models
import uvicorn, json, datetime
import torch

logger = logging.getLogger(__name__)

DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"

app = FastAPI()

@app.post("/")
async def create_item(request: Request):
global model
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
# history = json_post_list.get('history')
# max_length = json_post_list.get('max_length')
# top_p = json_post_list.get('top_p')
# temperature = json_post_list.get('temperature')
output = model.run(prompt)
model: models.LLMModel


class ChatRequest(BaseModel):
prompt: str
history: t.Optional[list] = []


class ChatResponse(BaseModel):
status: t.Optional[int] = 200
response: str
history: t.Optional[list] = []
time: str


@app.post("/", response_model=ChatResponse)
async def chat_completions(request: ChatRequest) -> ChatResponse:
prompt = request.prompt
output = model.run(prompt, request.history)
if isinstance(output, tuple):
response, history = output
else:
response = output
history = []
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
return answer
logger.info(f"prompt: {prompt}, response: {response}")
return ChatResponse(response=response, history=history, time=datetime.datetime.now().strftime(DATETIME_FORMAT))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument("model", choices=models.availabel_models)
args = parser.parse_args()
model = models.get_model(args)
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
logger.info(f"model<{args}> load success")

uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
3 changes: 2 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from pathlib import Path
import traceback
import typing as t

sys.path.append(str((Path(__file__).parent / 'llama' ).absolute()))
sys.path.append(str((Path(__file__).parent / 'pangualpha').absolute()))
Expand All @@ -24,7 +25,7 @@ class LLMModel:
def __init__(self) -> None:
pass

def run(self, input_text: str) -> str:
def run(self, input_text: str, history: t.Optional[list] = None, **kwargs) -> str:
pass


Expand Down
5 changes: 3 additions & 2 deletions models/chatglm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from models import LLMModel
import jittor as jt
import torch
import typing as t

os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
Expand Down Expand Up @@ -41,8 +42,8 @@ def run_web_demo(self, input_text, history=[]):
while True:
yield self.run(input_text, history=history)

def run(self, text, history=[]):
return self.model.chat(self.tokenizer, text, history=history)
def run(self, input_text: str, history: t.Optional[list] = None, **kwargs) -> str:
return self.model.chat(self.tokenizer, input_text, history=history)

def get_model(args):
return ChatGLMMdoel(args)
13 changes: 7 additions & 6 deletions models/chatrwkv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os, copy, types, gc, sys
import numpy as np
from prompt_toolkit import prompt
import typing as t

from models import LLMModel
import jittor as jt
Expand Down Expand Up @@ -175,11 +176,11 @@ def load_all_stat(self, srv, name):
self.model_tokens = copy.deepcopy(self.all_state[n]['token'])
return self.all_state[n]['out']


def run(self, message: str, is_web=False) -> str:
def run(self, input_text: str, history: t.Optional[list] = None, **kwargs) -> str:
is_web = kwargs.get("is_web", False)
srv = 'dummy_server'

msg = message.replace('\\n','\n').strip()
msg = input_text.replace('\\n','\n').strip()

x_temp = self.GEN_TEMP
x_top_p = self.GEN_TOP_P
Expand All @@ -197,7 +198,7 @@ def run(self, message: str, is_web=False) -> str:
x_temp = 5
if x_top_p <= 0:
x_top_p = 0

if msg == '+reset':
out = self.load_all_stat('', 'chat_init')
self.save_all_stat(srv, 'chat', out)
Expand Down Expand Up @@ -227,7 +228,7 @@ def run(self, message: str, is_web=False) -> str:
real_msg = msg[4:].strip()
new = f"{self.user}{self.interface} {real_msg}\n\n{self.bot}{self.interface}"
# print(f'### qa ###\n[{new}]')

out = self.run_rnn(self.tokenizer.encode(new))
self.save_all_stat(srv, 'gen_0', out)

Expand Down Expand Up @@ -259,7 +260,7 @@ def run(self, message: str, is_web=False) -> str:
out = self.run_rnn([token], newline_adj=-2)
else:
out = self.run_rnn([token])

xxx = self.tokenizer.decode(self.model_tokens[out_last:])
if '\ufffd' not in xxx: # avoid utf-8 display issues
print(xxx, end='', flush=True)
Expand Down
3 changes: 2 additions & 1 deletion models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
import os
import platform
import typing as t

import jittor as jt
jt.flags.use_cuda = 1
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(self, args) -> None:
)
jt.gc()

def run(self, input_text: str) -> str:
def run(self, input_text: str, history: t.Optional[list] = None, **kwargs) -> str:
with jt.no_grad():
output = self.generator.generate([input_text], max_gen_len=256, temperature=0.8, top_p=0.95)
text_out = ""
Expand Down
17 changes: 10 additions & 7 deletions models/pangualpha/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os, sys
import os
import sys
import typing as t

import jittor as jt
import numpy as np
import torch
import jittor as jt

from megatron.text_generation_utils import pad_batch, get_batch
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model
from megatron.text_generation_utils import pad_batch, get_batch
from megatron.training import get_model as megatron_get_model

from models import LLMModel


Expand Down Expand Up @@ -153,11 +155,12 @@ def chat(self) -> str:
generate(self.model, context_tokens, self.args, tokenizer, 100, len(text))
print("")

def run(self, text, tokenizer=None, history=[]):
def run(self, input_text: str, history: t.Optional[list] = None, **kwargs) -> str:
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is None:
tokenizer = get_tokenizer()
tokenizer.tokenize("init")
text = "问:" + text + "?答:"
text = "问:" + input_text + "?答:"
context_tokens = tokenizer.tokenize(text)
return generate(self.model, context_tokens, self.args, tokenizer, 100, len(text))

Expand Down