Skip to content

Commit e8f14ce

Browse files
gjpowerabetlen
andauthoredJan 8, 2025··
fix: streaming resource lock (#1879)
* fix: correct issue with handling lock during streaming move locking for streaming into get_event_publisher call so it is locked and unlocked in the correct task for the streaming reponse * fix: simplify exit stack management for create_chat_completion and create_completion * fix: correct missing `async with` and format code * fix: remove unnecessary explicit use of AsyncExitStack fix: correct type hints for body_model --------- Co-authored-by: Andrei <[email protected]>
1 parent 1d5f534 commit e8f14ce

File tree

1 file changed

+103
-121
lines changed

1 file changed

+103
-121
lines changed
 

‎llama_cpp/server/app.py

+103-121
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from anyio import Lock
99
from functools import partial
10-
from typing import Iterator, List, Optional, Union, Dict
10+
from typing import List, Optional, Union, Dict
1111

1212
import llama_cpp
1313

@@ -155,34 +155,71 @@ def create_app(
155155
return app
156156

157157

158+
def prepare_request_resources(
159+
body: CreateCompletionRequest | CreateChatCompletionRequest,
160+
llama_proxy: LlamaProxy,
161+
body_model: str | None,
162+
kwargs,
163+
) -> llama_cpp.Llama:
164+
if llama_proxy is None:
165+
raise HTTPException(
166+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
167+
detail="Service is not available",
168+
)
169+
llama = llama_proxy(body_model)
170+
if body.logit_bias is not None:
171+
kwargs["logit_bias"] = (
172+
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
173+
if body.logit_bias_type == "tokens"
174+
else body.logit_bias
175+
)
176+
177+
if body.grammar is not None:
178+
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
179+
180+
if body.min_tokens > 0:
181+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
182+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
183+
)
184+
if "logits_processor" not in kwargs:
185+
kwargs["logits_processor"] = _min_tokens_logits_processor
186+
else:
187+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
188+
return llama
189+
190+
158191
async def get_event_publisher(
159192
request: Request,
160193
inner_send_chan: MemoryObjectSendStream[typing.Any],
161-
iterator: Iterator[typing.Any],
162-
on_complete: typing.Optional[typing.Callable[[], typing.Awaitable[None]]] = None,
194+
body: CreateCompletionRequest | CreateChatCompletionRequest,
195+
body_model: str | None,
196+
llama_call,
197+
kwargs,
163198
):
164199
server_settings = next(get_server_settings())
165200
interrupt_requests = (
166201
server_settings.interrupt_requests if server_settings else False
167202
)
168-
async with inner_send_chan:
169-
try:
170-
async for chunk in iterate_in_threadpool(iterator):
171-
await inner_send_chan.send(dict(data=json.dumps(chunk)))
172-
if await request.is_disconnected():
173-
raise anyio.get_cancelled_exc_class()()
174-
if interrupt_requests and llama_outer_lock.locked():
175-
await inner_send_chan.send(dict(data="[DONE]"))
176-
raise anyio.get_cancelled_exc_class()()
177-
await inner_send_chan.send(dict(data="[DONE]"))
178-
except anyio.get_cancelled_exc_class() as e:
179-
print("disconnected")
180-
with anyio.move_on_after(1, shield=True):
181-
print(f"Disconnected from client (via refresh/close) {request.client}")
182-
raise e
183-
finally:
184-
if on_complete:
185-
await on_complete()
203+
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
204+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
205+
async with inner_send_chan:
206+
try:
207+
iterator = await run_in_threadpool(llama_call, llama, **kwargs)
208+
async for chunk in iterate_in_threadpool(iterator):
209+
await inner_send_chan.send(dict(data=json.dumps(chunk)))
210+
if await request.is_disconnected():
211+
raise anyio.get_cancelled_exc_class()()
212+
if interrupt_requests and llama_outer_lock.locked():
213+
await inner_send_chan.send(dict(data="[DONE]"))
214+
raise anyio.get_cancelled_exc_class()()
215+
await inner_send_chan.send(dict(data="[DONE]"))
216+
except anyio.get_cancelled_exc_class() as e:
217+
print("disconnected")
218+
with anyio.move_on_after(1, shield=True):
219+
print(
220+
f"Disconnected from client (via refresh/close) {request.client}"
221+
)
222+
raise e
186223

187224

188225
def _logit_bias_tokens_to_input_ids(
@@ -267,18 +304,11 @@ async def create_completion(
267304
request: Request,
268305
body: CreateCompletionRequest,
269306
) -> llama_cpp.Completion:
270-
exit_stack = contextlib.AsyncExitStack()
271-
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
272-
if llama_proxy is None:
273-
raise HTTPException(
274-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
275-
detail="Service is not available",
276-
)
277307
if isinstance(body.prompt, list):
278308
assert len(body.prompt) <= 1
279309
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
280310

281-
llama = llama_proxy(
311+
body_model = (
282312
body.model
283313
if request.url.path != "/v1/engines/copilot-codex/completions"
284314
else "copilot-codex"
@@ -293,60 +323,38 @@ async def create_completion(
293323
}
294324
kwargs = body.model_dump(exclude=exclude)
295325

296-
if body.logit_bias is not None:
297-
kwargs["logit_bias"] = (
298-
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
299-
if body.logit_bias_type == "tokens"
300-
else body.logit_bias
301-
)
302-
303-
if body.grammar is not None:
304-
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
305-
306-
if body.min_tokens > 0:
307-
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
308-
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
309-
)
310-
if "logits_processor" not in kwargs:
311-
kwargs["logits_processor"] = _min_tokens_logits_processor
312-
else:
313-
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
314-
315-
try:
316-
iterator_or_completion: Union[
317-
llama_cpp.CreateCompletionResponse,
318-
Iterator[llama_cpp.CreateCompletionStreamResponse],
319-
] = await run_in_threadpool(llama, **kwargs)
320-
except Exception as err:
321-
await exit_stack.aclose()
322-
raise err
323-
324-
if isinstance(iterator_or_completion, Iterator):
325-
# EAFP: It's easier to ask for forgiveness than permission
326-
first_response = await run_in_threadpool(next, iterator_or_completion)
327-
328-
# If no exception was raised from first_response, we can assume that
329-
# the iterator is valid and we can use it to stream the response.
330-
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
331-
yield first_response
332-
yield from iterator_or_completion
333-
326+
# handle streaming request
327+
if kwargs.get("stream", False):
334328
send_chan, recv_chan = anyio.create_memory_object_stream(10)
335329
return EventSourceResponse(
336330
recv_chan,
337331
data_sender_callable=partial( # type: ignore
338332
get_event_publisher,
339333
request=request,
340334
inner_send_chan=send_chan,
341-
iterator=iterator(),
342-
on_complete=exit_stack.aclose,
335+
body=body,
336+
body_model=body_model,
337+
llama_call=llama_cpp.Llama.__call__,
338+
kwargs=kwargs,
343339
),
344340
sep="\n",
345341
ping_message_factory=_ping_message_factory,
346342
)
347-
else:
348-
await exit_stack.aclose()
349-
return iterator_or_completion
343+
344+
# handle regular request
345+
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
346+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
347+
348+
if await request.is_disconnected():
349+
print(
350+
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
351+
)
352+
raise HTTPException(
353+
status_code=status.HTTP_400_BAD_REQUEST,
354+
detail="Client closed request",
355+
)
356+
357+
return await run_in_threadpool(llama, **kwargs)
350358

351359

352360
@router.post(
@@ -474,74 +482,48 @@ async def create_chat_completion(
474482
# where the dependency is cleaned up before a StreamingResponse
475483
# is complete.
476484
# https://github.com/tiangolo/fastapi/issues/11143
477-
exit_stack = contextlib.AsyncExitStack()
478-
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
479-
if llama_proxy is None:
480-
raise HTTPException(
481-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
482-
detail="Service is not available",
483-
)
485+
486+
body_model = body.model
484487
exclude = {
485488
"n",
486489
"logit_bias_type",
487490
"user",
488491
"min_tokens",
489492
}
490493
kwargs = body.model_dump(exclude=exclude)
491-
llama = llama_proxy(body.model)
492-
if body.logit_bias is not None:
493-
kwargs["logit_bias"] = (
494-
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
495-
if body.logit_bias_type == "tokens"
496-
else body.logit_bias
497-
)
498-
499-
if body.grammar is not None:
500-
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
501-
502-
if body.min_tokens > 0:
503-
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
504-
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
505-
)
506-
if "logits_processor" not in kwargs:
507-
kwargs["logits_processor"] = _min_tokens_logits_processor
508-
else:
509-
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
510-
511-
try:
512-
iterator_or_completion: Union[
513-
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
514-
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
515-
except Exception as err:
516-
await exit_stack.aclose()
517-
raise err
518-
519-
if isinstance(iterator_or_completion, Iterator):
520-
# EAFP: It's easier to ask for forgiveness than permission
521-
first_response = await run_in_threadpool(next, iterator_or_completion)
522-
523-
# If no exception was raised from first_response, we can assume that
524-
# the iterator is valid and we can use it to stream the response.
525-
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
526-
yield first_response
527-
yield from iterator_or_completion
528494

495+
# handle streaming request
496+
if kwargs.get("stream", False):
529497
send_chan, recv_chan = anyio.create_memory_object_stream(10)
530498
return EventSourceResponse(
531499
recv_chan,
532500
data_sender_callable=partial( # type: ignore
533501
get_event_publisher,
534502
request=request,
535503
inner_send_chan=send_chan,
536-
iterator=iterator(),
537-
on_complete=exit_stack.aclose,
504+
body=body,
505+
body_model=body_model,
506+
llama_call=llama_cpp.Llama.create_chat_completion,
507+
kwargs=kwargs,
538508
),
539509
sep="\n",
540510
ping_message_factory=_ping_message_factory,
541511
)
542-
else:
543-
await exit_stack.aclose()
544-
return iterator_or_completion
512+
513+
# handle regular request
514+
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
515+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
516+
517+
if await request.is_disconnected():
518+
print(
519+
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
520+
)
521+
raise HTTPException(
522+
status_code=status.HTTP_400_BAD_REQUEST,
523+
detail="Client closed request",
524+
)
525+
526+
return await run_in_threadpool(llama.create_chat_completion, **kwargs)
545527

546528

547529
@router.get(

0 commit comments

Comments
 (0)
Please sign in to comment.