1
+ # -*- coding: utf-8 -*-
2
+ import time
3
+ import uvicorn
4
+ import asyncio
5
+ import logging
6
+ # import configparser
7
+ import json
8
+ from fastapi import FastAPI , Request , HTTPException
9
+ from pydantic import BaseModel
10
+ from concurrent .futures import ThreadPoolExecutor
11
+ from starlette .responses import PlainTextResponse
12
+ import functools
13
+
14
+ from modelcache import cache
15
+ from modelcache .adapter import adapter
16
+ from modelcache .manager import CacheBase , VectorBase , get_data_manager
17
+ from modelcache .similarity_evaluation .distance import SearchDistanceEvaluation
18
+ from modelcache .processor .pre import query_multi_splicing
19
+ from modelcache .processor .pre import insert_multi_splicing
20
+ from modelcache .utils .model_filter import model_blacklist_filter
21
+ from modelcache .embedding import Data2VecAudio
22
+
23
+ # 创建一个FastAPI实例
24
+ app = FastAPI ()
25
+
26
+ class RequestData (BaseModel ):
27
+ type : str
28
+ scope : dict = None
29
+ query : str = None
30
+ chat_info : list = None
31
+ remove_type : str = None
32
+ id_list : list = []
33
+
34
+ data2vec = Data2VecAudio ()
35
+
36
+ data_manager = get_data_manager (CacheBase ("sqlite" ), VectorBase ("faiss" , dimension = data2vec .dimension ))
37
+
38
+ cache .init (
39
+ embedding_func = data2vec .to_embeddings ,
40
+ data_manager = data_manager ,
41
+ similarity_evaluation = SearchDistanceEvaluation (),
42
+ query_pre_embedding_func = query_multi_splicing ,
43
+ insert_pre_embedding_func = insert_multi_splicing ,
44
+ )
45
+
46
+ executor = ThreadPoolExecutor (max_workers = 6 )
47
+
48
+ # 异步保存查询信息
49
+ async def save_query_info_fastapi (result , model , query , delta_time_log ):
50
+ loop = asyncio .get_running_loop ()
51
+ func = functools .partial (cache .data_manager .save_query_resp , result , model = model , query = json .dumps (query , ensure_ascii = False ), delta_time = delta_time_log )
52
+ await loop .run_in_executor (None , func )
53
+
54
+
55
+
56
+ @app .get ("/welcome" , response_class = PlainTextResponse )
57
+ async def first_fastapi ():
58
+ return "hello, modelcache!"
59
+
60
+ @app .post ("/modelcache" )
61
+ async def user_backend (request : Request ):
62
+ try :
63
+ raw_body = await request .body ()
64
+ # 解析字符串为JSON对象
65
+ if isinstance (raw_body , bytes ):
66
+ raw_body = raw_body .decode ("utf-8" )
67
+ if isinstance (raw_body , str ):
68
+ try :
69
+ # 尝试将字符串解析为JSON对象
70
+ request_data = json .loads (raw_body )
71
+ except json .JSONDecodeError :
72
+ # 如果无法解析,返回格式错误
73
+ raise HTTPException (status_code = 400 , detail = "Invalid JSON format" )
74
+ else :
75
+ request_data = raw_body
76
+
77
+ # 确保request_data是字典对象
78
+ if isinstance (request_data , str ):
79
+ try :
80
+ request_data = json .loads (request_data )
81
+ except json .JSONDecodeError :
82
+ raise HTTPException (status_code = 400 , detail = "Invalid JSON format" )
83
+
84
+ request_type = request_data .get ('type' )
85
+ model = None
86
+ if 'scope' in request_data :
87
+ model = request_data ['scope' ].get ('model' , '' ).replace ('-' , '_' ).replace ('.' , '_' )
88
+ query = request_data .get ('query' )
89
+ chat_info = request_data .get ('chat_info' )
90
+
91
+ if not request_type or request_type not in ['query' , 'insert' , 'remove' , 'detox' ]:
92
+ raise HTTPException (status_code = 400 , detail = "Type exception, should be one of ['query', 'insert', 'remove', 'detox']" )
93
+
94
+ except Exception as e :
95
+ request_data = raw_body if 'raw_body' in locals () else None
96
+ result = {
97
+ "errorCode" : 103 ,
98
+ "errorDesc" : str (e ),
99
+ "cacheHit" : False ,
100
+ "delta_time" : 0 ,
101
+ "hit_query" : '' ,
102
+ "answer" : '' ,
103
+ "para_dict" : request_data
104
+ }
105
+ return result
106
+
107
+
108
+ # model filter
109
+ filter_resp = model_blacklist_filter (model , request_type )
110
+ if isinstance (filter_resp , dict ):
111
+ return filter_resp
112
+
113
+ if request_type == 'query' :
114
+ try :
115
+ start_time = time .time ()
116
+ response = adapter .ChatCompletion .create_query (scope = {"model" : model }, query = query )
117
+ delta_time = f"{ round (time .time () - start_time , 2 )} s"
118
+
119
+ if response is None :
120
+ result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : False , "delta_time" : delta_time , "hit_query" : '' , "answer" : '' }
121
+ elif response in ['adapt_query_exception' ]:
122
+ # elif isinstance(response, str):
123
+ result = {"errorCode" : 201 , "errorDesc" : response , "cacheHit" : False , "delta_time" : delta_time ,
124
+ "hit_query" : '' , "answer" : '' }
125
+ else :
126
+ answer = response ['data' ]
127
+ hit_query = response ['hitQuery' ]
128
+ result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : True , "delta_time" : delta_time , "hit_query" : hit_query , "answer" : answer }
129
+
130
+ delta_time_log = round (time .time () - start_time , 2 )
131
+ asyncio .create_task (save_query_info_fastapi (result , model , query , delta_time_log ))
132
+ return result
133
+ except Exception as e :
134
+ result = {"errorCode" : 202 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 ,
135
+ "hit_query" : '' , "answer" : '' }
136
+ logging .info (f'result: { str (result )} ' )
137
+ return result
138
+
139
+ if request_type == 'insert' :
140
+ try :
141
+ response = adapter .ChatCompletion .create_insert (model = model , chat_info = chat_info )
142
+ if response == 'success' :
143
+ return {"errorCode" : 0 , "errorDesc" : "" , "writeStatus" : "success" }
144
+ else :
145
+ return {"errorCode" : 301 , "errorDesc" : response , "writeStatus" : "exception" }
146
+ except Exception as e :
147
+ return {"errorCode" : 303 , "errorDesc" : str (e ), "writeStatus" : "exception" }
148
+
149
+ if request_type == 'remove' :
150
+ response = adapter .ChatCompletion .create_remove (model = model , remove_type = request_data .get ("remove_type" ), id_list = request_data .get ("id_list" ))
151
+ if not isinstance (response , dict ):
152
+ return {"errorCode" : 401 , "errorDesc" : "" , "response" : response , "removeStatus" : "exception" }
153
+
154
+ state = response .get ('status' )
155
+ if state == 'success' :
156
+ return {"errorCode" : 0 , "errorDesc" : "" , "response" : response , "writeStatus" : "success" }
157
+ else :
158
+ return {"errorCode" : 402 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
159
+
160
+ # TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
161
+ if __name__ == '__main__' :
162
+ uvicorn .run (app , host = '0.0.0.0' , port = 5000 )
0 commit comments