From 8281b04aba068f4b75d821f7e64dacb04c5c2cef Mon Sep 17 00:00:00 2001 From: zhongyi51 Date: Sat, 19 Jul 2025 11:10:59 +0800 Subject: [PATCH] feat(validation): support validation for subtypes of `ServerRequest` enum. --- src/generated_schema/2025_06_18/mcp_schema.rs | 114 ++++++++++++++++++ .../2025_06_18/schema_utils.rs | 7 ++ 2 files changed, 121 insertions(+) diff --git a/src/generated_schema/2025_06_18/mcp_schema.rs b/src/generated_schema/2025_06_18/mcp_schema.rs index 552a5bc..1e3a956 100644 --- a/src/generated_schema/2025_06_18/mcp_schema.rs +++ b/src/generated_schema/2025_06_18/mcp_schema.rs @@ -1312,6 +1312,8 @@ impl ::std::convert::From for ContentBlock { /// #[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)] pub struct CreateMessageRequest { + // This field requires custom deserialization for validation. + #[serde(deserialize_with = "server_request_method_validation::deserialize_CreateMessageRequest_method")] method: ::std::string::String, pub params: CreateMessageRequestParams, } @@ -1625,6 +1627,8 @@ pub struct Cursor(pub ::std::string::String); /// #[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)] pub struct ElicitRequest { + // This field requires custom deserialization for validation. + #[serde(deserialize_with = "server_request_method_validation::deserialize_ElicitRequest_method")] method: ::std::string::String, pub params: ElicitRequestParams, } @@ -3302,6 +3306,8 @@ structure or access specific locations that the client has permission to read fr /// #[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)] pub struct ListRootsRequest { + // This field requires custom deserialization for validation. + #[serde(deserialize_with = "server_request_method_validation::deserialize_ListRootsRequest_method")] method: ::std::string::String, #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] pub params: ::std::option::Option, @@ -4040,6 +4046,8 @@ pub struct PaginatedResult { /// #[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)] pub struct PingRequest { + // This field requires custom deserialization for validation. + #[serde(deserialize_with = "server_request_method_validation::deserialize_PingRequest_method")] method: ::std::string::String, #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] pub params: ::std::option::Option, @@ -7167,5 +7175,111 @@ impl ServerNotification { } } } + +// Custom module for deserialization function to prevent name conflicts. +mod server_request_method_validation{ + + // Custom deserialization function, following the `deserialize_#StructName_#FieldName` format. + #[allow(non_snake_case)] + pub(super) fn deserialize_PingRequest_method<'de, D>( + deserializer: D, + ) -> std::result::Result + where + D: serde::de::Deserializer<'de>, + { + let value = serde::Deserialize::deserialize(deserializer)?; + // The expected constant value. + let expected = "ping"; + + // Validate the deserialized value. + if value == expected { + Ok(value) + } else { + // The error message with format + // "Expected field `#FieldName` in struct `#StructName` as const value '{}', but got '{}'" + Err(serde::de::Error::custom(format!( + "Expected field `method` in struct `PingRequest` as const value '{}', but got '{}'", + expected, value + ))) + } + } + + // Custom deserialization function, following the `deserialize_#StructName_#FieldName` format. + #[allow(non_snake_case)] + pub(super) fn deserialize_CreateMessageRequest_method<'de, D>( + deserializer: D, + ) -> std::result::Result + where + D: serde::de::Deserializer<'de>, + { + let value = serde::Deserialize::deserialize(deserializer)?; + // The expected constant value. + let expected = "sampling/createMessage"; + + // Validate the deserialized value. + if value == expected { + Ok(value) + } else { + // The error message with format + // "Expected field `#FieldName` in struct `#StructName` as const value '{}', but got '{}'" + Err(serde::de::Error::custom(format!( + "Expected field `method` in struct `CreateMessageRequest` as const value '{}', but got '{}'", + expected, value + ))) + } + } + + // Custom deserialization function, following the `deserialize_#StructName_#FieldName` format. + #[allow(non_snake_case)] + pub(super) fn deserialize_ListRootsRequest_method<'de, D>( + deserializer: D, + ) -> std::result::Result + where + D: serde::de::Deserializer<'de>, + { + let value = serde::Deserialize::deserialize(deserializer)?; + // The expected constant value. + let expected = "roots/list"; + + // Validate the deserialized value. + if value == expected { + Ok(value) + } else { + // The error message with format + // "Expected field `#FieldName` in struct `#StructName` as const value '{}', but got '{}'" + Err(serde::de::Error::custom(format!( + "Expected field `method` in struct `ListRootsRequest` as const value '{}', but got '{}'", + expected, value + ))) + } + } + + // Custom deserialization function, following the `deserialize_#StructName_#FieldName` format. + #[allow(non_snake_case)] + pub(super) fn deserialize_ElicitRequest_method<'de, D>( + deserializer: D, + ) -> std::result::Result + where + D: serde::de::Deserializer<'de>, + { + let value = serde::Deserialize::deserialize(deserializer)?; + // The expected constant value. + let expected = "elicitation/create"; + + // Validate the deserialized value. + if value == expected { + Ok(value) + } else { + // The error message with format + // "Expected field `#FieldName` in struct `#StructName` as const value '{}', but got '{}'" + Err(serde::de::Error::custom(format!( + "Expected field `method` in struct `ElicitRequest` as const value '{}', but got '{}'", + expected, value + ))) + } + } + +} + #[deprecated(since = "0.3.0", note = "Use `RpcError` instead.")] pub type JsonrpcErrorError = RpcError; diff --git a/src/generated_schema/2025_06_18/schema_utils.rs b/src/generated_schema/2025_06_18/schema_utils.rs index 0635d3c..a0cf3f3 100644 --- a/src/generated_schema/2025_06_18/schema_utils.rs +++ b/src/generated_schema/2025_06_18/schema_utils.rs @@ -3922,5 +3922,12 @@ mod tests { // default let result = detect_message_type(&json!({})); assert!(matches!(result, MessageTypes::Request)); + + // assert method type validation + let should_err:std::result::Result = serde_json::from_value(json!({ + "method":"wrong_method", + "params":null + })); + assert!(should_err.is_err()); } }