Skip to content

Commit 9bd0c95

Browse files
gjpowerabetlen
andauthoredDec 6, 2024
fix: Avoid thread starvation on many concurrent requests by making use of asyncio to lock llama_proxy context (abetlen#1798)
* fix: make use of asyncio to lock llama_proxy context * fix: use aclose instead of close for AsyncExitStack * fix: don't call exit stack close in stream iterator as it will be called by finally from on_complete anyway * fix: use anyio.Lock instead of asyncio.Lock --------- Co-authored-by: Andrei <[email protected]>
1 parent 073b7e4 commit 9bd0c95

File tree

1 file changed

+14
-20
lines changed

1 file changed

+14
-20
lines changed
 

‎llama_cpp/server/app.py

+14-20
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import typing
66
import contextlib
77

8-
from threading import Lock
8+
from anyio import Lock
99
from functools import partial
1010
from typing import Iterator, List, Optional, Union, Dict
1111

@@ -70,14 +70,14 @@ def set_llama_proxy(model_settings: List[ModelSettings]):
7070
_llama_proxy = LlamaProxy(models=model_settings)
7171

7272

73-
def get_llama_proxy():
73+
async def get_llama_proxy():
7474
# NOTE: This double lock allows the currently streaming llama model to
7575
# check if any other requests are pending in the same thread and cancel
7676
# the stream if so.
77-
llama_outer_lock.acquire()
77+
await llama_outer_lock.acquire()
7878
release_outer_lock = True
7979
try:
80-
llama_inner_lock.acquire()
80+
await llama_inner_lock.acquire()
8181
try:
8282
llama_outer_lock.release()
8383
release_outer_lock = False
@@ -159,7 +159,7 @@ async def get_event_publisher(
159159
request: Request,
160160
inner_send_chan: MemoryObjectSendStream[typing.Any],
161161
iterator: Iterator[typing.Any],
162-
on_complete: typing.Optional[typing.Callable[[], None]] = None,
162+
on_complete: typing.Optional[typing.Callable[[], typing.Awaitable[None]]] = None,
163163
):
164164
server_settings = next(get_server_settings())
165165
interrupt_requests = (
@@ -182,7 +182,7 @@ async def get_event_publisher(
182182
raise e
183183
finally:
184184
if on_complete:
185-
on_complete()
185+
await on_complete()
186186

187187

188188
def _logit_bias_tokens_to_input_ids(
@@ -267,10 +267,8 @@ async def create_completion(
267267
request: Request,
268268
body: CreateCompletionRequest,
269269
) -> llama_cpp.Completion:
270-
exit_stack = contextlib.ExitStack()
271-
llama_proxy = await run_in_threadpool(
272-
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
273-
)
270+
exit_stack = contextlib.AsyncExitStack()
271+
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
274272
if llama_proxy is None:
275273
raise HTTPException(
276274
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
@@ -332,7 +330,6 @@ async def create_completion(
332330
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
333331
yield first_response
334332
yield from iterator_or_completion
335-
exit_stack.close()
336333

337334
send_chan, recv_chan = anyio.create_memory_object_stream(10)
338335
return EventSourceResponse(
@@ -342,13 +339,13 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
342339
request=request,
343340
inner_send_chan=send_chan,
344341
iterator=iterator(),
345-
on_complete=exit_stack.close,
342+
on_complete=exit_stack.aclose,
346343
),
347344
sep="\n",
348345
ping_message_factory=_ping_message_factory,
349346
)
350347
else:
351-
exit_stack.close()
348+
await exit_stack.aclose()
352349
return iterator_or_completion
353350

354351

@@ -477,10 +474,8 @@ async def create_chat_completion(
477474
# where the dependency is cleaned up before a StreamingResponse
478475
# is complete.
479476
# https://github.com/tiangolo/fastapi/issues/11143
480-
exit_stack = contextlib.ExitStack()
481-
llama_proxy = await run_in_threadpool(
482-
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
483-
)
477+
exit_stack = contextlib.AsyncExitStack()
478+
llama_proxy = exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
484479
if llama_proxy is None:
485480
raise HTTPException(
486481
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
@@ -530,7 +525,6 @@ async def create_chat_completion(
530525
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
531526
yield first_response
532527
yield from iterator_or_completion
533-
exit_stack.close()
534528

535529
send_chan, recv_chan = anyio.create_memory_object_stream(10)
536530
return EventSourceResponse(
@@ -540,13 +534,13 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
540534
request=request,
541535
inner_send_chan=send_chan,
542536
iterator=iterator(),
543-
on_complete=exit_stack.close,
537+
on_complete=exit_stack.aclose,
544538
),
545539
sep="\n",
546540
ping_message_factory=_ping_message_factory,
547541
)
548542
else:
549-
exit_stack.close()
543+
await exit_stack.aclose()
550544
return iterator_or_completion
551545

552546

0 commit comments

Comments
 (0)
Please sign in to comment.