From 4360393dc10b2f1cac279aa2a00cc14d3825da4d Mon Sep 17 00:00:00 2001 From: Seefs Date: Tue, 17 Feb 2026 15:45:14 +0800 Subject: [PATCH] fix: unify usage mapping and include toolUsePromptTokenCount in input tokens --- dto/gemini.go | 14 +- relay/channel/gemini/relay-gemini-native.go | 17 +- relay/channel/gemini/relay-gemini.go | 93 +++-- .../channel/gemini/relay_gemini_usage_test.go | 333 ++++++++++++++++++ 4 files changed, 386 insertions(+), 71 deletions(-) create mode 100644 relay/channel/gemini/relay_gemini_usage_test.go diff --git a/dto/gemini.go b/dto/gemini.go index 0fd74c639d..c963960e5d 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -453,12 +453,14 @@ type GeminiChatResponse struct { } type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` - ThoughtsTokenCount int `json:"thoughtsTokenCount"` - CachedContentTokenCount int `json:"cachedContentTokenCount"` - PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` + PromptTokenCount int `json:"promptTokenCount"` + ToolUsePromptTokenCount int `json:"toolUsePromptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + ThoughtsTokenCount int `json:"thoughtsTokenCount"` + CachedContentTokenCount int `json:"cachedContentTokenCount"` + PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` + ToolUsePromptTokensDetails []GeminiPromptTokensDetails `json:"toolUsePromptTokensDetails"` } type GeminiPromptTokensDetails struct { diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 39485b16f1..1a434a4327 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -42,22 +42,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re } // 计算使用量(基于 UsageMetadata) - usage := dto.Usage{ - PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount, - TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, - } - - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } + usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) service.IOCopyBytesGracefully(c, resp, responseBody) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index b10ec06c7b..b81a148a3d 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1032,6 +1032,46 @@ func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse { } } +func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage { + promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount + if promptTokens <= 0 && fallbackPromptTokens > 0 { + promptTokens = fallbackPromptTokens + } + + usage := dto.Usage{ + PromptTokens: promptTokens, + CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount, + TotalTokens: metadata.TotalTokenCount, + } + usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount + usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount + + for _, detail := range metadata.PromptTokensDetails { + if detail.Modality == "AUDIO" { + usage.PromptTokensDetails.AudioTokens += detail.TokenCount + } else if detail.Modality == "TEXT" { + usage.PromptTokensDetails.TextTokens += detail.TokenCount + } + } + for _, detail := range metadata.ToolUsePromptTokensDetails { + if detail.Modality == "AUDIO" { + usage.PromptTokensDetails.AudioTokens += detail.TokenCount + } else if detail.Modality == "TEXT" { + usage.PromptTokensDetails.TextTokens += detail.TokenCount + } + } + + if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 { + usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens + } + + if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 { + usage.PromptTokensDetails.TextTokens = usage.PromptTokens + } + + return usage +} + func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Id: helper.GetResponseID(c), @@ -1272,18 +1312,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http // 更新使用量统计 if geminiResponse.UsageMetadata.TotalTokenCount != 0 { - usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount - usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } + mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) + *usage = mappedUsage } return callback(data, &geminiResponse) @@ -1295,11 +1325,6 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } } - usage.PromptTokensDetails.TextTokens = usage.PromptTokens - if usage.TotalTokens > 0 { - usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - } - if usage.CompletionTokens <= 0 { if info.ReceivedResponseCount > 0 { usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens()) @@ -1416,21 +1441,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if len(geminiResponse.Candidates) == 0 { - usage := dto.Usage{ - PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - } - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } - if usage.PromptTokens <= 0 { - usage.PromptTokens = info.GetEstimatePromptTokens() - } + usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) var newAPIError *types.NewAPIError if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil { @@ -1466,23 +1477,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse.Model = info.UpstreamModelName - usage := dto.Usage{ - PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, - TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, - } - - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } + usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) fullTextResponse.Usage = usage diff --git a/relay/channel/gemini/relay_gemini_usage_test.go b/relay/channel/gemini/relay_gemini_usage_test.go new file mode 100644 index 0000000000..c8f9f83430 --- /dev/null +++ b/relay/channel/gemini/relay_gemini_usage_test.go @@ -0,0 +1,333 @@ +package gemini + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/types" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + info := &relaycommon.RelayInfo{ + RelayFormat: types.RelayFormatGemini, + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 151, + ToolUsePromptTokenCount: 18329, + CandidatesTokenCount: 1089, + ThoughtsTokenCount: 1120, + TotalTokenCount: 20689, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiChatHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 18480, usage.PromptTokens) + require.Equal(t, 2209, usage.CompletionTokens) + require.Equal(t, 20689, usage.TotalTokens) + require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) +} + +func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldStreamingTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 300 + t.Cleanup(func() { + constant.StreamingTimeout = oldStreamingTimeout + }) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + + chunk := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "partial"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 151, + ToolUsePromptTokenCount: 18329, + CandidatesTokenCount: 1089, + ThoughtsTokenCount: 1120, + TotalTokenCount: 20689, + }, + } + + chunkData, err := common.Marshal(chunk) + require.NoError(t, err) + + streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n") + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(streamBody)), + } + + usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool { + return true + }) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 18480, usage.PromptTokens) + require.Equal(t, 2209, usage.CompletionTokens) + require.Equal(t, 20689, usage.TotalTokens) + require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) +} + +func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 151, + ToolUsePromptTokenCount: 18329, + CandidatesTokenCount: 1089, + ThoughtsTokenCount: 1120, + TotalTokenCount: 20689, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiTextGenerationHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 18480, usage.PromptTokens) + require.Equal(t, 2209, usage.CompletionTokens) + require.Equal(t, 20689, usage.TotalTokens) + require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) +} + +func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + info := &relaycommon.RelayInfo{ + RelayFormat: types.RelayFormatGemini, + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + info.SetEstimatePromptTokens(20) + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 0, + ToolUsePromptTokenCount: 0, + CandidatesTokenCount: 90, + ThoughtsTokenCount: 10, + TotalTokenCount: 110, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiChatHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 20, usage.PromptTokens) + require.Equal(t, 100, usage.CompletionTokens) + require.Equal(t, 110, usage.TotalTokens) +} + +func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldStreamingTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 300 + t.Cleanup(func() { + constant.StreamingTimeout = oldStreamingTimeout + }) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + info.SetEstimatePromptTokens(20) + + chunk := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "partial"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 0, + ToolUsePromptTokenCount: 0, + CandidatesTokenCount: 90, + ThoughtsTokenCount: 10, + TotalTokenCount: 110, + }, + } + + chunkData, err := common.Marshal(chunk) + require.NoError(t, err) + + streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n") + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(streamBody)), + } + + usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool { + return true + }) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 20, usage.PromptTokens) + require.Equal(t, 100, usage.CompletionTokens) + require.Equal(t, 110, usage.TotalTokens) +} + +func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + info.SetEstimatePromptTokens(20) + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 0, + ToolUsePromptTokenCount: 0, + CandidatesTokenCount: 90, + ThoughtsTokenCount: 10, + TotalTokenCount: 110, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiTextGenerationHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 20, usage.PromptTokens) + require.Equal(t, 100, usage.CompletionTokens) + require.Equal(t, 110, usage.TotalTokens) +}