7
7
8
8
from anyio import Lock
9
9
from functools import partial
10
- from typing import Iterator , List , Optional , Union , Dict
10
+ from typing import List , Optional , Union , Dict
11
11
12
12
import llama_cpp
13
13
@@ -155,34 +155,71 @@ def create_app(
155
155
return app
156
156
157
157
158
+ def prepare_request_resources (
159
+ body : CreateCompletionRequest | CreateChatCompletionRequest ,
160
+ llama_proxy : LlamaProxy ,
161
+ body_model : str | None ,
162
+ kwargs ,
163
+ ) -> llama_cpp .Llama :
164
+ if llama_proxy is None :
165
+ raise HTTPException (
166
+ status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
167
+ detail = "Service is not available" ,
168
+ )
169
+ llama = llama_proxy (body_model )
170
+ if body .logit_bias is not None :
171
+ kwargs ["logit_bias" ] = (
172
+ _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
173
+ if body .logit_bias_type == "tokens"
174
+ else body .logit_bias
175
+ )
176
+
177
+ if body .grammar is not None :
178
+ kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
179
+
180
+ if body .min_tokens > 0 :
181
+ _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
182
+ [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
183
+ )
184
+ if "logits_processor" not in kwargs :
185
+ kwargs ["logits_processor" ] = _min_tokens_logits_processor
186
+ else :
187
+ kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
188
+ return llama
189
+
190
+
158
191
async def get_event_publisher (
159
192
request : Request ,
160
193
inner_send_chan : MemoryObjectSendStream [typing .Any ],
161
- iterator : Iterator [typing .Any ],
162
- on_complete : typing .Optional [typing .Callable [[], typing .Awaitable [None ]]] = None ,
194
+ body : CreateCompletionRequest | CreateChatCompletionRequest ,
195
+ body_model : str | None ,
196
+ llama_call ,
197
+ kwargs ,
163
198
):
164
199
server_settings = next (get_server_settings ())
165
200
interrupt_requests = (
166
201
server_settings .interrupt_requests if server_settings else False
167
202
)
168
- async with inner_send_chan :
169
- try :
170
- async for chunk in iterate_in_threadpool (iterator ):
171
- await inner_send_chan .send (dict (data = json .dumps (chunk )))
172
- if await request .is_disconnected ():
173
- raise anyio .get_cancelled_exc_class ()()
174
- if interrupt_requests and llama_outer_lock .locked ():
175
- await inner_send_chan .send (dict (data = "[DONE]" ))
176
- raise anyio .get_cancelled_exc_class ()()
177
- await inner_send_chan .send (dict (data = "[DONE]" ))
178
- except anyio .get_cancelled_exc_class () as e :
179
- print ("disconnected" )
180
- with anyio .move_on_after (1 , shield = True ):
181
- print (f"Disconnected from client (via refresh/close) { request .client } " )
182
- raise e
183
- finally :
184
- if on_complete :
185
- await on_complete ()
203
+ async with contextlib .asynccontextmanager (get_llama_proxy )() as llama_proxy :
204
+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
205
+ async with inner_send_chan :
206
+ try :
207
+ iterator = await run_in_threadpool (llama_call , llama , ** kwargs )
208
+ async for chunk in iterate_in_threadpool (iterator ):
209
+ await inner_send_chan .send (dict (data = json .dumps (chunk )))
210
+ if await request .is_disconnected ():
211
+ raise anyio .get_cancelled_exc_class ()()
212
+ if interrupt_requests and llama_outer_lock .locked ():
213
+ await inner_send_chan .send (dict (data = "[DONE]" ))
214
+ raise anyio .get_cancelled_exc_class ()()
215
+ await inner_send_chan .send (dict (data = "[DONE]" ))
216
+ except anyio .get_cancelled_exc_class () as e :
217
+ print ("disconnected" )
218
+ with anyio .move_on_after (1 , shield = True ):
219
+ print (
220
+ f"Disconnected from client (via refresh/close) { request .client } "
221
+ )
222
+ raise e
186
223
187
224
188
225
def _logit_bias_tokens_to_input_ids (
@@ -267,18 +304,11 @@ async def create_completion(
267
304
request : Request ,
268
305
body : CreateCompletionRequest ,
269
306
) -> llama_cpp .Completion :
270
- exit_stack = contextlib .AsyncExitStack ()
271
- llama_proxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
272
- if llama_proxy is None :
273
- raise HTTPException (
274
- status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
275
- detail = "Service is not available" ,
276
- )
277
307
if isinstance (body .prompt , list ):
278
308
assert len (body .prompt ) <= 1
279
309
body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
280
310
281
- llama = llama_proxy (
311
+ body_model = (
282
312
body .model
283
313
if request .url .path != "/v1/engines/copilot-codex/completions"
284
314
else "copilot-codex"
@@ -293,60 +323,38 @@ async def create_completion(
293
323
}
294
324
kwargs = body .model_dump (exclude = exclude )
295
325
296
- if body .logit_bias is not None :
297
- kwargs ["logit_bias" ] = (
298
- _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
299
- if body .logit_bias_type == "tokens"
300
- else body .logit_bias
301
- )
302
-
303
- if body .grammar is not None :
304
- kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
305
-
306
- if body .min_tokens > 0 :
307
- _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
308
- [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
309
- )
310
- if "logits_processor" not in kwargs :
311
- kwargs ["logits_processor" ] = _min_tokens_logits_processor
312
- else :
313
- kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
314
-
315
- try :
316
- iterator_or_completion : Union [
317
- llama_cpp .CreateCompletionResponse ,
318
- Iterator [llama_cpp .CreateCompletionStreamResponse ],
319
- ] = await run_in_threadpool (llama , ** kwargs )
320
- except Exception as err :
321
- await exit_stack .aclose ()
322
- raise err
323
-
324
- if isinstance (iterator_or_completion , Iterator ):
325
- # EAFP: It's easier to ask for forgiveness than permission
326
- first_response = await run_in_threadpool (next , iterator_or_completion )
327
-
328
- # If no exception was raised from first_response, we can assume that
329
- # the iterator is valid and we can use it to stream the response.
330
- def iterator () -> Iterator [llama_cpp .CreateCompletionStreamResponse ]:
331
- yield first_response
332
- yield from iterator_or_completion
333
-
326
+ # handle streaming request
327
+ if kwargs .get ("stream" , False ):
334
328
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
335
329
return EventSourceResponse (
336
330
recv_chan ,
337
331
data_sender_callable = partial ( # type: ignore
338
332
get_event_publisher ,
339
333
request = request ,
340
334
inner_send_chan = send_chan ,
341
- iterator = iterator (),
342
- on_complete = exit_stack .aclose ,
335
+ body = body ,
336
+ body_model = body_model ,
337
+ llama_call = llama_cpp .Llama .__call__ ,
338
+ kwargs = kwargs ,
343
339
),
344
340
sep = "\n " ,
345
341
ping_message_factory = _ping_message_factory ,
346
342
)
347
- else :
348
- await exit_stack .aclose ()
349
- return iterator_or_completion
343
+
344
+ # handle regular request
345
+ async with contextlib .asynccontextmanager (get_llama_proxy )() as llama_proxy :
346
+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
347
+
348
+ if await request .is_disconnected ():
349
+ print (
350
+ f"Disconnected from client (via refresh/close) before llm invoked { request .client } "
351
+ )
352
+ raise HTTPException (
353
+ status_code = status .HTTP_400_BAD_REQUEST ,
354
+ detail = "Client closed request" ,
355
+ )
356
+
357
+ return await run_in_threadpool (llama , ** kwargs )
350
358
351
359
352
360
@router .post (
@@ -474,74 +482,48 @@ async def create_chat_completion(
474
482
# where the dependency is cleaned up before a StreamingResponse
475
483
# is complete.
476
484
# https://github.com/tiangolo/fastapi/issues/11143
477
- exit_stack = contextlib .AsyncExitStack ()
478
- llama_proxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
479
- if llama_proxy is None :
480
- raise HTTPException (
481
- status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
482
- detail = "Service is not available" ,
483
- )
485
+
486
+ body_model = body .model
484
487
exclude = {
485
488
"n" ,
486
489
"logit_bias_type" ,
487
490
"user" ,
488
491
"min_tokens" ,
489
492
}
490
493
kwargs = body .model_dump (exclude = exclude )
491
- llama = llama_proxy (body .model )
492
- if body .logit_bias is not None :
493
- kwargs ["logit_bias" ] = (
494
- _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
495
- if body .logit_bias_type == "tokens"
496
- else body .logit_bias
497
- )
498
-
499
- if body .grammar is not None :
500
- kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
501
-
502
- if body .min_tokens > 0 :
503
- _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
504
- [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
505
- )
506
- if "logits_processor" not in kwargs :
507
- kwargs ["logits_processor" ] = _min_tokens_logits_processor
508
- else :
509
- kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
510
-
511
- try :
512
- iterator_or_completion : Union [
513
- llama_cpp .ChatCompletion , Iterator [llama_cpp .ChatCompletionChunk ]
514
- ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
515
- except Exception as err :
516
- await exit_stack .aclose ()
517
- raise err
518
-
519
- if isinstance (iterator_or_completion , Iterator ):
520
- # EAFP: It's easier to ask for forgiveness than permission
521
- first_response = await run_in_threadpool (next , iterator_or_completion )
522
-
523
- # If no exception was raised from first_response, we can assume that
524
- # the iterator is valid and we can use it to stream the response.
525
- def iterator () -> Iterator [llama_cpp .ChatCompletionChunk ]:
526
- yield first_response
527
- yield from iterator_or_completion
528
494
495
+ # handle streaming request
496
+ if kwargs .get ("stream" , False ):
529
497
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
530
498
return EventSourceResponse (
531
499
recv_chan ,
532
500
data_sender_callable = partial ( # type: ignore
533
501
get_event_publisher ,
534
502
request = request ,
535
503
inner_send_chan = send_chan ,
536
- iterator = iterator (),
537
- on_complete = exit_stack .aclose ,
504
+ body = body ,
505
+ body_model = body_model ,
506
+ llama_call = llama_cpp .Llama .create_chat_completion ,
507
+ kwargs = kwargs ,
538
508
),
539
509
sep = "\n " ,
540
510
ping_message_factory = _ping_message_factory ,
541
511
)
542
- else :
543
- await exit_stack .aclose ()
544
- return iterator_or_completion
512
+
513
+ # handle regular request
514
+ async with contextlib .asynccontextmanager (get_llama_proxy )() as llama_proxy :
515
+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
516
+
517
+ if await request .is_disconnected ():
518
+ print (
519
+ f"Disconnected from client (via refresh/close) before llm invoked { request .client } "
520
+ )
521
+ raise HTTPException (
522
+ status_code = status .HTTP_400_BAD_REQUEST ,
523
+ detail = "Client closed request" ,
524
+ )
525
+
526
+ return await run_in_threadpool (llama .create_chat_completion , ** kwargs )
545
527
546
528
547
529
@router .get (
0 commit comments