Skip to content

Commit 0fd788a

Browse files
committed
Fix codegate version and similar commands.
While refactoring I removed three lines of coding managing short-circuited requests. We short-circuit requests to implement `codegate version`, `codegate workspace`, and similar commands. Given we now have a provider-native representation of the messages, it is necessary to produce the right message for the given request and provider, so some code must be added in provider-specific modules to handle that. Fixes #1362
1 parent b8d7b65 commit 0fd788a

File tree

6 files changed

+65
-1
lines changed

6 files changed

+65
-1
lines changed

src/codegate/providers/anthropic/provider.py

+2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ async def process_request(
7777
client_type: ClientType,
7878
completion_handler: Callable | None = None,
7979
stream_generator: Callable | None = None,
80+
short_circuiter: Callable | None = None,
8081
):
8182
try:
8283
stream = await self.complete(
@@ -86,6 +87,7 @@ async def process_request(
8687
is_fim_request,
8788
client_type,
8889
completion_handler=completion_handler,
90+
short_circuiter=short_circuiter,
8991
)
9092
except Exception as e:
9193
# check if we have an status code there

src/codegate/providers/base.py

+5
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ async def complete(
259259
is_fim_request: bool,
260260
client_type: ClientType,
261261
completion_handler: Callable | None = None,
262+
short_circuiter: Callable | None = None,
262263
) -> Union[Any, AsyncIterator[Any]]:
263264
"""
264265
Main completion flow with pipeline integration
@@ -287,6 +288,10 @@ async def complete(
287288
is_fim_request,
288289
)
289290

291+
if input_pipeline_result.response and input_pipeline_result.context:
292+
if short_circuiter: # this if should be removed eventually
293+
return short_circuiter(input_pipeline_result)
294+
290295
provider_request = normalized_request # default value
291296
if input_pipeline_result.request:
292297
provider_request = self._input_normalizer.denormalize(input_pipeline_result.request)

src/codegate/providers/openai/provider.py

+4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from codegate.types.openai import (
1515
ChatCompletionRequest,
1616
completions_streaming,
17+
short_circuiter,
1718
stream_generator,
1819
)
1920

@@ -72,6 +73,7 @@ async def process_request(
7273
client_type: ClientType,
7374
completion_handler: Callable | None = None,
7475
stream_generator: Callable | None = None,
76+
short_circuiter: Callable | None = None,
7577
):
7678
try:
7779
stream = await self.complete(
@@ -81,6 +83,7 @@ async def process_request(
8183
is_fim_request=is_fim_request,
8284
client_type=client_type,
8385
completion_handler=completion_handler,
86+
short_circuiter=short_circuiter,
8487
)
8588
except Exception as e:
8689
# Check if we have an status code there
@@ -130,4 +133,5 @@ async def create_completion(
130133
self.base_url,
131134
is_fim_request,
132135
request.state.detected_client,
136+
short_circuiter=short_circuiter,
133137
)

src/codegate/types/openai/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ._generators import (
33
completions_streaming,
44
message_wrapper,
5+
short_circuiter,
56
single_response_generator,
67
stream_generator,
78
streaming,
@@ -74,6 +75,7 @@
7475
"completions_streaming",
7576
"message_wrapper",
7677
"single_response_generator",
78+
"short_circuiter",
7779
"stream_generator",
7880
"streaming",
7981
"LegacyCompletion",

src/codegate/types/openai/_generators.py

+51
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
2+
import time
23
from typing import (
4+
Any,
35
AsyncIterator,
46
)
57

@@ -9,9 +11,16 @@
911
from ._legacy_models import (
1012
LegacyCompletionRequest,
1113
)
14+
from ._request_models import (
15+
ChatCompletionRequest,
16+
)
1217
from ._response_models import (
1318
ChatCompletion,
19+
Choice,
20+
ChoiceDelta,
1421
ErrorDetails,
22+
Message,
23+
MessageDelta,
1524
MessageError,
1625
StreamingChatCompletion,
1726
VllmMessageError,
@@ -20,6 +29,48 @@
2029
logger = structlog.get_logger("codegate")
2130

2231

32+
async def short_circuiter(pipeline_result) -> AsyncIterator[Any]:
33+
# NOTE: This routine MUST be called only when we short-circuit the
34+
# request.
35+
assert pipeline_result.context.shortcut_response # nosec
36+
37+
match pipeline_result.context.input_request.request:
38+
case ChatCompletionRequest(stream=True):
39+
yield StreamingChatCompletion(
40+
id="codegate",
41+
model=pipeline_result.response.model,
42+
created=int(time.time()),
43+
choices=[
44+
ChoiceDelta(
45+
finish_reason="stop",
46+
index=0,
47+
delta=MessageDelta(
48+
content=pipeline_result.response.content,
49+
),
50+
),
51+
],
52+
)
53+
case ChatCompletionRequest(stream=False):
54+
yield ChatCompletion(
55+
id="codegate",
56+
model=pipeline_result.response.model,
57+
created=int(time.time()),
58+
choices=[
59+
Choice(
60+
finish_reason="stop",
61+
index=0,
62+
message=Message(
63+
content=pipeline_result.response.content,
64+
),
65+
),
66+
],
67+
)
68+
case _:
69+
raise ValueError(
70+
f"invalid input request: {pipeline_result.context.input_request.request}"
71+
)
72+
73+
2374
async def stream_generator(stream: AsyncIterator[StreamingChatCompletion]) -> AsyncIterator[str]:
2475
"""OpenAI-style SSE format"""
2576
try:

src/codegate/updates/client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import os
12
from enum import Enum
23

34
import requests
45
import structlog
5-
import os
66

77
logger = structlog.get_logger("codegate")
88

0 commit comments

Comments
 (0)