Skip to content

Commit a76229b

Browse files
authored
.Net: Google Gemini - Adding response schema (Structured Outputs support) (#10135)
### Motivation and Context - Resolves #9501 ### Description Allow schema definition for the LLM response. Similar to `Structured Output` concept from OpenAI.
1 parent e8b31a2 commit a76229b

File tree

6 files changed

+248
-6
lines changed

6 files changed

+248
-6
lines changed

.editorconfig

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ dotnet_diagnostic.IDE0005.severity = warning # Remove unnecessary using directiv
136136
dotnet_diagnostic.IDE0009.severity = warning # Add this or Me qualification
137137
dotnet_diagnostic.IDE0011.severity = warning # Add braces
138138
dotnet_diagnostic.IDE0018.severity = warning # Inline variable declaration
139+
139140
dotnet_diagnostic.IDE0032.severity = warning # Use auto-implemented property
140141
dotnet_diagnostic.IDE0034.severity = warning # Simplify 'default' expression
141142
dotnet_diagnostic.IDE0035.severity = warning # Remove unreachable code
@@ -221,20 +222,29 @@ dotnet_diagnostic.RCS1241.severity = none # Implement IComparable when implement
221222
dotnet_diagnostic.IDE0001.severity = none # Simplify name
222223
dotnet_diagnostic.IDE0002.severity = none # Simplify member access
223224
dotnet_diagnostic.IDE0004.severity = none # Remove unnecessary cast
225+
dotnet_diagnostic.IDE0010.severity = none # Populate switch
226+
dotnet_diagnostic.IDE0021.severity = none # Use block body for constructors
227+
dotnet_diagnostic.IDE0022.severity = none # Use block body for methods
228+
dotnet_diagnostic.IDE0024.severity = none # Use block body for operator
224229
dotnet_diagnostic.IDE0035.severity = none # Remove unreachable code
225230
dotnet_diagnostic.IDE0051.severity = none # Remove unused private member
226231
dotnet_diagnostic.IDE0052.severity = none # Remove unread private member
227232
dotnet_diagnostic.IDE0058.severity = none # Remove unused expression value
228233
dotnet_diagnostic.IDE0059.severity = none # Unnecessary assignment of a value
229234
dotnet_diagnostic.IDE0060.severity = none # Remove unused parameter
235+
dotnet_diagnostic.IDE0061.severity = none # Use block body for local function
230236
dotnet_diagnostic.IDE0079.severity = none # Remove unnecessary suppression.
231237
dotnet_diagnostic.IDE0080.severity = none # Remove unnecessary suppression operator.
232238
dotnet_diagnostic.IDE0100.severity = none # Remove unnecessary equality operator
233239
dotnet_diagnostic.IDE0110.severity = none # Remove unnecessary discards
234240
dotnet_diagnostic.IDE0130.severity = none # Namespace does not match folder structure
241+
dotnet_diagnostic.IDE0290.severity = none # Use primary constructor
235242
dotnet_diagnostic.IDE0032.severity = none # Use auto property
236243
dotnet_diagnostic.IDE0160.severity = none # Use block-scoped namespace
237244
dotnet_diagnostic.IDE1006.severity = warning # Naming rule violations
245+
dotnet_diagnostic.IDE0046.severity = suggestion # If statement can be simplified
246+
dotnet_diagnostic.IDE0056.severity = suggestion # Indexing can be simplified
247+
dotnet_diagnostic.IDE0057.severity = suggestion # Substring can be simplified
238248

239249
###############################
240250
# Naming Conventions #

dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.IO;
66
using System.Linq;
77
using System.Net.Http;
8+
using System.Text;
89
using System.Text.Json;
910
using System.Threading.Tasks;
1011
using Microsoft.SemanticKernel.ChatCompletion;
@@ -419,13 +420,34 @@ public async Task ItCreatesPostRequestWithSemanticKernelVersionHeaderAsync()
419420
Assert.Equal(expectedVersion, header);
420421
}
421422

423+
[Fact]
424+
public async Task ItCreatesPostRequestWithResponseSchemaPropertyAsync()
425+
{
426+
// Arrange
427+
var client = this.CreateChatCompletionClient();
428+
var chatHistory = CreateSampleChatHistory();
429+
var settings = new GeminiPromptExecutionSettings { ResponseMimeType = "application/json", ResponseSchema = typeof(List<int>) };
430+
431+
// Act
432+
await client.GenerateChatMessageAsync(chatHistory, settings);
433+
434+
// Assert
435+
Assert.NotNull(this._messageHandlerStub.RequestHeaders);
436+
437+
var responseBody = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!);
438+
439+
Assert.Contains("responseSchema", responseBody, StringComparison.Ordinal);
440+
Assert.Contains("\"responseSchema\":{\"type\":\"array\",\"items\":{\"type\":\"integer\"}}", responseBody, StringComparison.Ordinal);
441+
Assert.Contains("\"responseMimeType\":\"application/json\"", responseBody, StringComparison.Ordinal);
442+
}
443+
422444
[Fact]
423445
public async Task ItCanUseValueTasksSequentiallyForBearerTokenAsync()
424446
{
425447
// Arrange
426448
var bearerTokenGenerator = new BearerTokenGenerator()
427449
{
428-
BearerKeys = new List<string> { "key1", "key2", "key3" }
450+
BearerKeys = ["key1", "key2", "key3"]
429451
};
430452

431453
var responseContent = File.ReadAllText(ChatTestDataFilePath);
@@ -442,7 +464,7 @@ public async Task ItCanUseValueTasksSequentiallyForBearerTokenAsync()
442464
httpClient: httpClient,
443465
modelId: "fake-model",
444466
apiVersion: VertexAIVersion.V1,
445-
bearerTokenProvider: () => bearerTokenGenerator.GetBearerToken(),
467+
bearerTokenProvider: bearerTokenGenerator.GetBearerToken,
446468
location: "fake-location",
447469
projectId: "fake-project-id");
448470

dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System;
44
using System.Collections.Generic;
55
using System.Linq;
6+
using System.Text.Json;
67
using System.Text.Json.Nodes;
78
using Microsoft.SemanticKernel;
89
using Microsoft.SemanticKernel.ChatCompletion;
@@ -25,7 +26,8 @@ public void FromPromptItReturnsWithConfiguration()
2526
MaxTokens = 10,
2627
TopP = 0.9,
2728
AudioTimestamp = true,
28-
ResponseMimeType = "application/json"
29+
ResponseMimeType = "application/json",
30+
ResponseSchema = JsonSerializer.Deserialize<JsonElement>(@"{""schema"":""schema""}")
2931
};
3032

3133
// Act
@@ -37,9 +39,120 @@ public void FromPromptItReturnsWithConfiguration()
3739
Assert.Equal(executionSettings.MaxTokens, request.Configuration.MaxOutputTokens);
3840
Assert.Equal(executionSettings.AudioTimestamp, request.Configuration.AudioTimestamp);
3941
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
42+
Assert.Equal(executionSettings.ResponseSchema, request.Configuration.ResponseSchema);
4043
Assert.Equal(executionSettings.TopP, request.Configuration.TopP);
4144
}
4245

46+
[Fact]
47+
public void JsonElementResponseSchemaFromPromptReturnsAsExpected()
48+
{
49+
// Arrange
50+
var prompt = "prompt-example";
51+
var executionSettings = new GeminiPromptExecutionSettings
52+
{
53+
ResponseMimeType = "application/json",
54+
ResponseSchema = Microsoft.Extensions.AI.AIJsonUtilities.CreateJsonSchema(typeof(int))
55+
};
56+
57+
// Act
58+
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);
59+
60+
// Assert
61+
Assert.NotNull(request.Configuration);
62+
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
63+
Assert.Equal(executionSettings.ResponseSchema, request.Configuration.ResponseSchema);
64+
}
65+
66+
[Fact]
67+
public void KernelJsonSchemaFromPromptReturnsAsExpected()
68+
{
69+
// Arrange
70+
var prompt = "prompt-example";
71+
var executionSettings = new GeminiPromptExecutionSettings
72+
{
73+
ResponseMimeType = "application/json",
74+
ResponseSchema = KernelJsonSchemaBuilder.Build(typeof(int))
75+
};
76+
77+
// Act
78+
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);
79+
80+
// Assert
81+
Assert.NotNull(request.Configuration);
82+
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
83+
Assert.Equal(((KernelJsonSchema)executionSettings.ResponseSchema).RootElement, request.Configuration.ResponseSchema);
84+
}
85+
86+
[Fact]
87+
public void JsonNodeResponseSchemaFromPromptReturnsAsExpected()
88+
{
89+
// Arrange
90+
var prompt = "prompt-example";
91+
var executionSettings = new GeminiPromptExecutionSettings
92+
{
93+
ResponseMimeType = "application/json",
94+
ResponseSchema = JsonNode.Parse(Microsoft.Extensions.AI.AIJsonUtilities.CreateJsonSchema(typeof(int)).GetRawText())
95+
};
96+
97+
// Act
98+
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);
99+
100+
// Assert
101+
Assert.NotNull(request.Configuration);
102+
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
103+
Assert.NotNull(request.Configuration.ResponseSchema);
104+
Assert.Equal(JsonSerializer.SerializeToElement(executionSettings.ResponseSchema).GetRawText(), request.Configuration.ResponseSchema.Value.GetRawText());
105+
}
106+
107+
[Fact]
108+
public void JsonDocumentResponseSchemaFromPromptReturnsAsExpected()
109+
{
110+
// Arrange
111+
var prompt = "prompt-example";
112+
var executionSettings = new GeminiPromptExecutionSettings
113+
{
114+
ResponseMimeType = "application/json",
115+
ResponseSchema = JsonDocument.Parse(Microsoft.Extensions.AI.AIJsonUtilities.CreateJsonSchema(typeof(int)).GetRawText())
116+
};
117+
118+
// Act
119+
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);
120+
121+
// Assert
122+
Assert.NotNull(request.Configuration);
123+
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
124+
Assert.NotNull(request.Configuration.ResponseSchema);
125+
Assert.Equal(JsonSerializer.SerializeToElement(executionSettings.ResponseSchema).GetRawText(), request.Configuration.ResponseSchema.Value.GetRawText());
126+
}
127+
128+
[Theory]
129+
[InlineData(typeof(int), "integer")]
130+
[InlineData(typeof(bool), "boolean")]
131+
[InlineData(typeof(string), "string")]
132+
[InlineData(typeof(double), "number")]
133+
[InlineData(typeof(GeminiRequest), "object")]
134+
[InlineData(typeof(List<int>), "array")]
135+
public void TypeResponseSchemaFromPromptReturnsAsExpected(Type type, string expectedSchemaType)
136+
{
137+
// Arrange
138+
var prompt = "prompt-example";
139+
var executionSettings = new GeminiPromptExecutionSettings
140+
{
141+
ResponseMimeType = "application/json",
142+
ResponseSchema = type
143+
};
144+
145+
// Act
146+
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);
147+
148+
// Assert
149+
Assert.NotNull(request.Configuration);
150+
var schemaType = request.Configuration.ResponseSchema?.GetProperty("type").GetString();
151+
152+
Assert.Equal(expectedSchemaType, schemaType);
153+
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
154+
}
155+
43156
[Fact]
44157
public void FromPromptItReturnsWithSafetySettings()
45158
{

dotnet/src/Connectors/Connectors.Google.UnitTests/GeminiPromptExecutionSettingsTests.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public void ItCreatesGeminiExecutionSettingsWithCorrectDefaults()
2828
Assert.Null(executionSettings.SafetySettings);
2929
Assert.Null(executionSettings.AudioTimestamp);
3030
Assert.Null(executionSettings.ResponseMimeType);
31+
Assert.Null(executionSettings.ResponseSchema);
3132
Assert.Equal(GeminiPromptExecutionSettings.DefaultTextMaxTokens, executionSettings.MaxTokens);
3233
}
3334

@@ -70,7 +71,8 @@ public void ItCreatesGeminiExecutionSettingsFromExtensionDataSnakeCase()
7071
{ "max_tokens", 1000 },
7172
{ "temperature", 0 },
7273
{ "audio_timestamp", true },
73-
{ "response_mimetype", "application/json" }
74+
{ "response_mimetype", "application/json" },
75+
{ "response_schema", JsonSerializer.Serialize(new { }) }
7476
}
7577
};
7678

@@ -81,6 +83,9 @@ public void ItCreatesGeminiExecutionSettingsFromExtensionDataSnakeCase()
8183
Assert.NotNull(executionSettings);
8284
Assert.Equal(1000, executionSettings.MaxTokens);
8385
Assert.Equal(0, executionSettings.Temperature);
86+
Assert.Equal("application/json", executionSettings.ResponseMimeType);
87+
Assert.NotNull(executionSettings.ResponseSchema);
88+
Assert.Equal(typeof(JsonElement), executionSettings.ResponseSchema.GetType());
8489
Assert.True(executionSettings.AudioTimestamp);
8590
}
8691

dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,25 @@
44
using System.Collections.Generic;
55
using System.Linq;
66
using System.Text.Json;
7+
using System.Text.Json.Nodes;
78
using System.Text.Json.Serialization;
9+
using System.Text.Json.Serialization.Metadata;
10+
using Microsoft.Extensions.AI;
811
using Microsoft.SemanticKernel.ChatCompletion;
912

1013
namespace Microsoft.SemanticKernel.Connectors.Google.Core;
1114

1215
internal sealed class GeminiRequest
1316
{
17+
private static JsonSerializerOptions? s_options;
18+
private static readonly AIJsonSchemaCreateOptions s_schemaOptions = new()
19+
{
20+
IncludeSchemaKeyword = false,
21+
IncludeTypeInEnumSchemas = true,
22+
RequireAllProperties = false,
23+
DisallowAdditionalProperties = false,
24+
};
25+
1426
[JsonPropertyName("contents")]
1527
public IList<GeminiContent> Contents { get; set; } = null!;
1628

@@ -249,10 +261,57 @@ private static void AddConfiguration(GeminiPromptExecutionSettings executionSett
249261
StopSequences = executionSettings.StopSequences,
250262
CandidateCount = executionSettings.CandidateCount,
251263
AudioTimestamp = executionSettings.AudioTimestamp,
252-
ResponseMimeType = executionSettings.ResponseMimeType
264+
ResponseMimeType = executionSettings.ResponseMimeType,
265+
ResponseSchema = GetResponseSchemaConfig(executionSettings.ResponseSchema)
253266
};
254267
}
255268

269+
private static JsonElement? GetResponseSchemaConfig(object? responseSchemaSettings)
270+
{
271+
if (responseSchemaSettings is null)
272+
{
273+
return null;
274+
}
275+
276+
var jsonElement = responseSchemaSettings switch
277+
{
278+
JsonElement element => element,
279+
Type type => CreateSchema(type, GetDefaultOptions()),
280+
KernelJsonSchema kernelJsonSchema => kernelJsonSchema.RootElement,
281+
JsonNode jsonNode => JsonSerializer.SerializeToElement(jsonNode, GetDefaultOptions()),
282+
JsonDocument jsonDocument => JsonSerializer.SerializeToElement(jsonDocument, GetDefaultOptions()),
283+
_ => CreateSchema(responseSchemaSettings.GetType(), GetDefaultOptions())
284+
};
285+
286+
return jsonElement;
287+
}
288+
289+
private static JsonElement CreateSchema(
290+
Type type,
291+
JsonSerializerOptions options,
292+
string? description = null,
293+
AIJsonSchemaCreateOptions? configuration = null)
294+
{
295+
configuration ??= s_schemaOptions;
296+
return AIJsonUtilities.CreateJsonSchema(type, description, serializerOptions: options, inferenceOptions: configuration);
297+
}
298+
299+
private static JsonSerializerOptions GetDefaultOptions()
300+
{
301+
if (s_options is null)
302+
{
303+
JsonSerializerOptions options = new()
304+
{
305+
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
306+
Converters = { new JsonStringEnumConverter() },
307+
};
308+
options.MakeReadOnly();
309+
s_options = options;
310+
}
311+
312+
return s_options;
313+
}
314+
256315
private static void AddSafetySettings(GeminiPromptExecutionSettings executionSettings, GeminiRequest request)
257316
{
258317
request.SafetySettings = executionSettings.SafetySettings?.Select(s
@@ -292,5 +351,9 @@ internal sealed class ConfigurationElement
292351
[JsonPropertyName("responseMimeType")]
293352
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
294353
public string? ResponseMimeType { get; set; }
354+
355+
[JsonPropertyName("responseSchema")]
356+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
357+
public JsonElement? ResponseSchema { get; set; }
295358
}
296359
}

0 commit comments

Comments
 (0)