5
5
import typing
6
6
import contextlib
7
7
8
- from threading import Lock
8
+ from anyio import Lock
9
9
from functools import partial
10
10
from typing import Iterator , List , Optional , Union , Dict
11
11
@@ -70,14 +70,14 @@ def set_llama_proxy(model_settings: List[ModelSettings]):
70
70
_llama_proxy = LlamaProxy (models = model_settings )
71
71
72
72
73
- def get_llama_proxy ():
73
+ async def get_llama_proxy ():
74
74
# NOTE: This double lock allows the currently streaming llama model to
75
75
# check if any other requests are pending in the same thread and cancel
76
76
# the stream if so.
77
- llama_outer_lock .acquire ()
77
+ await llama_outer_lock .acquire ()
78
78
release_outer_lock = True
79
79
try :
80
- llama_inner_lock .acquire ()
80
+ await llama_inner_lock .acquire ()
81
81
try :
82
82
llama_outer_lock .release ()
83
83
release_outer_lock = False
@@ -159,7 +159,7 @@ async def get_event_publisher(
159
159
request : Request ,
160
160
inner_send_chan : MemoryObjectSendStream [typing .Any ],
161
161
iterator : Iterator [typing .Any ],
162
- on_complete : typing .Optional [typing .Callable [[], None ]] = None ,
162
+ on_complete : typing .Optional [typing .Callable [[], typing . Awaitable [ None ] ]] = None ,
163
163
):
164
164
server_settings = next (get_server_settings ())
165
165
interrupt_requests = (
@@ -182,7 +182,7 @@ async def get_event_publisher(
182
182
raise e
183
183
finally :
184
184
if on_complete :
185
- on_complete ()
185
+ await on_complete ()
186
186
187
187
188
188
def _logit_bias_tokens_to_input_ids (
@@ -267,10 +267,8 @@ async def create_completion(
267
267
request : Request ,
268
268
body : CreateCompletionRequest ,
269
269
) -> llama_cpp .Completion :
270
- exit_stack = contextlib .ExitStack ()
271
- llama_proxy = await run_in_threadpool (
272
- lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
273
- )
270
+ exit_stack = contextlib .AsyncExitStack ()
271
+ llama_proxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
274
272
if llama_proxy is None :
275
273
raise HTTPException (
276
274
status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
@@ -332,7 +330,6 @@ async def create_completion(
332
330
def iterator () -> Iterator [llama_cpp .CreateCompletionStreamResponse ]:
333
331
yield first_response
334
332
yield from iterator_or_completion
335
- exit_stack .close ()
336
333
337
334
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
338
335
return EventSourceResponse (
@@ -342,13 +339,13 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
342
339
request = request ,
343
340
inner_send_chan = send_chan ,
344
341
iterator = iterator (),
345
- on_complete = exit_stack .close ,
342
+ on_complete = exit_stack .aclose ,
346
343
),
347
344
sep = "\n " ,
348
345
ping_message_factory = _ping_message_factory ,
349
346
)
350
347
else :
351
- exit_stack .close ()
348
+ await exit_stack .aclose ()
352
349
return iterator_or_completion
353
350
354
351
@@ -477,10 +474,8 @@ async def create_chat_completion(
477
474
# where the dependency is cleaned up before a StreamingResponse
478
475
# is complete.
479
476
# https://github.com/tiangolo/fastapi/issues/11143
480
- exit_stack = contextlib .ExitStack ()
481
- llama_proxy = await run_in_threadpool (
482
- lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
483
- )
477
+ exit_stack = contextlib .AsyncExitStack ()
478
+ llama_proxy = exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
484
479
if llama_proxy is None :
485
480
raise HTTPException (
486
481
status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
@@ -530,7 +525,6 @@ async def create_chat_completion(
530
525
def iterator () -> Iterator [llama_cpp .ChatCompletionChunk ]:
531
526
yield first_response
532
527
yield from iterator_or_completion
533
- exit_stack .close ()
534
528
535
529
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
536
530
return EventSourceResponse (
@@ -540,13 +534,13 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
540
534
request = request ,
541
535
inner_send_chan = send_chan ,
542
536
iterator = iterator (),
543
- on_complete = exit_stack .close ,
537
+ on_complete = exit_stack .aclose ,
544
538
),
545
539
sep = "\n " ,
546
540
ping_message_factory = _ping_message_factory ,
547
541
)
548
542
else :
549
- exit_stack .close ()
543
+ await exit_stack .aclose ()
550
544
return iterator_or_completion
551
545
552
546
0 commit comments