Skip to content

[VertexAI] Add support for token-based usage metrics #14406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 11, 2025
1 change: 1 addition & 0 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Note: This feature is in Public Preview, which means that the it is not
subject to any SLA or deprecation policy and could change in
backwards-incompatible ways.
- [feature] Added support for modality-based token count. (#14406)

# 11.6.0
- [changed] The token counts from `GenerativeModel.countTokens(...)` now include
Expand Down
3 changes: 3 additions & 0 deletions FirebaseVertexAI/Sources/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ public struct CountTokensResponse {
/// > Important: This does not include billable image, video or other non-text input. See
/// [Vertex AI pricing](https://cloud.google.com/vertex-ai/generative-ai/pricing) for details.
public let totalBillableCharacters: Int?

/// The breakdown, by modality, of how many tokens are consumed by the prompt.
public let promptTokensDetails: [ModalityTokenCount]
}

// MARK: - Codable Conformances
Expand Down
14 changes: 14 additions & 0 deletions FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ public struct GenerateContentResponse: Sendable {

/// The total number of tokens in both the request and response.
public let totalTokenCount: Int

/// The breakdown, by modality, of how many tokens are consumed by the prompt
public let promptTokensDetails: [ModalityTokenCount]

/// The breakdown, by modality, of how many tokens are consumed by the candidates
public let candidatesTokensDetails: [ModalityTokenCount]
}

/// A list of candidate response content, ordered from best to worst.
Expand Down Expand Up @@ -299,6 +305,8 @@ extension GenerateContentResponse.UsageMetadata: Decodable {
case promptTokenCount
case candidatesTokenCount
case totalTokenCount
case promptTokensDetails
case candidatesTokensDetails
}

public init(from decoder: any Decoder) throws {
Expand All @@ -307,6 +315,12 @@ extension GenerateContentResponse.UsageMetadata: Decodable {
candidatesTokenCount = try container
.decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0
totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0
promptTokensDetails = try container
.decodeIfPresent([ModalityTokenCount].self, forKey: .promptTokensDetails) ??
[ModalityTokenCount]()
candidatesTokensDetails = try container
.decodeIfPresent([ModalityTokenCount].self, forKey: .candidatesTokensDetails) ??
[ModalityTokenCount]()
}
}

Expand Down
61 changes: 61 additions & 0 deletions FirebaseVertexAI/Sources/ModalityTokenCount.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

/// Represents token counting info for a single modality.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ModalityTokenCount: Sendable {
/// The modality associated with this token count.
public let modality: ContentModality

/// The number of tokens counted.
public let tokenCount: Int
}

/// Content part modality.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ContentModality: DecodableProtoEnum, Hashable, Sendable {
enum Kind: String {
case text = "TEXT"
case image = "IMAGE"
case video = "VIDEO"
case audio = "AUDIO"
case document = "DOCUMENT"
}

/// Plain text.
public static let text = ContentModality(kind: .text)

/// Image.
public static let image = ContentModality(kind: .image)

/// Video.
public static let video = ContentModality(kind: .video)

/// Audio.
public static let audio = ContentModality(kind: .audio)

/// Document, e.g. PDF.
public static let document = ContentModality(kind: .document)

/// Returns the raw string representation of the `ContentModality` value.
public let rawValue: String

static let unrecognizedValueMessageCode =
VertexLog.MessageCode.generateContentResponseUnrecognizedContentModality
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ModalityTokenCount: Decodable {}
1 change: 1 addition & 0 deletions FirebaseVertexAI/Sources/VertexLog.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ enum VertexLog {
case decodedInvalidProtoDateMonth = 3009
case decodedInvalidProtoDateDay = 3010
case decodedInvalidCitationPublicationDate = 3011
case generateContentResponseUnrecognizedContentModality = 3012

// SDK State Errors
case generateContentResponseNoCandidates = 4000
Expand Down
43 changes: 43 additions & 0 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,30 @@
XCTAssertEqual(response.functionCalls, [])
}

func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-response-long-usage-metadata",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
let finishReason = try XCTUnwrap(candidate.finishReason)
XCTAssertEqual(finishReason, .stop)
let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
}

func testGenerateContent_success_citations() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
Expand Down Expand Up @@ -488,6 +512,8 @@
XCTAssertEqual(usageMetadata.promptTokenCount, 6)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
XCTAssertEqual(usageMetadata.totalTokenCount, 13)
XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
Expand Down Expand Up @@ -1326,6 +1352,23 @@
XCTAssertEqual(response.totalBillableCharacters, 16)
}

func testCountTokens_succeeds_detailed() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "unary-success-detailed-token-response",
withExtension: "json"
)

let response = try await model.countTokens("Why is the sky blue?")

XCTAssertEqual(response.totalTokens, 1837)
XCTAssertEqual(response.totalBillableCharacters, 117)
XCTAssertEqual(response.promptTokensDetails.count, 2)
XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
}

func testCountTokens_succeeds_allOptions() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "unary-success-total-tokens",
Expand Down Expand Up @@ -1435,7 +1478,7 @@
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
return { request in

Check warning on line 1481 in FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

View workflow job for this annotation

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
// This is *not* an HTTPURLResponse
let response = URLResponse(
url: request.url!,
Expand All @@ -1461,7 +1504,7 @@
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
let bundle = BundleTestUtil.bundle()

Check warning on line 1507 in FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

View workflow job for this annotation

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
let fileURL = try XCTUnwrap(bundle.url(forResource: name, withExtension: ext))
return { request in
let requestURL = try XCTUnwrap(request.url)
Expand Down
Loading