diff --git a/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/AggregateMcpDiagnosticEventListener.cs b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/AggregateMcpDiagnosticEventListener.cs new file mode 100644 index 00000000000..a929a57ba57 --- /dev/null +++ b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/AggregateMcpDiagnosticEventListener.cs @@ -0,0 +1,55 @@ +namespace HotChocolate.ModelContextProtocol.Diagnostics; + +internal sealed class AggregateMcpDiagnosticEvents(IMcpDiagnosticEventListener[] listeners) + : IMcpDiagnosticEvents +{ + public IDisposable InitializeTools() + { + var scopes = new IDisposable[listeners.Length]; + + for (var i = 0; i < listeners.Length; i++) + { + scopes[i] = listeners[i].InitializeTools(); + } + + return new AggregateActivityScope(scopes); + } + + public IDisposable UpdateTools() + { + var scopes = new IDisposable[listeners.Length]; + + for (var i = 0; i < listeners.Length; i++) + { + scopes[i] = listeners[i].InitializeTools(); + } + + return new AggregateActivityScope(scopes); + } + + public void ValidationErrors(IReadOnlyList errors) + { + foreach (var listener in listeners) + { + listener.ValidationErrors(errors); + } + } + + private sealed class AggregateActivityScope(IDisposable[] scopes) : IDisposable + { + private bool _disposed; + + public void Dispose() + { + if (!_disposed) + { + foreach (var scope in scopes) + { + scope.Dispose(); + } + + _disposed = true; + } + } + } +} diff --git a/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/IMcpDiagnosticEventListener.cs b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/IMcpDiagnosticEventListener.cs new file mode 100644 index 00000000000..6ba3f9e3699 --- /dev/null +++ b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/IMcpDiagnosticEventListener.cs @@ -0,0 +1,9 @@ +namespace HotChocolate.ModelContextProtocol.Diagnostics; + +/// +/// Register an implementation of this interface in the DI container to +/// listen to diagnostic events. Multiple implementations can be registered, +/// and they will all be called in the registration order. +/// +/// +public interface IMcpDiagnosticEventListener : IMcpDiagnosticEvents; diff --git a/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/IMcpDiagnosticEvents.cs b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/IMcpDiagnosticEvents.cs new file mode 100644 index 00000000000..b0ab998902a --- /dev/null +++ b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/IMcpDiagnosticEvents.cs @@ -0,0 +1,29 @@ +namespace HotChocolate.ModelContextProtocol.Diagnostics; + +/// +/// Provides diagnostic events for the Model Context Protocol (MCP) integration. +/// +public interface IMcpDiagnosticEvents +{ + /// + /// Called when the MCP tools are being initialized. + /// + /// + /// Returns a scope that is disposed when the initialization is complete. + /// + IDisposable InitializeTools(); + + /// + /// Called when the MCP tools are being updated. + /// + /// + /// Returns a scope that is disposed when the update is complete. + /// + IDisposable UpdateTools(); + + /// + /// Called when errors occur while validating a tool document. + /// + /// The validation errors. + void ValidationErrors(IReadOnlyList errors); +} diff --git a/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/McpDiagnosticEventListener.cs b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/McpDiagnosticEventListener.cs new file mode 100644 index 00000000000..49fb3e60bab --- /dev/null +++ b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/McpDiagnosticEventListener.cs @@ -0,0 +1,30 @@ +namespace HotChocolate.ModelContextProtocol.Diagnostics; + +/// +/// This class can be used as a base class for +/// implementations, so that they only have to override the methods they +/// are interested in instead of having to provide implementations for all of them. +/// +public class McpDiagnosticEventListener : IMcpDiagnosticEventListener +{ + /// + /// A no-op activity scope that can be returned from + /// event methods that are not interested in when the scope is disposed. + /// + private static IDisposable EmptyScope { get; } = new EmptyActivityScope(); + + public virtual IDisposable InitializeTools() => EmptyScope; + + public virtual IDisposable UpdateTools() => EmptyScope; + + public virtual void ValidationErrors(IReadOnlyList errors) + { + } + + private sealed class EmptyActivityScope : IDisposable + { + public void Dispose() + { + } + } +} diff --git a/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/NoopMcpDiagnosticEventListener.cs b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/NoopMcpDiagnosticEventListener.cs new file mode 100644 index 00000000000..fa77939edfc --- /dev/null +++ b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Diagnostics/NoopMcpDiagnosticEventListener.cs @@ -0,0 +1,3 @@ +namespace HotChocolate.ModelContextProtocol.Diagnostics; + +public class NoopMcpDiagnosticEvents : McpDiagnosticEventListener; diff --git a/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Extensions/RequestExecutorBuilderExtensions.cs b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Extensions/RequestExecutorBuilderExtensions.cs index 201dfda65ae..8cd39ece4cc 100644 --- a/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Extensions/RequestExecutorBuilderExtensions.cs +++ b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/Extensions/RequestExecutorBuilderExtensions.cs @@ -1,5 +1,7 @@ +using System.Diagnostics.CodeAnalysis; using HotChocolate.Execution; using HotChocolate.Execution.Configuration; +using HotChocolate.ModelContextProtocol.Diagnostics; using HotChocolate.ModelContextProtocol.Handlers; using HotChocolate.ModelContextProtocol.Proxies; using HotChocolate.ModelContextProtocol.Storage; @@ -47,13 +49,26 @@ public static IRequestExecutorBuilder AddMcp( static sp => new OperationToolFactory( sp.GetRequiredService())); + services.TryAddSingleton(sp => + { + var listeners = sp.GetServices().ToArray(); + return listeners.Length switch + { + 0 => new NoopMcpDiagnosticEvents(), + 1 => listeners[0], + _ => new AggregateMcpDiagnosticEvents(listeners) + }; + }); + services .TryAddSingleton( static sp => new ToolStorageObserver( + sp.GetRequiredService(), sp.GetRequiredService(), sp.GetRequiredService(), sp.GetRequiredService(), - sp.GetRequiredService())); + sp.GetRequiredService(), + sp.GetRequiredService())); services .AddSingleton( @@ -99,6 +114,29 @@ public static IRequestExecutorBuilder AddMcp( return builder; } + public static IRequestExecutorBuilder AddMcpDiagnosticEventListener( + this IRequestExecutorBuilder builder, + IMcpDiagnosticEventListener listener) + { + ArgumentNullException.ThrowIfNull(builder); + ArgumentNullException.ThrowIfNull(listener); + + builder.ConfigureSchemaServices(s => s.AddSingleton(listener)); + + return builder; + } + + public static IRequestExecutorBuilder AddMcpDiagnosticEventListener< + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] T>( + this IRequestExecutorBuilder builder) where T : class, IMcpDiagnosticEventListener + { + ArgumentNullException.ThrowIfNull(builder); + + builder.ConfigureSchemaServices(s => s.AddSingleton()); + + return builder; + } + public static IRequestExecutorBuilder AddMcpToolStorage( this IRequestExecutorBuilder builder, IOperationToolStorage storage) diff --git a/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/ToolStorageObserver.cs b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/ToolStorageObserver.cs index 3f8e27cea58..565e3745f1b 100644 --- a/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/ToolStorageObserver.cs +++ b/src/HotChocolate/ModelContextProtocol/src/HotChocolate.ModelContextProtocol/ToolStorageObserver.cs @@ -1,7 +1,10 @@ using System.Collections.Immutable; using System.Reactive.Linq; +using HotChocolate.ModelContextProtocol.Diagnostics; using HotChocolate.ModelContextProtocol.Storage; using HotChocolate.Utilities; +using HotChocolate.Validation; +using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol; using ModelContextProtocol.AspNetCore; using static ModelContextProtocol.Protocol.NotificationMethods; @@ -13,27 +16,35 @@ internal sealed class ToolStorageObserver : IDisposable private readonly SemaphoreSlim _semaphore = new(initialCount: 1, maxCount: 1); private readonly CancellationTokenSource _cts = new(); private readonly CancellationToken _ct; + private readonly ISchemaDefinition _schema; private readonly ToolRegistry _registry; private readonly OperationToolFactory _toolFactory; private readonly StreamableHttpHandler _httpHandler; private readonly IOperationToolStorage _storage; + private readonly IMcpDiagnosticEvents _diagnosticEvents; private IDisposable? _subscription; #if NET10_0_OR_GREATER private ImmutableDictionary _tools = []; #else private ImmutableDictionary _tools = ImmutableDictionary.Empty; #endif + private static readonly DocumentValidator s_documentValidator + = DocumentValidatorBuilder.New().AddDefaultRules().Build(); private bool _disposed; public ToolStorageObserver( + ISchemaDefinition schema, ToolRegistry registry, OperationToolFactory toolFactory, StreamableHttpHandler httpHandler, - IOperationToolStorage storage) + IOperationToolStorage storage, + IMcpDiagnosticEvents diagnosticEvents) { + _schema = schema; _registry = registry; _toolFactory = toolFactory; _storage = storage; + _diagnosticEvents = diagnosticEvents; _httpHandler = httpHandler; _ct = _cts.Token; } @@ -58,8 +69,18 @@ public async Task StartAsync(CancellationToken cancellationToken) { var tools = ImmutableDictionary.CreateBuilder(); + using var scope = _diagnosticEvents.InitializeTools(); + foreach (var toolDefinition in await _storage.GetToolsAsync(cancellationToken)) { + var validationResult = s_documentValidator.Validate(_schema, toolDefinition.Document); + + if (validationResult.HasErrors) + { + _diagnosticEvents.ValidationErrors(validationResult.Errors); + continue; + } + tools.Add(toolDefinition.Name, _toolFactory.CreateTool(toolDefinition)); } @@ -84,9 +105,21 @@ private void ProcessBatch(IList eventArgs) { case OperationToolStorageEventType.Added: case OperationToolStorageEventType.Modified: - var tool = _toolFactory.CreateTool(eventArg.ToolDefinition!); - _tools = _tools.SetItem(eventArg.Name, tool); - break; + using (_diagnosticEvents.UpdateTools()) + { + var validationResult = + s_documentValidator.Validate(_schema, eventArg.ToolDefinition!.Document); + + if (validationResult.HasErrors) + { + _diagnosticEvents.ValidationErrors(validationResult.Errors); + continue; + } + + var tool = _toolFactory.CreateTool(eventArg.ToolDefinition!); + _tools = _tools.SetItem(eventArg.Name, tool); + break; + } case OperationToolStorageEventType.Removed: _tools = _tools.Remove(eventArg.Name); diff --git a/src/HotChocolate/ModelContextProtocol/test/HotChocolate.ModelContextProtocol.Tests/IntegrationTests.cs b/src/HotChocolate/ModelContextProtocol/test/HotChocolate.ModelContextProtocol.Tests/IntegrationTests.cs index 4707fc6b1f1..10b6058e075 100644 --- a/src/HotChocolate/ModelContextProtocol/test/HotChocolate.ModelContextProtocol.Tests/IntegrationTests.cs +++ b/src/HotChocolate/ModelContextProtocol/test/HotChocolate.ModelContextProtocol.Tests/IntegrationTests.cs @@ -6,6 +6,7 @@ using HotChocolate.Execution; using HotChocolate.Execution.Configuration; using HotChocolate.Language; +using HotChocolate.ModelContextProtocol.Diagnostics; using HotChocolate.ModelContextProtocol.Extensions; using HotChocolate.Types; using HotChocolate.Types.Descriptors; @@ -254,6 +255,67 @@ mutation AddBook @mcpTool(destructiveHint: false, idempotentHint: true, openWorl Assert.Equal(false, tools[0].ProtocolTool.Annotations?.OpenWorldHint); } + [Fact] + public async Task ListTools_InitializeToolsInvalidDocument_ReturnsExpectedResult() + { + // arrange + var storage = new TestOperationToolStorage(); + await storage.AddOrUpdateToolAsync( + Utf8GraphQLParser.Parse("query Invalid { doesNotExist1, doesNotExist2 }")); + await storage.AddOrUpdateToolAsync( + Utf8GraphQLParser.Parse("query Valid { books { title } }")); + var listener = new TestMcpDiagnosticEventListener(); + var server = + CreateTestServer( + b => b + .AddMcpDiagnosticEventListener(listener) + .AddMcpToolStorage(storage)); + var mcpClient = await CreateMcpClientAsync(server.CreateClient()); + + // act + var result = await mcpClient.ListToolsAsync(); + + // assert + Assert.Single(result, tool => tool.Name == "valid"); // The invalid tool is ignored. + Assert.Collection( + listener.ValidationErrorLog, + firstError => + Assert.Equal("The field `doesNotExist1` does not exist on the type `Query`.", firstError.Message), + secondError => + Assert.Equal("The field `doesNotExist2` does not exist on the type `Query`.", secondError.Message)); + } + + [Fact] + public async Task ListTools_UpdateToolsInvalidDocument_ReturnsExpectedResult() + { + // arrange + var storage = new TestOperationToolStorage(); + await storage.AddOrUpdateToolAsync( + Utf8GraphQLParser.Parse("""query Tool @mcpTool(title: "BEFORE") { books { title } }""")); + var listener = new TestMcpDiagnosticEventListener(); + var server = + CreateTestServer( + b => b + .AddMcpDiagnosticEventListener(listener) + .AddMcpToolStorage(storage)); + var mcpClient = await CreateMcpClientAsync(server.CreateClient()); + + // act + await storage.AddOrUpdateToolAsync( + Utf8GraphQLParser.Parse("""query Tool @mcpTool(title: "AFTER") { doesNotExist1, doesNotExist2 }""")); + await Task.Delay(500); // Wait for the observer buffer to flush. + var result = await mcpClient.ListToolsAsync(); + + // assert + Assert.Single(result, tool => tool.Title == "BEFORE"); // The invalid update is ignored. + Assert.Collection( + listener.ValidationErrorLog, + firstError => + Assert.Equal("The field `doesNotExist1` does not exist on the type `Query`.", firstError.Message), + secondError => + Assert.Equal("The field `doesNotExist2` does not exist on the type `Query`.", secondError.Message)); + } + [Fact] public async Task CallTool_GetWithNullableVariables_ReturnsExpectedResult() { @@ -657,3 +719,13 @@ public static string GenerateToken() private const string TokenAudience = "test-audience"; private static readonly SymmetricSecurityKey s_tokenKey = new("test-secret-key-at-least-32-bytes"u8.ToArray()); } + +public sealed class TestMcpDiagnosticEventListener : McpDiagnosticEventListener +{ + public List ValidationErrorLog { get; } = []; + + public override void ValidationErrors(IReadOnlyList errors) + { + ValidationErrorLog.AddRange(errors); + } +}