File tree 6 files changed +65
-1
lines changed
6 files changed +65
-1
lines changed Original file line number Diff line number Diff line change @@ -77,6 +77,7 @@ async def process_request(
77
77
client_type : ClientType ,
78
78
completion_handler : Callable | None = None ,
79
79
stream_generator : Callable | None = None ,
80
+ short_circuiter : Callable | None = None ,
80
81
):
81
82
try :
82
83
stream = await self .complete (
@@ -86,6 +87,7 @@ async def process_request(
86
87
is_fim_request ,
87
88
client_type ,
88
89
completion_handler = completion_handler ,
90
+ short_circuiter = short_circuiter ,
89
91
)
90
92
except Exception as e :
91
93
# check if we have an status code there
Original file line number Diff line number Diff line change @@ -259,6 +259,7 @@ async def complete(
259
259
is_fim_request : bool ,
260
260
client_type : ClientType ,
261
261
completion_handler : Callable | None = None ,
262
+ short_circuiter : Callable | None = None ,
262
263
) -> Union [Any , AsyncIterator [Any ]]:
263
264
"""
264
265
Main completion flow with pipeline integration
@@ -287,6 +288,10 @@ async def complete(
287
288
is_fim_request ,
288
289
)
289
290
291
+ if input_pipeline_result .response and input_pipeline_result .context :
292
+ if short_circuiter : # this if should be removed eventually
293
+ return short_circuiter (input_pipeline_result )
294
+
290
295
provider_request = normalized_request # default value
291
296
if input_pipeline_result .request :
292
297
provider_request = self ._input_normalizer .denormalize (input_pipeline_result .request )
Original file line number Diff line number Diff line change 14
14
from codegate .types .openai import (
15
15
ChatCompletionRequest ,
16
16
completions_streaming ,
17
+ short_circuiter ,
17
18
stream_generator ,
18
19
)
19
20
@@ -72,6 +73,7 @@ async def process_request(
72
73
client_type : ClientType ,
73
74
completion_handler : Callable | None = None ,
74
75
stream_generator : Callable | None = None ,
76
+ short_circuiter : Callable | None = None ,
75
77
):
76
78
try :
77
79
stream = await self .complete (
@@ -81,6 +83,7 @@ async def process_request(
81
83
is_fim_request = is_fim_request ,
82
84
client_type = client_type ,
83
85
completion_handler = completion_handler ,
86
+ short_circuiter = short_circuiter ,
84
87
)
85
88
except Exception as e :
86
89
# Check if we have an status code there
@@ -130,4 +133,5 @@ async def create_completion(
130
133
self .base_url ,
131
134
is_fim_request ,
132
135
request .state .detected_client ,
136
+ short_circuiter = short_circuiter ,
133
137
)
Original file line number Diff line number Diff line change 2
2
from ._generators import (
3
3
completions_streaming ,
4
4
message_wrapper ,
5
+ short_circuiter ,
5
6
single_response_generator ,
6
7
stream_generator ,
7
8
streaming ,
74
75
"completions_streaming" ,
75
76
"message_wrapper" ,
76
77
"single_response_generator" ,
78
+ "short_circuiter" ,
77
79
"stream_generator" ,
78
80
"streaming" ,
79
81
"LegacyCompletion" ,
Original file line number Diff line number Diff line change 1
1
import os
2
+ import time
2
3
from typing import (
4
+ Any ,
3
5
AsyncIterator ,
4
6
)
5
7
9
11
from ._legacy_models import (
10
12
LegacyCompletionRequest ,
11
13
)
14
+ from ._request_models import (
15
+ ChatCompletionRequest ,
16
+ )
12
17
from ._response_models import (
13
18
ChatCompletion ,
19
+ Choice ,
20
+ ChoiceDelta ,
14
21
ErrorDetails ,
22
+ Message ,
23
+ MessageDelta ,
15
24
MessageError ,
16
25
StreamingChatCompletion ,
17
26
VllmMessageError ,
20
29
logger = structlog .get_logger ("codegate" )
21
30
22
31
32
+ async def short_circuiter (pipeline_result ) -> AsyncIterator [Any ]:
33
+ # NOTE: This routine MUST be called only when we short-circuit the
34
+ # request.
35
+ assert pipeline_result .context .shortcut_response # nosec
36
+
37
+ match pipeline_result .context .input_request .request :
38
+ case ChatCompletionRequest (stream = True ):
39
+ yield StreamingChatCompletion (
40
+ id = "codegate" ,
41
+ model = pipeline_result .response .model ,
42
+ created = int (time .time ()),
43
+ choices = [
44
+ ChoiceDelta (
45
+ finish_reason = "stop" ,
46
+ index = 0 ,
47
+ delta = MessageDelta (
48
+ content = pipeline_result .response .content ,
49
+ ),
50
+ ),
51
+ ],
52
+ )
53
+ case ChatCompletionRequest (stream = False ):
54
+ yield ChatCompletion (
55
+ id = "codegate" ,
56
+ model = pipeline_result .response .model ,
57
+ created = int (time .time ()),
58
+ choices = [
59
+ Choice (
60
+ finish_reason = "stop" ,
61
+ index = 0 ,
62
+ message = Message (
63
+ content = pipeline_result .response .content ,
64
+ ),
65
+ ),
66
+ ],
67
+ )
68
+ case _:
69
+ raise ValueError (
70
+ f"invalid input request: { pipeline_result .context .input_request .request } "
71
+ )
72
+
73
+
23
74
async def stream_generator (stream : AsyncIterator [StreamingChatCompletion ]) -> AsyncIterator [str ]:
24
75
"""OpenAI-style SSE format"""
25
76
try :
Original file line number Diff line number Diff line change
1
+ import os
1
2
from enum import Enum
2
3
3
4
import requests
4
5
import structlog
5
- import os
6
6
7
7
logger = structlog .get_logger ("codegate" )
8
8
You can’t perform that action at this time.
0 commit comments