diff --git a/Directory.Packages.props b/Directory.Packages.props index 656f5fe9..f86a7c69 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -48,12 +48,19 @@ + + + + - + + + - + + runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/ModelContextProtocol.slnx b/ModelContextProtocol.slnx index a70e3e31..1f6dce1e 100644 --- a/ModelContextProtocol.slnx +++ b/ModelContextProtocol.slnx @@ -62,11 +62,13 @@ + + diff --git a/README.md b/README.md index 3099dfcd..19672bfe 100644 --- a/README.md +++ b/README.md @@ -225,6 +225,11 @@ await using McpServer server = McpServer.Create(new StdioServerTransport("MyServ await server.RunAsync(); ``` +Descriptions can be added to tools, prompts, and resources in a variety of ways, including via the `[Description]` attribute from `System.ComponentModel`. +This attribute may be placed on a method to provide for the tool, prompt, or resource, or on individual parameters to describe each's purpose. +XML comments may also be used; if an `[McpServerTool]`, `[McpServerPrompt]`, or `[McpServerResource]`-attributed method is marked as `partial`, +XML comments placed on the method will be used automatically to generate `[Description]` attributes for the method and its parameters. + ## Acknowledgements The starting point for this library was a project called [mcpdotnet](https://github.com/PederHP/mcpdotnet), initiated by [Peder Holdgaard Pedersen](https://github.com/PederHP). We are grateful for the work done by Peder and other contributors to that repository, which created a solid foundation for this library. diff --git a/src/ModelContextProtocol.Analyzers/Diagnostics.cs b/src/ModelContextProtocol.Analyzers/Diagnostics.cs new file mode 100644 index 00000000..e2c70412 --- /dev/null +++ b/src/ModelContextProtocol.Analyzers/Diagnostics.cs @@ -0,0 +1,31 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; +using System.Collections.Immutable; +using System.Text; +using System.Xml.Linq; + +namespace ModelContextProtocol.Analyzers; + +/// Provides the diagnostic descriptors used by the assembly. +internal static class Diagnostics +{ + public static DiagnosticDescriptor InvalidXmlDocumentation { get; } = new( + id: "MCP001", + title: "Invalid XML documentation for MCP method", + messageFormat: "XML comment for method '{0}' is invalid and cannot be processed to generate [Description] attributes.", + category: "mcp", + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true, + description: "The XML documentation comment contains invalid XML and cannot be processed to generate Description attributes."); + + public static DiagnosticDescriptor McpMethodMustBePartial { get; } = new( + id: "MCP002", + title: "MCP method must be partial to generate [Description] attributes", + messageFormat: "Method '{0}' has XML documentation that could be used to generate [Description] attributes, but the method is not declared as partial.", + category: "mcp", + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true, + description: "Methods with MCP attributes should be declared as partial to allow the source generator to emit Description attributes from XML documentation comments."); +} diff --git a/src/ModelContextProtocol.Analyzers/ModelContextProtocol.Analyzers.csproj b/src/ModelContextProtocol.Analyzers/ModelContextProtocol.Analyzers.csproj new file mode 100644 index 00000000..5338bbb8 --- /dev/null +++ b/src/ModelContextProtocol.Analyzers/ModelContextProtocol.Analyzers.csproj @@ -0,0 +1,20 @@ + + + + netstandard2.0 + true + false + true + + + + + + + + + + + + + diff --git a/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs b/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs new file mode 100644 index 00000000..a5dff0c7 --- /dev/null +++ b/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs @@ -0,0 +1,414 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; +using System.CodeDom.Compiler; +using System.Collections.Immutable; +using System.Text; +using System.Xml.Linq; + +namespace ModelContextProtocol.Analyzers; + +/// +/// Source generator that creates [Description] attributes from XML comments +/// for partial methods tagged with MCP attributes. +/// +[Generator] +public sealed class XmlToDescriptionGenerator : IIncrementalGenerator +{ + private const string GeneratedFileName = "ModelContextProtocol.Descriptions.g.cs"; + private const string McpServerToolAttributeName = "ModelContextProtocol.Server.McpServerToolAttribute"; + private const string McpServerPromptAttributeName = "ModelContextProtocol.Server.McpServerPromptAttribute"; + private const string McpServerResourceAttributeName = "ModelContextProtocol.Server.McpServerResourceAttribute"; + private const string DescriptionAttributeName = "System.ComponentModel.DescriptionAttribute"; + + public void Initialize(IncrementalGeneratorInitializationContext context) + { + // Use ForAttributeWithMetadataName for each MCP attribute type + var toolMethods = CreateProviderForAttribute(context, McpServerToolAttributeName); + var promptMethods = CreateProviderForAttribute(context, McpServerPromptAttributeName); + var resourceMethods = CreateProviderForAttribute(context, McpServerResourceAttributeName); + + // Combine all three providers + var allMethods = toolMethods + .Collect() + .Combine(promptMethods.Collect()) + .Combine(resourceMethods.Collect()) + .Select(static (tuple, _) => + { + var ((tool, prompt), resource) = tuple; + return tool.AddRange(prompt).AddRange(resource); + }); + + // Combine with compilation to get well-known type symbols. + var compilationAndMethods = context.CompilationProvider.Combine(allMethods); + + // Write out the source for all methods. + context.RegisterSourceOutput(compilationAndMethods, static (spc, source) => Execute(source.Left, source.Right, spc)); + } + + private static IncrementalValuesProvider CreateProviderForAttribute( + IncrementalGeneratorInitializationContext context, + string attributeMetadataName) => + context.SyntaxProvider.ForAttributeWithMetadataName( + attributeMetadataName, + static (node, _) => node is MethodDeclarationSyntax, + static (ctx, ct) => + { + var methodDeclaration = (MethodDeclarationSyntax)ctx.TargetNode; + var methodSymbol = (IMethodSymbol)ctx.TargetSymbol; + return new MethodToGenerate(methodDeclaration, methodSymbol); + }); + + private static void Execute(Compilation compilation, ImmutableArray methods, SourceProductionContext context) + { + if (methods.IsDefaultOrEmpty || + compilation.GetTypeByMetadataName(DescriptionAttributeName) is not { } descriptionAttribute) + { + return; + } + + // Gather a list of all methods needing generation. + List<(IMethodSymbol MethodSymbol, MethodDeclarationSyntax MethodDeclaration, XmlDocumentation? XmlDocs)> methodsToGenerate = new(methods.Length); + foreach (var methodModel in methods) + { + var xmlDocs = ExtractXmlDocumentation(methodModel.MethodSymbol, context); + + // Generate implementation for partial methods. + if (methodModel.MethodDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword)) + { + methodsToGenerate.Add((methodModel.MethodSymbol, methodModel.MethodDeclaration, xmlDocs)); + } + else if (xmlDocs is not null && HasGeneratableContent(xmlDocs, methodModel.MethodSymbol, descriptionAttribute)) + { + // The method is not partial but has XML docs that would generate attributes; issue a diagnostic. + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.McpMethodMustBePartial, + methodModel.MethodDeclaration.Identifier.GetLocation(), + methodModel.MethodSymbol.Name)); + } + } + + // Generate a single file with all partial declarations. + if (methodsToGenerate.Count > 0) + { + string source = GenerateSourceFile(compilation, methodsToGenerate, descriptionAttribute); + context.AddSource(GeneratedFileName, SourceText.From(source, Encoding.UTF8)); + } + } + + private static XmlDocumentation? ExtractXmlDocumentation(IMethodSymbol methodSymbol, SourceProductionContext context) + { + string? xmlDoc = methodSymbol.GetDocumentationCommentXml(); + if (string.IsNullOrWhiteSpace(xmlDoc)) + { + return null; + } + + try + { + if (XDocument.Parse(xmlDoc).Element("member") is not { } memberElement) + { + return null; + } + + var summary = CleanXmlDocText(memberElement.Element("summary")?.Value); + var remarks = CleanXmlDocText(memberElement.Element("remarks")?.Value); + var returns = CleanXmlDocText(memberElement.Element("returns")?.Value); + + // Combine summary and remarks for method description. + var methodDescription = + string.IsNullOrWhiteSpace(remarks) ? summary : + string.IsNullOrWhiteSpace(summary) ? remarks : + $"{summary}\n{remarks}"; + + Dictionary paramDocs = new(StringComparer.Ordinal); + foreach (var paramElement in memberElement.Elements("param")) + { + var name = paramElement.Attribute("name")?.Value; + var value = CleanXmlDocText(paramElement.Value); + if (!string.IsNullOrWhiteSpace(name) && !string.IsNullOrWhiteSpace(value)) + { + paramDocs[name!] = value; + } + } + + // Return documentation even if empty - we'll still generate the partial implementation + return new(methodDescription ?? string.Empty, returns ?? string.Empty, paramDocs); + } + catch (System.Xml.XmlException) + { + // Emit warning for invalid XML + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.InvalidXmlDocumentation, + methodSymbol.Locations.FirstOrDefault(), + methodSymbol.Name)); + return null; + } + } + + private static string CleanXmlDocText(string? text) + { + if (string.IsNullOrWhiteSpace(text)) + { + return string.Empty; + } + + // Remove leading/trailing whitespace and normalize line breaks + var lines = text!.Split('\n') + .Select(line => line.Trim()) + .Where(line => !string.IsNullOrEmpty(line)); + + return string.Join(" ", lines).Trim(); + } + + private static string GenerateSourceFile( + Compilation compilation, + List<(IMethodSymbol MethodSymbol, MethodDeclarationSyntax MethodDeclaration, XmlDocumentation? XmlDocs)> methods, + INamedTypeSymbol descriptionAttribute) + { + StringWriter sw = new(); + IndentedTextWriter writer = new(sw); + + writer.WriteLine("// "); + writer.WriteLine($"// ModelContextProtocol.Analyzers {typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}"); + writer.WriteLine(); + writer.WriteLine("#pragma warning disable"); + writer.WriteLine(); + writer.WriteLine("using System.ComponentModel;"); + writer.WriteLine("using ModelContextProtocol.Server;"); + writer.WriteLine(); + + // Group methods by namespace and containing type + var groupedMethods = methods.GroupBy(m => + m.MethodSymbol.ContainingNamespace.Name == compilation.GlobalNamespace.Name ? "" : + m.MethodSymbol.ContainingNamespace?.ToDisplayString() ?? + ""); + + bool firstNamespace = true; + foreach (var namespaceGroup in groupedMethods) + { + if (!firstNamespace) + { + writer.WriteLine(); + } + firstNamespace = false; + + // Check if this is the global namespace (methods with null ContainingNamespace) + bool isGlobalNamespace = string.IsNullOrEmpty(namespaceGroup.Key); + if (!isGlobalNamespace) + { + writer.WriteLine($"namespace {namespaceGroup.Key}"); + writer.WriteLine("{"); + writer.Indent++; + } + + // Group by containing type within namespace + bool isFirstTypeInNamespace = true; + foreach (var typeGroup in namespaceGroup.GroupBy(m => m.MethodSymbol.ContainingType, SymbolEqualityComparer.Default)) + { + if (typeGroup.Key is not INamedTypeSymbol containingType) + { + continue; + } + + if (!isFirstTypeInNamespace) + { + writer.WriteLine(); + } + isFirstTypeInNamespace = false; + + // Write out the type, which could include parent types. + AppendNestedTypeDeclarations(writer, containingType, typeGroup, descriptionAttribute); + } + + if (!isGlobalNamespace) + { + writer.Indent--; + writer.WriteLine("}"); + } + } + + return sw.ToString(); + } + + private static void AppendNestedTypeDeclarations( + IndentedTextWriter writer, + INamedTypeSymbol typeSymbol, + IGrouping typeGroup, + INamedTypeSymbol descriptionAttribute) + { + // Build stack of nested types from innermost to outermost + Stack types = []; + for (var current = typeSymbol; current is not null; current = current.ContainingType) + { + types.Push(current); + } + + // Generate type declarations from outermost to innermost + int nestingCount = types.Count; + while (types.Count > 0) + { + // Get the type keyword and handle records + var type = types.Pop(); + var typeDecl = type.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as TypeDeclarationSyntax; + string typeKeyword; + if (typeDecl is RecordDeclarationSyntax rds) + { + string classOrStruct = rds.ClassOrStructKeyword.ValueText; + if (string.IsNullOrEmpty(classOrStruct)) + { + classOrStruct = "class"; + } + + typeKeyword = $"{typeDecl.Keyword.ValueText} {classOrStruct}"; + } + else + { + typeKeyword = typeDecl?.Keyword.ValueText ?? "class"; + } + + writer.WriteLine($"partial {typeKeyword} {type.Name}"); + writer.WriteLine("{"); + writer.Indent++; + } + + // Generate methods for this type. + bool firstMethodInType = true; + foreach (var (methodSymbol, methodDeclaration, xmlDocs) in typeGroup) + { + AppendMethodDeclaration(writer, methodSymbol, methodDeclaration, xmlDocs, descriptionAttribute, firstMethodInType); + firstMethodInType = false; + } + + // Close all type declarations. + for (int i = 0; i < nestingCount; i++) + { + writer.Indent--; + writer.WriteLine("}"); + } + } + + private static void AppendMethodDeclaration( + IndentedTextWriter writer, + IMethodSymbol methodSymbol, + MethodDeclarationSyntax methodDeclaration, + XmlDocumentation? xmlDocs, + INamedTypeSymbol descriptionAttribute, + bool firstMethodInType) + { + if (!firstMethodInType) + { + writer.WriteLine(); + } + + // Add the Description attribute for method if needed and documentation exists + if (xmlDocs is not null && + !string.IsNullOrWhiteSpace(xmlDocs.MethodDescription) && + !HasAttribute(methodSymbol, descriptionAttribute)) + { + writer.WriteLine($"[Description(\"{EscapeString(xmlDocs.MethodDescription)}\")]"); + } + + // Add return: Description attribute if needed and documentation exists + if (xmlDocs is not null && + !string.IsNullOrWhiteSpace(xmlDocs.Returns) && + methodSymbol.GetReturnTypeAttributes().All(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, descriptionAttribute))) + { + writer.WriteLine($"[return: Description(\"{EscapeString(xmlDocs.Returns)}\")]"); + } + + // Copy modifiers from original method syntax. + // Add return type (without nullable annotations). + // Add method name. + writer.Write(string.Join(" ", methodDeclaration.Modifiers.Select(m => m.Text))); + writer.Write(' '); + writer.Write(methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat)); + writer.Write(' '); + writer.Write(methodSymbol.Name); + + // Add parameters with their Description attributes. + writer.Write("("); + for (int i = 0; i < methodSymbol.Parameters.Length; i++) + { + IParameterSymbol param = methodSymbol.Parameters[i]; + + if (i > 0) + { + writer.Write(", "); + } + + if (xmlDocs is not null && + !HasAttribute(param, descriptionAttribute) && + xmlDocs.Parameters.TryGetValue(param.Name, out var paramDoc) && + !string.IsNullOrWhiteSpace(paramDoc)) + { + writer.Write($"[Description(\"{EscapeString(paramDoc)}\")] "); + } + + writer.Write(param.Type.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat)); + writer.Write(' '); + writer.Write(param.Name); + } + writer.WriteLine(");"); + } + + /// Checks if a symbol has a specific attribute applied. + private static bool HasAttribute(ISymbol symbol, INamedTypeSymbol attributeType) + { + foreach (var attr in symbol.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, attributeType)) + { + return true; + } + } + + return false; + } + + /// Escape special characters for C# string literals. + private static string EscapeString(string text) => + string.IsNullOrEmpty(text) ? text : + text.Replace("\\", "\\\\") + .Replace("\"", "\\\"") + .Replace("\r", "\\r") + .Replace("\n", "\\n") + .Replace("\t", "\\t"); + + /// Checks if XML documentation would generate any Description attributes for a method. + private static bool HasGeneratableContent(XmlDocumentation xmlDocs, IMethodSymbol methodSymbol, INamedTypeSymbol descriptionAttribute) + { + // Check if method description would be generated + if (!string.IsNullOrWhiteSpace(xmlDocs.MethodDescription) && !HasAttribute(methodSymbol, descriptionAttribute)) + { + return true; + } + + // Check if return description would be generated + if (!string.IsNullOrWhiteSpace(xmlDocs.Returns) && + methodSymbol.GetReturnTypeAttributes().All(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, descriptionAttribute))) + { + return true; + } + + // Check if any parameter descriptions would be generated + foreach (var param in methodSymbol.Parameters) + { + if (!HasAttribute(param, descriptionAttribute) && + xmlDocs.Parameters.TryGetValue(param.Name, out var paramDoc) && + !string.IsNullOrWhiteSpace(paramDoc)) + { + return true; + } + } + + return false; + } + + /// Represents a method that may need Description attributes generated. + private readonly record struct MethodToGenerate(MethodDeclarationSyntax MethodDeclaration, IMethodSymbol MethodSymbol); + + /// Holds extracted XML documentation for a method. + private sealed record XmlDocumentation(string MethodDescription, string Returns, Dictionary Parameters); +} diff --git a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj index d39c008e..cdbe25a2 100644 --- a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj +++ b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj @@ -43,6 +43,23 @@ + + + + + + + + + + + diff --git a/tests/ModelContextProtocol.Analyzers.Tests/ModelContextProtocol.Analyzers.Tests.csproj b/tests/ModelContextProtocol.Analyzers.Tests/ModelContextProtocol.Analyzers.Tests.csproj new file mode 100644 index 00000000..430db1b0 --- /dev/null +++ b/tests/ModelContextProtocol.Analyzers.Tests/ModelContextProtocol.Analyzers.Tests.csproj @@ -0,0 +1,36 @@ + + + + net9.0 + enable + true + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + diff --git a/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs b/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs new file mode 100644 index 00000000..2439f894 --- /dev/null +++ b/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs @@ -0,0 +1,1545 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using System.Diagnostics.CodeAnalysis; +using Xunit; + +namespace ModelContextProtocol.Analyzers.Tests; + +public partial class XmlToDescriptionGeneratorTests +{ + [Fact] + public void Generator_WithSummaryOnly_GeneratesMethodDescription() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// Test tool description + /// + [McpServerTool] + public static partial string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [Description("Test tool description")] + public static partial string TestMethod(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithSummaryAndRemarks_CombinesInMethodDescription() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// Test tool summary + /// + /// + /// Additional remarks + /// + [McpServerTool] + public static partial string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [Description("Test tool summary\nAdditional remarks")] + public static partial string TestMethod(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithParameterDocs_GeneratesParameterDescriptions() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// Test tool + /// + /// Input parameter description + /// Count parameter description + [McpServerTool] + public static partial string TestMethod(string input, int count) + { + return input; + } + } + """); + + Assert.True(result.Success); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [Description("Test tool")] + public static partial string TestMethod([Description("Input parameter description")] string input, [Description("Count parameter description")] int count); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithReturnDocs_GeneratesReturnDescription() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// Test tool + /// + /// The result of the operation + [McpServerTool] + public static partial string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [Description("Test tool")] + [return: Description("The result of the operation")] + public static partial string TestMethod(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithExistingMethodDescription_DoesNotGenerateMethodDescription() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// Test tool summary + /// + /// Result + [McpServerTool] + [Description("Already has description")] + public static partial string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [return: Description("Result")] + public static partial string TestMethod(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithExistingParameterDescription_SkipsThatParameter() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// Test tool + /// + /// Input description + /// Count description + [McpServerTool] + public static partial string TestMethod(string input, [Description("Already has")] int count) + { + return input; + } + } + """); + + Assert.True(result.Success); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [Description("Test tool")] + public static partial string TestMethod([Description("Input description")] string input, int count); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithoutMcpServerToolAttribute_DoesNotGenerate() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + public partial class TestTools + { + /// + /// Test tool + /// + public static partial string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + } + + [Fact] + public void Generator_WithoutPartialKeyword_DoesNotGenerate() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public class TestTools + { + /// + /// Test tool + /// + [McpServerTool] + public static string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + } + + [Fact] + public void Generator_NonPartialMethodWithXmlDocs_ReportsMCP002Diagnostic() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public class TestTools + { + /// + /// Test tool with documentation + /// + [McpServerTool] + public static string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + + // Should report MCP002 diagnostic + var diagnostic = Assert.Single(result.Diagnostics, d => d.Id == "MCP002"); + Assert.Equal(DiagnosticSeverity.Warning, diagnostic.Severity); + Assert.Contains("TestMethod", diagnostic.GetMessage()); + Assert.Contains("partial", diagnostic.GetMessage()); + } + + [Fact] + public void Generator_NonPartialMethodWithParameterDocs_ReportsMCP002Diagnostic() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public class TestTools + { + /// Input parameter + [McpServerTool] + public static string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + + // Should report MCP002 diagnostic because parameter has documentation + var diagnostic = Assert.Single(result.Diagnostics, d => d.Id == "MCP002"); + Assert.Equal(DiagnosticSeverity.Warning, diagnostic.Severity); + } + + [Fact] + public void Generator_NonPartialMethodWithReturnDocs_ReportsMCP002Diagnostic() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public class TestTools + { + /// Return value + [McpServerTool] + public static string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + + // Should report MCP002 diagnostic because return has documentation + var diagnostic = Assert.Single(result.Diagnostics, d => d.Id == "MCP002"); + Assert.Equal(DiagnosticSeverity.Warning, diagnostic.Severity); + } + + [Fact] + public void Generator_NonPartialMethodWithoutXmlDocs_DoesNotReportDiagnostic() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public class TestTools + { + [McpServerTool] + public static string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + + // Should NOT report MCP002 diagnostic because there's no XML documentation + Assert.DoesNotContain(result.Diagnostics, d => d.Id == "MCP002"); + } + + [Fact] + public void Generator_NonPartialMethodWithEmptyXmlDocs_DoesNotReportDiagnostic() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public class TestTools + { + /// + [McpServerTool] + public static string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + + // Should NOT report MCP002 diagnostic because XML documentation is empty + Assert.DoesNotContain(result.Diagnostics, d => d.Id == "MCP002"); + } + + [Fact] + public void Generator_NonPartialMethodWithExistingDescriptions_DoesNotReportDiagnostic() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public class TestTools + { + /// Test tool + /// Input param + /// Return value + [McpServerTool] + [Description("Already has method description")] + [return: Description("Already has return description")] + public static string TestMethod([Description("Already has param description")] string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + + // Should NOT report MCP002 diagnostic because all descriptions already exist + Assert.DoesNotContain(result.Diagnostics, d => d.Id == "MCP002"); + } + + [Fact] + public void Generator_NonPartialMethodWithPartialExistingDescriptions_ReportsMCP002Diagnostic() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public class TestTools + { + /// Test tool + /// Input param + [McpServerTool] + [Description("Already has method description")] + public static string TestMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + + // Should report MCP002 diagnostic because parameter description would be generated + var diagnostic = Assert.Single(result.Diagnostics, d => d.Id == "MCP002"); + Assert.Equal(DiagnosticSeverity.Warning, diagnostic.Severity); + } + + [Fact] + public void Generator_NonPartialPromptWithXmlDocs_ReportsMCP002Diagnostic() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerPromptType] + public class TestPrompts + { + /// + /// Test prompt + /// + [McpServerPrompt] + public static string TestPrompt(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + + // Should report MCP002 diagnostic for prompts too + var diagnostic = Assert.Single(result.Diagnostics, d => d.Id == "MCP002"); + Assert.Equal(DiagnosticSeverity.Warning, diagnostic.Severity); + } + + [Fact] + public void Generator_NonPartialResourceWithXmlDocs_ReportsMCP002Diagnostic() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerResourceType] + public class TestResources + { + /// + /// Test resource + /// + [McpServerResource("test://resource")] + public static string TestResource(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Empty(result.GeneratedSources); + + // Should report MCP002 diagnostic for resources too + var diagnostic = Assert.Single(result.Diagnostics, d => d.Id == "MCP002"); + Assert.Equal(DiagnosticSeverity.Warning, diagnostic.Severity); + } + + [Fact] + public void Generator_WithSpecialCharacters_EscapesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// Test with "quotes", \backslash, newline + /// and tab characters. + /// + /// Parameter with "quotes" + [McpServerTool] + public static partial string TestEscaping(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [Description("Test with \"quotes\", \\backslash, newline and tab characters.")] + public static partial string TestEscaping([Description("Parameter with \"quotes\"")] string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithInvalidXml_GeneratesPartialAndReportsDiagnostic() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// Test with + [McpServerTool] + public static partial string TestInvalidXml(string input) + { + return input; + } + } + """); + + // Should not throw, generates partial implementation without Description attributes + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + public static partial string TestInvalidXml(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + + // Should report a warning diagnostic + var diagnostic = Assert.Single(result.Diagnostics, d => d.Id == "MCP001"); + Assert.Equal(DiagnosticSeverity.Warning, diagnostic.Severity); + Assert.Contains("invalid", diagnostic.GetMessage(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public void Generator_WithGenericType_GeneratesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// Test generic + /// + [McpServerTool] + public static partial string TestGeneric(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [Description("Test generic")] + public static partial string TestGeneric(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithEmptyXmlComments_GeneratesPartialWithoutDescription() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// + [McpServerTool] + public static partial string TestEmpty(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + public static partial string TestEmpty(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithMultilineComments_CombinesIntoSingleLine() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// First line + /// Second line + /// Third line + /// + [McpServerTool] + public static partial string TestMultiline(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [Description("First line Second line Third line")] + public static partial string TestMultiline(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithParametersOnly_GeneratesParameterDescriptions() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// Input parameter + /// Count parameter + [McpServerTool] + public static partial string TestMethod(string input, int count) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + public static partial string TestMethod([Description("Input parameter")] string input, [Description("Count parameter")] int count); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithNestedType_GeneratesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + public partial class OuterClass + { + [McpServerToolType] + public partial class InnerClass + { + /// + /// Nested tool + /// + [McpServerTool] + public static partial string NestedMethod(string input) + { + return input; + } + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class OuterClass + { + partial class InnerClass + { + [Description("Nested tool")] + public static partial string NestedMethod(string input); + } + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithManyToolsAcrossMultipleNestedTypes_GeneratesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test.Outer; + + [McpServerToolType] + public partial class RootTools + { + /// Root level tool 1 + [McpServerTool] + public static partial string RootTool1(string input) => input; + + /// Root level tool 2 + /// The count + [McpServerTool] + public static partial string RootTool2(string input, int count) => input; + } + + public partial class OuterContainer + { + [McpServerToolType] + public partial class Level2A + { + /// Level 2A tool + [McpServerTool] + public static partial string Level2ATool(string input) => input; + + public partial class Level3 + { + [McpServerToolType] + public partial class Level4 + { + /// Deep nested tool + /// The result + [McpServerTool] + public static partial string DeepTool(string input) => input; + } + } + } + + [McpServerToolType] + public partial class Level2B + { + /// Level 2B tool 1 + [McpServerTool] + public static partial string Level2BTool1(string input) => input; + + /// Level 2B tool 2 + [McpServerTool] + public static partial string Level2BTool2(string input) => input; + } + } + + namespace Test.Resources; + + [McpServerResourceType] + public partial class ResourceProviders + { + /// Test resource 1 + /// The path + [McpServerResource("test:///{path}")] + public static partial string Resource1(string path) => path; + + /// Test resource 2 + [McpServerResource("test2:///{id}")] + public static partial string Resource2(string id) => id; + } + + [McpServerPromptType] + public partial class GlobalPrompts + { + /// Global prompt + [McpServerPrompt] + public static partial string GlobalPrompt(string input) => input; + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test.Outer + { + partial class RootTools + { + [Description("Root level tool 1")] + public static partial string RootTool1(string input); + + [Description("Root level tool 2")] + public static partial string RootTool2(string input, [Description("The count")] int count); + } + + partial class OuterContainer + { + partial class Level2A + { + [Description("Level 2A tool")] + public static partial string Level2ATool(string input); + } + } + + partial class OuterContainer + { + partial class Level2A + { + partial class Level3 + { + partial class Level4 + { + [Description("Deep nested tool")] + [return: Description("The result")] + public static partial string DeepTool(string input); + } + } + } + } + + partial class OuterContainer + { + partial class Level2B + { + [Description("Level 2B tool 1")] + public static partial string Level2BTool1(string input); + + [Description("Level 2B tool 2")] + public static partial string Level2BTool2(string input); + } + } + } + + namespace Test.Outer.Test.Resources + { + partial class GlobalPrompts + { + [Description("Global prompt")] + public static partial string GlobalPrompt(string input); + } + + partial class ResourceProviders + { + [Description("Test resource 1")] + public static partial string Resource1([Description("The path")] string path); + + [Description("Test resource 2")] + public static partial string Resource2(string id); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithRecordClass_GeneratesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial record TestTools + { + /// + /// Record tool + /// + [McpServerTool] + public static partial string RecordMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial record class TestTools + { + [Description("Record tool")] + public static partial string RecordMethod(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithRecordStruct_GeneratesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial record struct TestTools + { + /// + /// Record struct tool + /// + [McpServerTool] + public static partial string RecordStructMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial record struct TestTools + { + [Description("Record struct tool")] + public static partial string RecordStructMethod(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithVirtualMethod_GeneratesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// + /// Virtual tool + /// + [McpServerTool] + public virtual partial string VirtualMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [Description("Virtual tool")] + public virtual partial string VirtualMethod(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithAbstractMethod_GeneratesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerToolType] + public abstract partial class TestTools + { + /// + /// Abstract tool + /// + [McpServerTool] + public abstract partial string AbstractMethod(string input); + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestTools + { + [Description("Abstract tool")] + public abstract partial string AbstractMethod(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithMcpServerPrompt_GeneratesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerPromptType] + public partial class TestPrompts + { + /// + /// Test prompt + /// + [McpServerPrompt] + public static partial string TestPrompt(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestPrompts + { + [Description("Test prompt")] + public static partial string TestPrompt(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithMcpServerResource_GeneratesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + namespace Test; + + [McpServerResourceType] + public partial class TestResources + { + /// + /// Test resource + /// + [McpServerResource("test://resource")] + public static partial string TestResource(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + namespace Test + { + partial class TestResources + { + [Description("Test resource")] + public static partial string TestResource(string input); + } + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Generator_WithGlobalNamespace_GeneratesCorrectly() + { + var result = RunGenerator(""" + using ModelContextProtocol.Server; + using System.ComponentModel; + + [McpServerToolType] + public partial class GlobalTools + { + /// + /// Tool in global namespace + /// + [McpServerTool] + public static partial string GlobalMethod(string input) + { + return input; + } + } + """); + + Assert.True(result.Success); + Assert.Single(result.GeneratedSources); + + var expected = $$""" + // + // ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}} + + #pragma warning disable + + using System.ComponentModel; + using ModelContextProtocol.Server; + + partial class GlobalTools + { + [Description("Tool in global namespace")] + public static partial string GlobalMethod(string input); + } + """; + + AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString()); + } + + private GeneratorRunResult RunGenerator([StringSyntax("C#-test")] string source) + { + var syntaxTree = CSharpSyntaxTree.ParseText(source); + + // Get reference assemblies - we need to include all the basic runtime types + List referenceList = + [ + MetadataReference.CreateFromFile(typeof(object).Assembly.Location), + MetadataReference.CreateFromFile(typeof(System.ComponentModel.DescriptionAttribute).Assembly.Location), + ]; + + // Add all necessary runtime assemblies + var runtimePath = Path.GetDirectoryName(typeof(object).Assembly.Location)!; + referenceList.Add(MetadataReference.CreateFromFile(Path.Combine(runtimePath, "System.Runtime.dll"))); + referenceList.Add(MetadataReference.CreateFromFile(Path.Combine(runtimePath, "netstandard.dll"))); + + // Try to find and add ModelContextProtocol.Core + try + { + var coreAssemblyPath = Path.Combine(AppContext.BaseDirectory, "ModelContextProtocol.Core.dll"); + if (File.Exists(coreAssemblyPath)) + { + referenceList.Add(MetadataReference.CreateFromFile(coreAssemblyPath)); + } + } + catch + { + // If we can't find it, the compilation will fail with appropriate errors + } + + var compilation = CSharpCompilation.Create( + "TestAssembly", + [syntaxTree], + referenceList, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + var driver = (CSharpGeneratorDriver)CSharpGeneratorDriver + .Create(new XmlToDescriptionGenerator()) + .RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var diagnostics); + + var runResult = driver.GetRunResult(); + + return new GeneratorRunResult + { + Success = !diagnostics.Any(d => d.Severity == DiagnosticSeverity.Error), + GeneratedSources = runResult.GeneratedTrees.Select(t => (t.FilePath, t.GetText())).ToList(), + Diagnostics = diagnostics.ToList(), + Compilation = outputCompilation + }; + } + + private static void AssertGeneratedSourceEquals( + [StringSyntax("C#-test")] string expected, + [StringSyntax("C#-test")] string actual) + { + // Normalize line endings to \n, remove trailing whitespace from each line, and trim the end + static string Normalize(string s) + { + var lines = s.Replace("\r\n", "\n").Replace("\r", "\n").Split('\n'); + for (int i = 0; i < lines.Length; i++) + { + lines[i] = lines[i].TrimEnd(); + } + return string.Join('\n', lines).TrimEnd(); + } + + var normalizedExpected = Normalize(expected); + var normalizedActual = Normalize(actual); + + Assert.Equal(normalizedExpected, normalizedActual); + } + + private class GeneratorRunResult + { + public bool Success { get; set; } + public List<(string FilePath, Microsoft.CodeAnalysis.Text.SourceText SourceText)> GeneratedSources { get; set; } = []; + public List Diagnostics { get; set; } = []; + public Compilation? Compilation { get; set; } + } +}