diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index c8af28ac..1d77bf29 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -12,6 +12,11 @@ import java.util.function.Consumer; import java.util.function.Function; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.spec.DefaultJsonSchemaValidator; +import io.modelcontextprotocol.spec.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransport; @@ -97,6 +102,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Anurag Pant * @see McpAsyncClient * @see McpSyncClient * @see McpTransport @@ -183,6 +189,8 @@ class SyncSpec { private Function elicitationHandler; + private JsonSchemaValidator jsonSchemaValidator; + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -409,12 +417,27 @@ public SyncSpec progressConsumers(List> return this; } + /** + * Sets the JSON schema validator to use for validating tool responses against + * output schemas. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public SyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + /** * Create an instance of {@link McpSyncClient} with the provided configurations or * sensible defaults. * @return a new instance of {@link McpSyncClient}. */ public McpSyncClient build() { + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(); McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler, @@ -423,7 +446,8 @@ public McpSyncClient build() { McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); return new McpSyncClient( - new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures)); + new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures), + jsonSchemaValidator); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 83c4900d..7544f299 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -5,10 +5,15 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.modelcontextprotocol.spec.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; @@ -48,6 +53,7 @@ * @author Dariusz Jędrzejczyk * @author Christian Tzolov * @author Jihoon Kim + * @author Anurag Pant * @see McpClient * @see McpAsyncClient * @see McpSchema @@ -63,14 +69,23 @@ public class McpSyncClient implements AutoCloseable { private final McpAsyncClient delegate; + private final JsonSchemaValidator jsonSchemaValidator; + + /** + * Cached tool output schemas. + */ + private final ConcurrentHashMap>> toolsOutputSchemaCache; + /** * Create a new McpSyncClient with the given delegate. * @param delegate the asynchronous kernel on top of which this synchronous client * provides a blocking API. */ - McpSyncClient(McpAsyncClient delegate) { + McpSyncClient(McpAsyncClient delegate, JsonSchemaValidator jsonSchemaValidator) { Assert.notNull(delegate, "The delegate can not be null"); this.delegate = delegate; + this.jsonSchemaValidator = jsonSchemaValidator; + this.toolsOutputSchemaCache = new ConcurrentHashMap<>(); } /** @@ -216,7 +231,37 @@ public Object ping() { * Boolean indicating if the execution failed (true) or succeeded (false/absent) */ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest) { - return this.delegate.callTool(callToolRequest).block(); + if (!this.toolsOutputSchemaCache.containsKey(callToolRequest.name())) { + listTools(); // Ensure tools are cached before calling + } + + McpSchema.CallToolResult result = this.delegate.callTool(callToolRequest).block(); + Optional> optOutputSchema = toolsOutputSchemaCache.get(callToolRequest.name()); + + if (result != null && result.isError() != null && !result.isError()) { + if (optOutputSchema == null) { + // Should not be triggered but added for completeness + throw new McpError("Tool with name '" + callToolRequest.name() + "' not found"); + } + else { + if (optOutputSchema.isPresent()) { + // Validate the tool output against the cached output schema + var validation = this.jsonSchemaValidator.validate(optOutputSchema.get(), + result.structuredContent()); + if (!validation.valid()) { + throw new McpError("Tool call result validation failed: " + validation.errorMessage()); + } + } + else if (result.structuredContent() != null) { + logger.warn( + "Calling a tool with no outputSchema is not expected to return result with structured content, but got: {}", + result.structuredContent()); + } + + } + } + + return result; } /** @@ -226,7 +271,14 @@ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolReque * pagination if more tools are available */ public McpSchema.ListToolsResult listTools() { - return this.delegate.listTools().block(); + return this.delegate.listTools().doOnNext(result -> { + if (result.tools() != null) { + // Cache tools output schema + result.tools() + .forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(), + Optional.ofNullable(tool.outputSchema()))); + } + }).block(); } /** @@ -237,7 +289,14 @@ public McpSchema.ListToolsResult listTools() { * pagination if more tools are available */ public McpSchema.ListToolsResult listTools(String cursor) { - return this.delegate.listTools(cursor).block(); + return this.delegate.listTools(cursor).doOnNext(result -> { + if (result.tools() != null) { + // Cache tools output schema + result.tools() + .forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(), + Optional.ofNullable(tool.outputSchema()))); + } + }).block(); } // -------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpSyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpSyncClientResponseHandlerTests.java new file mode 100644 index 00000000..d8a53c40 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpSyncClientResponseHandlerTests.java @@ -0,0 +1,312 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.Tool; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import reactor.core.publisher.Mono; + +import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for the {@link McpSyncClient} + * + * @author Anurag Pant + */ +public class McpSyncClientResponseHandlerTests { + + private static final McpSchema.Implementation SERVER_INFO = new McpSchema.Implementation("test-server", "1.0.0"); + + private static final McpSchema.ServerCapabilities SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder() + .tools(true) + .resources(true, true) // Enable both resources and resource templates + .build(); + + private static MockMcpClientTransport initializationEnabledTransport() { + return initializationEnabledTransport(SERVER_CAPABILITIES, SERVER_INFO); + } + + private static MockMcpClientTransport initializationEnabledTransport( + McpSchema.ServerCapabilities mockServerCapabilities, McpSchema.Implementation mockServerInfo) { + McpSchema.InitializeResult mockInitResult = new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, + mockServerCapabilities, mockServerInfo, "Test instructions"); + + return new MockMcpClientTransport((t, message) -> { + if (message instanceof McpSchema.JSONRPCRequest r && METHOD_INITIALIZE.equals(r.method())) { + McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + r.id(), mockInitResult, null); + t.simulateIncomingMessage(initResponse); + } + }); + } + + @Test + void testStructuredOutputClientSideValidationSuccess() throws JsonProcessingException { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create client with tools change consumer + McpSyncClient syncMcpClient = McpClient.sync(transport).build(); + + assertThat(syncMcpClient.initialize()).isNotNull(); + + // Create a mock tools list that the server will return + Map inputSchema = Map.of("type", "object", "properties", + Map.of("expression", Map.of("type", "string")), "required", List.of("expression")); + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .inputSchema(new ObjectMapper().writeValueAsString(inputSchema)) + .outputSchema(outputSchema) + .build(); + + // Create list tools response + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(calculatorTool), null); + + // Create call tool result with valid output (structured content despite no output + // schema) + CallToolResult mockInvalidCallToolResult = CallToolResult.builder() + .addTextContent("Valid calculation") + .structuredContent(Map.of("result", 5, "operation", "add")) + .build(); + + // Set up a separate thread to simulate the response + Thread responseThread = new Thread(() -> { + try { + // Wait briefly to ensure the listTools request is sent + Thread.sleep(100); + + // Simulate server response to first tools/list request + McpSchema.JSONRPCRequest toolsListRequest = transport.getLastSentMessageAsRequest(); + assertThat(toolsListRequest.method()).isEqualTo(McpSchema.METHOD_TOOLS_LIST); + McpSchema.JSONRPCResponse toolsListResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + toolsListRequest.id(), mockToolsResult, null); + transport.simulateIncomingMessage(toolsListResponse); + + // Wait briefly to ensure the callTool request is sent + Thread.sleep(100); + + // Get the request and send the response + McpSchema.JSONRPCRequest callToolRequest = transport.getLastSentMessageAsRequest(); + assertThat(callToolRequest.method()).isEqualTo(McpSchema.METHOD_TOOLS_CALL); + McpSchema.JSONRPCResponse callToolResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + callToolRequest.id(), mockInvalidCallToolResult, null); + transport.simulateIncomingMessage(callToolResponse); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + // Start the response thread + responseThread.start(); + + // Call tool with valid structured output + CallToolResult response = syncMcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.structuredContent()).hasSize(2); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + // Wait for the response thread to complete + try { + responseThread.join(1500); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + + syncMcpClient.closeGracefully(); + } + + @Test + void testStructuredOutputWhenNoOutputSchemaClientSideValidationSuccess() throws JsonProcessingException { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create client with tools change consumer + McpSyncClient syncMcpClient = McpClient.sync(transport).build(); + + assertThat(syncMcpClient.initialize()).isNotNull(); + + // Create a mock tools list that the server will return + Map inputSchema = Map.of("type", "object", "properties", + Map.of("expression", Map.of("type", "string")), "required", List.of("expression")); + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .inputSchema(new ObjectMapper().writeValueAsString(inputSchema)) + .build(); + + // Create list tools response + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(calculatorTool), null); + + // Create call tool result with valid output (structured content despite no output + // schema) + CallToolResult mockInvalidCallToolResult = CallToolResult.builder() + .addTextContent("Valid calculation") + .structuredContent(Map.of("result", 5, "operation", "add")) + .build(); + + // Set up a separate thread to simulate the response + Thread responseThread = new Thread(() -> { + try { + // Wait briefly to ensure the listTools request is sent + Thread.sleep(100); + + // Simulate server response to first tools/list request + McpSchema.JSONRPCRequest toolsListRequest = transport.getLastSentMessageAsRequest(); + assertThat(toolsListRequest.method()).isEqualTo(McpSchema.METHOD_TOOLS_LIST); + McpSchema.JSONRPCResponse toolsListResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + toolsListRequest.id(), mockToolsResult, null); + transport.simulateIncomingMessage(toolsListResponse); + + // Wait briefly to ensure the callTool request is sent + Thread.sleep(100); + + // Get the request and send the response + McpSchema.JSONRPCRequest callToolRequest = transport.getLastSentMessageAsRequest(); + assertThat(callToolRequest.method()).isEqualTo(McpSchema.METHOD_TOOLS_CALL); + McpSchema.JSONRPCResponse callToolResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + callToolRequest.id(), mockInvalidCallToolResult, null); + transport.simulateIncomingMessage(callToolResponse); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + // Start the response thread + responseThread.start(); + + // Call tool with valid structured output + CallToolResult response = syncMcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.structuredContent()).hasSize(2); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + // Wait for the response thread to complete + try { + responseThread.join(1500); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + + syncMcpClient.closeGracefully(); + } + + @Test + void testStructuredOutputClientSideValidationFailure() throws JsonProcessingException { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create client with tools change consumer + McpSyncClient syncMcpClient = McpClient.sync(transport).build(); + + assertThat(syncMcpClient.initialize()).isNotNull(); + + // Create a mock tools list that the server will return + Map inputSchema = Map.of("type", "object", "properties", + Map.of("expression", Map.of("type", "string")), "required", List.of("expression")); + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .inputSchema(new ObjectMapper().writeValueAsString(inputSchema)) + .outputSchema(outputSchema) + .build(); + + // Create list tools response + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(calculatorTool), null); + + // Create call tool result with invalid output + CallToolResult mockInvalidCallToolResult = CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "5", "operation", "add")) + .build(); + + // Set up a separate thread to simulate the response + Thread responseThread = new Thread(() -> { + try { + // Wait briefly to ensure the listTools request is sent + Thread.sleep(100); + + // Simulate server response to first tools/list request + McpSchema.JSONRPCRequest toolsListRequest = transport.getLastSentMessageAsRequest(); + assertThat(toolsListRequest.method()).isEqualTo(McpSchema.METHOD_TOOLS_LIST); + McpSchema.JSONRPCResponse toolsListResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + toolsListRequest.id(), mockToolsResult, null); + transport.simulateIncomingMessage(toolsListResponse); + + // Wait briefly to ensure the callTool request is sent + Thread.sleep(100); + + // Get the request and send the response + McpSchema.JSONRPCRequest callToolRequest = transport.getLastSentMessageAsRequest(); + assertThat(callToolRequest.method()).isEqualTo(McpSchema.METHOD_TOOLS_CALL); + McpSchema.JSONRPCResponse callToolResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + callToolRequest.id(), mockInvalidCallToolResult, null); + transport.simulateIncomingMessage(callToolResponse); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + // Start the response thread + responseThread.start(); + + // Make the call that should fail validation + assertThatThrownBy(() -> syncMcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .isInstanceOf(McpError.class) + .hasMessageContaining("Validation failed"); + + // Wait for the response thread to complete + try { + responseThread.join(1500); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + + syncMcpClient.closeGracefully(); + } + +}