Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: microsoft/semantic-kernel
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 20f5270b21fb66826696c731c0aab112abd959b9
Choose a base ref
..
head repository: microsoft/semantic-kernel
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 2afa19f6d0feff177887e37aa4460865d7c3eaec
Choose a head ref
4 changes: 2 additions & 2 deletions dotnet/nuget/nuget-package.props
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
<Project>
<PropertyGroup>
<!-- Central version prefix - applies to all nuget packages. -->
<VersionPrefix>1.33.0</VersionPrefix>
<VersionPrefix>1.34.0</VersionPrefix>
<PackageVersion Condition="'$(VersionSuffix)' != ''">$(VersionPrefix)-$(VersionSuffix)</PackageVersion>
<PackageVersion Condition="'$(VersionSuffix)' == ''">$(VersionPrefix)</PackageVersion>

<Configurations>Debug;Release;Publish</Configurations>
<IsPackable>true</IsPackable>

<!-- Package validation. Baseline Version should be the latest version available on NuGet. -->
<PackageValidationBaselineVersion>1.29.0</PackageValidationBaselineVersion>
<PackageValidationBaselineVersion>1.33.0</PackageValidationBaselineVersion>
<!-- Validate assembly attributes only for Publish builds -->
<NoWarn Condition="'$(Configuration)' != 'Publish'">$(NoWarn);CP0003</NoWarn>
<!-- Do not validate reference assemblies -->
2 changes: 2 additions & 0 deletions dotnet/samples/Demos/OnnxSimpleRAG/Program.cs
Original file line number Diff line number Diff line change
@@ -14,6 +14,8 @@
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.PromptTemplates.Handlebars;

Console.OutputEncoding = System.Text.Encoding.UTF8;

// Ensure you follow the preparation steps provided in the README.md
var config = new ConfigurationBuilder().AddUserSecrets<Program>().Build();

Original file line number Diff line number Diff line change
@@ -108,8 +108,10 @@ private async IAsyncEnumerable<string> RunInferenceAsync(ChatHistory chatHistory
generator.GenerateNextToken();

var outputTokens = generator.GetSequence(0);
var newToken = outputTokens.Slice(outputTokens.Length - 1, 1);
string output = this.GetTokenizer().Decode(newToken);
var newToken = outputTokens[outputTokens.Length - 1];

using var tokenizerStream = this.GetTokenizer().CreateStream();
string output = tokenizerStream.Decode(newToken);

if (removeNextTokenStartingWithSpace && output[0] == ' ')
{
Original file line number Diff line number Diff line change
@@ -3,20 +3,30 @@
using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Http;
using Microsoft.SemanticKernel.Services;
using Moq;
using OpenAI;
using OpenAI.Chat;
using Xunit;
using BinaryContent = System.ClientModel.BinaryContent;
using ChatMessageContent = Microsoft.SemanticKernel.ChatMessageContent;

namespace SemanticKernel.Connectors.OpenAI.UnitTests.Core;

public partial class ClientCoreTests
{
[Fact]
@@ -240,4 +250,90 @@ public void ItDoesNotThrowWhenUsingCustomEndpointAndApiKeyIsNotProvided()
clientCore = new ClientCore("modelId", "", endpoint: new Uri("http://localhost"));
clientCore = new ClientCore("modelId", apiKey: null!, endpoint: new Uri("http://localhost"));
}

[Theory]
[ClassData(typeof(ChatMessageContentWithFunctionCalls))]
public async Task ItShouldReplaceDisallowedCharactersInFunctionName(ChatMessageContent chatMessageContent, bool nameContainsDisallowedCharacter)
{
// Arrange
using var responseMessage = new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.json"))
};

using HttpMessageHandlerStub handler = new();
handler.ResponseToReturn = responseMessage;
using HttpClient client = new(handler);

var clientCore = new ClientCore("modelId", "apikey", httpClient: client);

ChatHistory chatHistory = [chatMessageContent];

// Act
await clientCore.GetChatMessageContentsAsync("gpt-4", chatHistory, new OpenAIPromptExecutionSettings(), new Kernel());

// Assert
JsonElement jsonString = JsonSerializer.Deserialize<JsonElement>(handler.RequestContent);

var function = jsonString.GetProperty("messages")[0].GetProperty("tool_calls")[0].GetProperty("function");

if (nameContainsDisallowedCharacter)
{
// The original name specified in function calls is "bar.foo", which contains a disallowed character '.'.
Assert.Equal("bar_foo", function.GetProperty("name").GetString());
}
else
{
// The original name specified in function calls is "bar-foo" and contains no disallowed characters.
Assert.Equal("bar-foo", function.GetProperty("name").GetString());
}
}

internal sealed class ChatMessageContentWithFunctionCalls : TheoryData<ChatMessageContent, bool>
{
private static readonly ChatToolCall s_functionCallWithInvalidFunctionName = ChatToolCall.CreateFunctionToolCall(id: "call123", functionName: "bar.foo", functionArguments: BinaryData.FromString("{}"));

private static readonly ChatToolCall s_functionCallWithValidFunctionName = ChatToolCall.CreateFunctionToolCall(id: "call123", functionName: "bar-foo", functionArguments: BinaryData.FromString("{}"));

public ChatMessageContentWithFunctionCalls()
{
this.AddMessagesWithFunctionCallsWithInvalidFunctionName();
}

private void AddMessagesWithFunctionCallsWithInvalidFunctionName()
{
// Case when function calls are available via the `Tools` property.
this.Add(new OpenAIChatMessageContent(AuthorRole.Assistant, "", "", [s_functionCallWithInvalidFunctionName]), true);

// Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of ChatToolCall type.
this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary<string, object?>()
{
[OpenAIChatMessageContent.FunctionToolCallsProperty] = new ChatToolCall[] { s_functionCallWithInvalidFunctionName }
}), true);

// Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of JsonElement type.
this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary<string, object?>()
{
[OpenAIChatMessageContent.FunctionToolCallsProperty] = JsonSerializer.Deserialize<JsonElement>($$"""[{"Id": "{{s_functionCallWithInvalidFunctionName.Id}}", "Name": "{{s_functionCallWithInvalidFunctionName.FunctionName}}", "Arguments": "{{s_functionCallWithInvalidFunctionName.FunctionArguments}}"}]""")
}), true);
}

private void AddMessagesWithFunctionCallsWithValidFunctionName()
{
// Case when function calls are available via the `Tools` property.
this.Add(new OpenAIChatMessageContent(AuthorRole.Assistant, "", "", [s_functionCallWithValidFunctionName]), false);

// Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of ChatToolCall type.
this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary<string, object?>()
{
[OpenAIChatMessageContent.FunctionToolCallsProperty] = new ChatToolCall[] { s_functionCallWithValidFunctionName }
}), false);

// Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of JsonElement type.
this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary<string, object?>()
{
[OpenAIChatMessageContent.FunctionToolCallsProperty] = JsonSerializer.Deserialize<JsonElement>($$"""[{"Id": "{{s_functionCallWithValidFunctionName.Id}}", "Name": "{{s_functionCallWithValidFunctionName.FunctionName}}", "Arguments": "{{s_functionCallWithValidFunctionName.FunctionArguments}}"}]""")
}), false);
}
}
}
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
@@ -26,6 +27,13 @@ namespace Microsoft.SemanticKernel.Connectors.OpenAI;
/// </summary>
internal partial class ClientCore
{
#if NET
[GeneratedRegex("[^a-zA-Z0-9_-]")]
private static partial Regex DisallowedFunctionNameCharactersRegex();
#else
private static Regex DisallowedFunctionNameCharactersRegex() => new("[^a-zA-Z0-9_-]", RegexOptions.Compiled);
#endif

protected const string ModelProvider = "openai";
protected record ToolCallingConfig(IList<ChatTool>? Tools, ChatToolChoice? Choice, bool AutoInvoke, bool AllowAnyRequestedKernelFunction, FunctionChoiceBehaviorOptions? Options);

@@ -752,7 +760,7 @@ private static List<ChatMessage> CreateRequestMessages(ChatMessageContent messag
return [new AssistantChatMessage(message.Content) { ParticipantName = message.AuthorName }];
}

var assistantMessage = new AssistantChatMessage(toolCalls) { ParticipantName = message.AuthorName };
var assistantMessage = new AssistantChatMessage(SanitizeFunctionNames(toolCalls)) { ParticipantName = message.AuthorName };

// If message content is null, adding it as empty string,
// because chat message content must be string.
@@ -1054,4 +1062,27 @@ private void ProcessNonFunctionToolCalls(IEnumerable<ChatToolCall> toolCalls, Ch
chatHistory.Add(message);
}
}

/// <summary>
/// Sanitizes function names by replacing disallowed characters.
/// </summary>
/// <param name="toolCalls">The function calls containing the function names which need to be sanitized.</param>
/// <returns>The function calls with sanitized function names.</returns>
private static List<ChatToolCall> SanitizeFunctionNames(List<ChatToolCall> toolCalls)
{
for (int i = 0; i < toolCalls.Count; i++)
{
ChatToolCall tool = toolCalls[i];

// Check if function name contains disallowed characters and replace them with '_'.
if (DisallowedFunctionNameCharactersRegex().IsMatch(tool.FunctionName))
{
var sanitizedName = DisallowedFunctionNameCharactersRegex().Replace(tool.FunctionName, "_");

toolCalls[i] = ChatToolCall.CreateFunctionToolCall(tool.Id, sanitizedName, tool.FunctionArguments);
}
}

return toolCalls;
}
}