diff --git a/firebase-vertexai/CHANGELOG.md b/firebase-vertexai/CHANGELOG.md index 319676102e5..b519e167307 100644 --- a/firebase-vertexai/CHANGELOG.md +++ b/firebase-vertexai/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased * [feature] added support for `responseSchema` in `GenerationConfig`. +* [changed] Made `FunctionCallPart.args` nullable. # 16.0.0-beta03 * [changed] Breaking Change: changed `Schema.int` to return 32 bit integers instead of 64 bit (long). diff --git a/firebase-vertexai/firebase-vertexai.gradle.kts b/firebase-vertexai/firebase-vertexai.gradle.kts index 01f46b60a5b..fcef66e1d6a 100644 --- a/firebase-vertexai/firebase-vertexai.gradle.kts +++ b/firebase-vertexai/firebase-vertexai.gradle.kts @@ -61,7 +61,7 @@ dependencies { implementation("com.google.firebase:firebase-components:18.0.0") implementation("com.google.firebase:firebase-annotations:16.2.0") implementation("com.google.firebase:firebase-appcheck-interop:17.1.0") - implementation("com.google.ai.client.generativeai:common:0.9.0") + implementation("com.google.ai.client.generativeai:common:0.10.0") implementation(libs.androidx.annotation) implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1") implementation("androidx.core:core-ktx:1.12.0") diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt index f6f77b42142..7d992962451 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt @@ -86,7 +86,7 @@ internal fun Part.toInternal(): com.google.ai.client.generativeai.common.shared. ) ) is com.google.firebase.vertexai.type.FunctionCallPart -> - FunctionCallPart(FunctionCall(name, args.orEmpty())) + FunctionCallPart(FunctionCall(name, args)) is com.google.firebase.vertexai.type.FunctionResponsePart -> FunctionResponsePart(FunctionResponse(name, response.toInternal())) is FileDataPart -> @@ -220,7 +220,7 @@ internal fun com.google.ai.client.generativeai.common.shared.Part.toPublic(): Pa is FunctionCallPart -> com.google.firebase.vertexai.type.FunctionCallPart( functionCall.name, - functionCall.args.orEmpty(), + functionCall.args, ) is FunctionResponsePart -> com.google.firebase.vertexai.type.FunctionResponsePart( diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclarations.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclarations.kt index 3d836bad27b..7a57d508f4a 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclarations.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclarations.kt @@ -368,7 +368,7 @@ fun defineFunction( ) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) private fun FunctionCallPart.getArgOrThrow(param: Schema): T { - return param.fromString(args[param.name]) + return param.fromString(args?.get(param.name)) ?: throw RuntimeException( "Missing argument for parameter \"${param.name}\" for function \"$name\"" ) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt index 60b31b98a7a..61827e159b7 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt @@ -49,7 +49,7 @@ class BlobPart(val mimeType: String, val blob: ByteArray) : Part * @param name the name of the function to call * @param args the function parameters and values as a [Map] */ -class FunctionCallPart(val name: String, val args: Map) : Part +class FunctionCallPart(val name: String, val args: Map?) : Part /** * Represents function call output to be returned to the model when it requests a function call. diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt index eeb45bba43e..c4c2d4a8dda 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt @@ -335,7 +335,8 @@ internal class UnarySnapshotTests { val response = model.generateContent("prompt") val callPart = (response.candidates.first().content.parts.first() as FunctionCallPart) - callPart.args["season"] shouldBe null + callPart.args shouldNotBe null + callPart.args?.get("seasons") shouldBe null } } @@ -352,7 +353,19 @@ internal class UnarySnapshotTests { it.parts.first().shouldBeInstanceOf() } - callPart.args["current"] shouldBe "true" + callPart.args?.get("current") shouldBe "true" + } + } + + @Test + fun `function call has no arguments field`() = + goldenUnaryFile("unary-success-function-call-empty-arguments.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + val callPart = response.functionCalls.first() + + callPart.name shouldBe "current_time" + callPart.args shouldBe null } } @@ -364,7 +377,7 @@ internal class UnarySnapshotTests { val callPart = response.functionCalls.shouldNotBeEmpty().first() callPart.name shouldBe "current_time" - callPart.args.isEmpty() shouldBe true + callPart.args?.isEmpty() shouldBe true } } @@ -376,8 +389,8 @@ internal class UnarySnapshotTests { val callPart = response.functionCalls.shouldNotBeEmpty().first() callPart.name shouldBe "sum" - callPart.args["x"] shouldBe "4" - callPart.args["y"] shouldBe "5" + callPart.args?.get("x") shouldBe "4" + callPart.args?.get("y") shouldBe "5" } } @@ -391,7 +404,7 @@ internal class UnarySnapshotTests { callList.size shouldBe 3 callList.forEach { it.name shouldBe "sum" - it.args.size shouldBe 2 + it.args?.size shouldBe 2 } } } @@ -405,7 +418,7 @@ internal class UnarySnapshotTests { response.text shouldBe "The sum of [1, 2, 3] is" callList.size shouldBe 2 - callList.forEach { it.args.size shouldBe 2 } + callList.forEach { it.args?.size shouldBe 2 } } }