Skip to content

Commit a79cd22

Browse files
authored
Merge pull request #366 from majorcheng/main
fix: add estimated openai usage stats
2 parents 7922ac6 + 927b927 commit a79cd22

File tree

4 files changed

+420
-28
lines changed

4 files changed

+420
-28
lines changed

app/services/grok/services/chat.py

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
parse_tool_call_block,
3333
format_tool_history,
3434
)
35+
from app.services.grok.utils.usage import estimate_chat_usage, estimate_prompt_tokens
3536
from app.services.token import get_token_manager, EffortType
3637

3738

@@ -373,7 +374,8 @@ async def chat_openai(
373374
model_config_override=model_config_override,
374375
)
375376

376-
return response, stream, model
377+
prompt_tokens = estimate_prompt_tokens(message)
378+
return response, stream, model, prompt_tokens
377379

378380

379381
class ChatService:
@@ -426,7 +428,7 @@ async def completions(
426428
try:
427429
# 请求 Grok
428430
service = GrokChatService()
429-
response, _, model_name = await service.chat_openai(
431+
response, _, model_name, prompt_tokens = await service.chat_openai(
430432
token,
431433
model,
432434
messages,
@@ -442,14 +444,27 @@ async def completions(
442444
# 处理响应
443445
if is_stream:
444446
logger.debug(f"Processing stream response: model={model}")
445-
processor = StreamProcessor(model_name, token, show_think, tools=tools, tool_choice=tool_choice)
447+
processor = StreamProcessor(
448+
model_name,
449+
token,
450+
show_think,
451+
tools=tools,
452+
tool_choice=tool_choice,
453+
prompt_tokens=prompt_tokens,
454+
)
446455
return wrap_stream_with_usage(
447456
processor.process(response), token_mgr, token, model
448457
)
449458

450459
# 非流式
451460
logger.debug(f"Processing non-stream response: model={model}")
452-
result = await CollectProcessor(model_name, token, tools=tools, tool_choice=tool_choice).process(response)
461+
result = await CollectProcessor(
462+
model_name,
463+
token,
464+
tools=tools,
465+
tool_choice=tool_choice,
466+
prompt_tokens=prompt_tokens,
467+
).process(response)
453468
try:
454469
model_info = ModelService.get(model)
455470
effort = (
@@ -506,7 +521,15 @@ async def completions(
506521
class StreamProcessor(proc_base.BaseProcessor):
507522
"""Stream response processor."""
508523

509-
def __init__(self, model: str, token: str = "", show_think: bool = None, tools: List[Dict[str, Any]] = None, tool_choice: Any = None):
524+
def __init__(
525+
self,
526+
model: str,
527+
token: str = "",
528+
show_think: bool = None,
529+
tools: List[Dict[str, Any]] = None,
530+
tool_choice: Any = None,
531+
prompt_tokens: int = 0,
532+
):
510533
super().__init__(model, token)
511534
self.response_id: str = None
512535
self.fingerprint: str = ""
@@ -531,6 +554,17 @@ def __init__(self, model: str, token: str = "", show_think: bool = None, tools:
531554
self._tool_partial = ""
532555
self._tool_calls_seen = False
533556
self._tool_call_index = 0
557+
self.prompt_tokens = max(0, int(prompt_tokens or 0))
558+
self._completion_parts: list[str] = []
559+
self._completion_tool_calls: list[dict[str, Any]] = []
560+
561+
def _record_content(self, content: str) -> None:
562+
if content:
563+
self._completion_parts.append(content)
564+
565+
def _record_tool_call(self, tool_call: Any) -> None:
566+
if isinstance(tool_call, dict):
567+
self._completion_tool_calls.append(tool_call)
534568

535569
def _with_tool_index(self, tool_call: Any) -> Any:
536570
if not isinstance(tool_call, dict):
@@ -691,7 +725,14 @@ def _flush_tool_stream(self) -> list[tuple[str, Any]]:
691725
self._tool_state = "text"
692726
return events
693727

694-
def _sse(self, content: str = "", role: str = None, finish: str = None, tool_calls: list = None) -> str:
728+
def _sse(
729+
self,
730+
content: str = "",
731+
role: str = None,
732+
finish: str = None,
733+
tool_calls: list = None,
734+
usage: dict | None = None,
735+
) -> str:
695736
"""Build SSE response."""
696737
delta = {}
697738
if role:
@@ -712,6 +753,8 @@ def _sse(self, content: str = "", role: str = None, finish: str = None, tool_cal
712753
{"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish}
713754
],
714755
}
756+
if usage is not None:
757+
chunk["usage"] = usage
715758
return f"data: {orjson.dumps(chunk).decode()}\n\n"
716759

717760
async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]:
@@ -780,6 +823,7 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N
780823
rendered = await dl_service.render_image(
781824
url, self.token, img_id
782825
)
826+
self._record_content(f"{rendered}\n")
783827
yield self._sse(f"{rendered}\n")
784828

785829
if (
@@ -804,8 +848,10 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N
804848
if original:
805849
title_safe = title.replace("\n", " ").strip()
806850
if title_safe:
851+
self._record_content(f"![{title_safe}]({original})\n")
807852
yield self._sse(f"![{title_safe}]({original})\n")
808853
else:
854+
self._record_content(f"![image]({original})\n")
809855
yield self._sse(f"![image]({original})\n")
810856
continue
811857

@@ -834,17 +880,21 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N
834880
self.think_closed_once = True
835881

836882
if in_think:
883+
self._record_content(filtered)
837884
yield self._sse(filtered)
838885
continue
839886

840887
if self._tool_stream_enabled:
841888
for kind, payload in self._handle_tool_stream(filtered):
842889
if kind == "text":
890+
self._record_content(payload)
843891
yield self._sse(payload)
844892
elif kind == "tool":
893+
self._record_tool_call(payload)
845894
yield self._sse(tool_calls=[payload])
846895
continue
847896

897+
self._record_content(filtered)
848898
yield self._sse(filtered)
849899

850900
if self.think_opened:
@@ -854,13 +904,29 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N
854904
if self._tool_stream_enabled:
855905
for kind, payload in self._flush_tool_stream():
856906
if kind == "text":
907+
self._record_content(payload)
857908
yield self._sse(payload)
858909
elif kind == "tool":
910+
self._record_tool_call(payload)
859911
yield self._sse(tool_calls=[payload])
860912
finish_reason = "tool_calls" if self._tool_calls_seen else "stop"
861-
yield self._sse(finish=finish_reason)
913+
yield self._sse(
914+
finish=finish_reason,
915+
usage=estimate_chat_usage(
916+
prompt_tokens=self.prompt_tokens,
917+
content="".join(self._completion_parts),
918+
tool_calls=self._completion_tool_calls or None,
919+
),
920+
)
862921
else:
863-
yield self._sse(finish="stop")
922+
yield self._sse(
923+
finish="stop",
924+
usage=estimate_chat_usage(
925+
prompt_tokens=self.prompt_tokens,
926+
content="".join(self._completion_parts),
927+
tool_calls=self._completion_tool_calls or None,
928+
),
929+
)
864930

865931
yield "data: [DONE]\n\n"
866932
except asyncio.CancelledError:
@@ -902,11 +968,19 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N
902968
class CollectProcessor(proc_base.BaseProcessor):
903969
"""Non-stream response processor."""
904970

905-
def __init__(self, model: str, token: str = "", tools: List[Dict[str, Any]] = None, tool_choice: Any = None):
971+
def __init__(
972+
self,
973+
model: str,
974+
token: str = "",
975+
tools: List[Dict[str, Any]] = None,
976+
tool_choice: Any = None,
977+
prompt_tokens: int = 0,
978+
):
906979
super().__init__(model, token)
907980
self.filter_tags = get_config("app.filter_tags")
908981
self.tools = tools
909982
self.tool_choice = tool_choice
983+
self.prompt_tokens = max(0, int(prompt_tokens or 0))
910984

911985
def _filter_content(self, content: str) -> str:
912986
"""Filter special tags in content."""
@@ -1098,22 +1172,11 @@ def _render_card(match: re.Match) -> str:
10981172
"finish_reason": finish_reason,
10991173
}
11001174
],
1101-
"usage": {
1102-
"prompt_tokens": 0,
1103-
"completion_tokens": 0,
1104-
"total_tokens": 0,
1105-
"prompt_tokens_details": {
1106-
"cached_tokens": 0,
1107-
"text_tokens": 0,
1108-
"audio_tokens": 0,
1109-
"image_tokens": 0,
1110-
},
1111-
"completion_tokens_details": {
1112-
"text_tokens": 0,
1113-
"audio_tokens": 0,
1114-
"reasoning_tokens": 0,
1115-
},
1116-
},
1175+
"usage": estimate_chat_usage(
1176+
prompt_tokens=self.prompt_tokens,
1177+
content=content,
1178+
tool_calls=tool_calls_result,
1179+
),
11171180
}
11181181

11191182

app/services/grok/services/responses.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import orjson
1010

1111
from app.services.grok.services.chat import ChatService
12+
from app.services.grok.utils.usage import to_responses_usage
1213
from app.services.grok.utils import process as proc_base
1314

1415

@@ -725,8 +726,7 @@ async def create(
725726
model=model,
726727
output_text=content,
727728
tool_calls=tool_calls,
728-
usage=result.get("usage")
729-
or {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
729+
usage=to_responses_usage(result.get("usage")),
730730
status="completed",
731731
instructions=instructions,
732732
max_output_tokens=max_output_tokens,
@@ -768,6 +768,7 @@ async def create(
768768
)
769769

770770
async def _stream() -> AsyncGenerator[str, None]:
771+
final_usage: Optional[Dict[str, Any]] = None
771772
yield adapter.created_event()
772773
yield adapter.in_progress_event()
773774
async for chunk in result:
@@ -780,6 +781,8 @@ async def _stream() -> AsyncGenerator[str, None]:
780781
continue
781782

782783
if data.get("object") == "chat.completion.chunk":
784+
if data.get("usage"):
785+
final_usage = to_responses_usage(data.get("usage"))
783786
delta = (data.get("choices") or [{}])[0].get("delta") or {}
784787
if "content" in delta and delta["content"]:
785788
for event in adapter.ensure_message_started():
@@ -815,7 +818,7 @@ async def _stream() -> AsyncGenerator[str, None]:
815818
yield event
816819
for event in adapter.tool_arguments_done_events():
817820
yield event
818-
yield adapter.completed_event()
821+
yield adapter.completed_event(final_usage)
819822

820823
return _stream()
821824

0 commit comments

Comments
 (0)