diff --git a/src/DefaultBuilder/src/PublicAPI.Unshipped.txt b/src/DefaultBuilder/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..151a0f960cd2 100644 --- a/src/DefaultBuilder/src/PublicAPI.Unshipped.txt +++ b/src/DefaultBuilder/src/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ #nullable enable +Microsoft.AspNetCore.Builder.WebApplication.EndpointConventions.get -> Microsoft.AspNetCore.Builder.IEndpointConventionBuilder! diff --git a/src/DefaultBuilder/src/WebApplication.cs b/src/DefaultBuilder/src/WebApplication.cs index ed700260a0e4..d90c8e3bb127 100644 --- a/src/DefaultBuilder/src/WebApplication.cs +++ b/src/DefaultBuilder/src/WebApplication.cs @@ -25,17 +25,24 @@ namespace Microsoft.AspNetCore.Builder; public sealed class WebApplication : IHost, IApplicationBuilder, IEndpointRouteBuilder, IAsyncDisposable { internal const string GlobalEndpointRouteBuilderKey = "__GlobalEndpointRouteBuilder"; + internal const string GlobalRouteGroupBuilderKey = "__GlobalRouteGroupBuilder"; private readonly IHost _host; - private readonly List _dataSources = new(); + private readonly GlobalEndpointRouteBuilder _globalEndpointRouteBuilder; + private readonly RouteGroupBuilder _globalRouteGroupBuilder; internal WebApplication(IHost host) { _host = host; + + _globalEndpointRouteBuilder = new(this); + _globalRouteGroupBuilder = _globalEndpointRouteBuilder.MapGroup(string.Empty); + ApplicationBuilder = new ApplicationBuilder(host.Services, ServerFeatures); Logger = host.Services.GetRequiredService().CreateLogger(Environment.ApplicationName ?? nameof(WebApplication)); - Properties[GlobalEndpointRouteBuilderKey] = this; + Properties[GlobalEndpointRouteBuilderKey] = _globalEndpointRouteBuilder; + Properties[GlobalRouteGroupBuilderKey] = _globalRouteGroupBuilder; } /// @@ -80,9 +87,14 @@ IServiceProvider IApplicationBuilder.ApplicationServices internal IDictionary Properties => ApplicationBuilder.Properties; IDictionary IApplicationBuilder.Properties => Properties; - internal ICollection DataSources => _dataSources; + internal ICollection DataSources => ((IEndpointRouteBuilder)_globalRouteGroupBuilder).DataSources; ICollection IEndpointRouteBuilder.DataSources => DataSources; + /// + /// Gets the for the application. + /// + public IEndpointConventionBuilder EndpointConventions => _globalRouteGroupBuilder; + internal ApplicationBuilder ApplicationBuilder { get; } IServiceProvider IEndpointRouteBuilder.ServiceProvider => Services; @@ -213,6 +225,7 @@ IApplicationBuilder IApplicationBuilder.New() var newBuilder = ApplicationBuilder.New(); // Remove the route builder so branched pipelines have their own routing world newBuilder.Properties.Remove(GlobalEndpointRouteBuilderKey); + newBuilder.Properties.Remove(GlobalRouteGroupBuilderKey); return newBuilder; } @@ -307,4 +320,13 @@ public IList? Middleware } } } + + private sealed class GlobalEndpointRouteBuilder(WebApplication application) : IEndpointRouteBuilder + { + public IServiceProvider ServiceProvider => application.Services; + + public ICollection DataSources { get; } = []; + + public IApplicationBuilder CreateApplicationBuilder() => ((IApplicationBuilder)application).New(); + } } diff --git a/src/DefaultBuilder/src/WebApplicationBuilder.cs b/src/DefaultBuilder/src/WebApplicationBuilder.cs index 2cff4ae4ffb9..20d7080b6880 100644 --- a/src/DefaultBuilder/src/WebApplicationBuilder.cs +++ b/src/DefaultBuilder/src/WebApplicationBuilder.cs @@ -407,8 +407,11 @@ private void ConfigureApplication(WebHostBuilderContext context, IApplicationBui // destination.Run(source) // destination.UseEndpoints() - // Set the route builder so that UseRouting will use the WebApplication as the IEndpointRouteBuilder for route matching - app.Properties.Add(WebApplication.GlobalEndpointRouteBuilderKey, _builtApplication); + // Set the route builder so that UseRouting will use the RouteGroupBuilder managed + // by the WebApplication as the IEndpointRouteBuilder instance + var globalRouteGroupBuilder = _builtApplication.Properties[WebApplication.GlobalEndpointRouteBuilderKey]; + app.Properties.Add(WebApplication.GlobalEndpointRouteBuilderKey, globalRouteGroupBuilder); + app.Properties.Add(WebApplication.GlobalRouteGroupBuilderKey, _builtApplication.Properties[WebApplication.GlobalRouteGroupBuilderKey]); // Only call UseRouting() if there are endpoints configured and UseRouting() wasn't called on the global route builder already if (_builtApplication.DataSources.Count > 0) diff --git a/src/DefaultBuilder/test/Microsoft.AspNetCore.Tests/Microsoft.AspNetCore.Tests.csproj b/src/DefaultBuilder/test/Microsoft.AspNetCore.Tests/Microsoft.AspNetCore.Tests.csproj index 2bb59ab8befe..fd0e2842fa3a 100644 --- a/src/DefaultBuilder/test/Microsoft.AspNetCore.Tests/Microsoft.AspNetCore.Tests.csproj +++ b/src/DefaultBuilder/test/Microsoft.AspNetCore.Tests/Microsoft.AspNetCore.Tests.csproj @@ -6,6 +6,10 @@ + + + + diff --git a/src/DefaultBuilder/test/Microsoft.AspNetCore.Tests/WebApplicationGlobalConventionTests.cs b/src/DefaultBuilder/test/Microsoft.AspNetCore.Tests/WebApplicationGlobalConventionTests.cs new file mode 100644 index 000000000000..7ad512e83641 --- /dev/null +++ b/src/DefaultBuilder/test/Microsoft.AspNetCore.Tests/WebApplicationGlobalConventionTests.cs @@ -0,0 +1,341 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Reflection; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.ApplicationParts; +using Microsoft.AspNetCore.Mvc.Controllers; +using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.FileProviders; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.Tests; + +public class WebApplicationGlobalConventionTests +{ + [Fact] + public async Task SupportsApplyingConventionsOnAllEndpoints() + { + var builder = WebApplication.CreateBuilder(); + builder.WebHost.UseTestServer(); + var app = builder.Build(); + + app.EndpointConventions.WithMetadata(new EndpointGroupNameAttribute("global")); + + app.MapGet("/1", () => "Hello, world!").WithName("One"); + app.MapGet("/2", () => "Hello, world!").WithName("Two"); + + await app.StartAsync(); + + var endpointDataSource = app.Services.GetRequiredService(); + Assert.Collection(endpointDataSource.Endpoints, + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + var nameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + Assert.Equal("One", nameMetadata.EndpointName); + }, + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + var nameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + Assert.Equal("Two", nameMetadata.EndpointName); + } + ); + + await app.StopAsync(); + } + + [Fact] + public async Task LocalConventionsOverrideGlobalConventions() + { + var builder = WebApplication.CreateBuilder(); + builder.WebHost.UseTestServer(); + var app = builder.Build(); + + app.EndpointConventions.WithMetadata(new EndpointGroupNameAttribute("one")); + + var group = app.MapGroup("/hello") + .WithMetadata(new EndpointGroupNameAttribute("two")); + + group.MapGet("/", () => "Hello world!") + .WithMetadata(new EndpointGroupNameAttribute("three")); + + await app.StartAsync(); + + var endpointDataSource = app.Services.GetRequiredService(); + Assert.Collection(endpointDataSource.Endpoints, + endpoint => + { + var metadata = endpoint.Metadata.OfType(); + Assert.Collection(metadata, + metadata => Assert.Equal("one", metadata.EndpointGroupName), + metadata => Assert.Equal("two", metadata.EndpointGroupName), + metadata => Assert.Equal("three", metadata.EndpointGroupName)); + var targetMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("three", targetMetadata.EndpointGroupName); + } + ); + + await app.StopAsync(); + } + + [Fact] + public async Task CanAccessCorrectBuilderInConvention() + { + var builder = WebApplication.CreateBuilder(); + builder.WebHost.UseTestServer(); + var app = builder.Build(); + + app.EndpointConventions.Add(builder => + { + if (builder is RouteEndpointBuilder { RoutePattern.RawText: "/1" }) + { + builder.Metadata.Add(new EndpointGroupNameAttribute("global")); + } + }); + + app.MapGet("/1", () => "One"); + app.MapGet("/2", () => " Two"); + + await app.StartAsync(); + + var endpointDataSource = app.Services.GetRequiredService(); + Assert.Collection(endpointDataSource.Endpoints, + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + }, + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Null(groupNameMetadata); + + } + ); + + await app.StopAsync(); + } + + [Fact] + public async Task CanAccessCorrectServiceProviderInConvention() + { + IServiceProvider globalConventionServiceProvider = null; + var builder = WebApplication.CreateBuilder(); + builder.WebHost.UseTestServer(); + var app = builder.Build(); + + app.EndpointConventions.Add(builder => + { + globalConventionServiceProvider = builder.ApplicationServices; + }); + + app.MapGet("/", () => "Hello, world!"); + + await app.StartAsync(); + + var endpointDataSource = app.Services.GetRequiredService(); + Assert.NotEmpty(endpointDataSource.Endpoints); + Assert.Equal(app.Services, globalConventionServiceProvider); + + await app.StopAsync(); + } + + [Fact] + public async Task BranchedPipelinesSupportGlobalConventions() + { + var builder = WebApplication.CreateBuilder(); + builder.WebHost.UseTestServer(); + var app = builder.Build(); + + app.EndpointConventions.WithMetadata(new EndpointGroupNameAttribute("global")); + + app.UseRouting(); + + app.MapGet("/1", () => "Hello, world!").WithName("One"); + + app.UseEndpoints(e => + { + e.MapGet("/2", () => "Hello, world!").WithName("Two"); + }); + + await app.StartAsync(); + + var endpointDataSource = app.Services.GetRequiredService(); + Assert.Collection(endpointDataSource.Endpoints, + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + var nameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + Assert.Equal("One", nameMetadata.EndpointName); + }, + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + var nameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + Assert.Equal("Two", nameMetadata.EndpointName); + } + ); + + await app.StopAsync(); + } + + [Fact] + public async Task SupportsGlobalConventionsOnRouteEndpoints() + { + var builder = WebApplication.CreateBuilder(); + builder.WebHost.UseTestServer(); + + builder.Services.AddSignalR(); + builder.Services.AddRazorComponents(); + builder.Services.AddControllers() + .ConfigureApplicationPartManager(apm => + { + apm.FeatureProviders.Clear(); + apm.FeatureProviders.Add(new TestControllerFeatureProvider()); + }); + + var app = builder.Build(); + + app.EndpointConventions.WithMetadata(new EndpointGroupNameAttribute("global")); + + app.MapGet("/", () => "Hello, world!"); + app.MapHub("/test-hub"); + app.MapRazorComponents(); + app.MapControllers(); + + await app.StartAsync(); + + var endpointDataSource = app.Services.GetRequiredService(); + Assert.Collection(endpointDataSource.Endpoints, + // Route handler endpoints + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + }, + // SignalR produces two endpoints per hub + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + }, + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + }, + // Razor component endpoint + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + }, + // MapController endpoint + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + }, + // Controller-based endpoints + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Equal("global", groupNameMetadata.EndpointGroupName); + } + ); + + await app.StopAsync(); + } + + [Fact] + public async Task DoesNotThrowExceptionOnNonRouteEndpointsAtTopLevelWhenConventionNotUsed() + { + var builder = WebApplication.CreateBuilder(); + builder.WebHost.UseTestServer(); + var app = builder.Build(); + + app.DataSources.Add(new CustomEndpointDataSource()); + + await app.StartAsync(); + + var endpointDataSource = app.Services.GetRequiredService(); + Assert.Collection(endpointDataSource.Endpoints, + endpoint => + { + var groupNameMetadata = endpoint.Metadata.GetMetadata(); + Assert.Null(groupNameMetadata); + } + ); + + await app.StopAsync(); + } + + [Fact] + public async Task ThrowsExceptionOnNonRouteEndpointsAtTopLevelWhenConventionUsed() + { + var builder = WebApplication.CreateBuilder(); + builder.WebHost.UseTestServer(); + var app = builder.Build(); + + app.EndpointConventions.WithMetadata(new EndpointGroupNameAttribute("global")); + + app.DataSources.Add(new CustomEndpointDataSource()); + + await app.StartAsync(); + + var endpointDataSource = app.Services.GetRequiredService(); + var ex = Assert.Throws(() => endpointDataSource.Endpoints); + Assert.Equal("MapGroup does not support custom Endpoint type 'Microsoft.AspNetCore.Tests.WebApplicationGlobalConventionTests+TestCustomEndpoint'. Only RouteEndpoints can be grouped.", ex.Message); + + await app.StopAsync(); + } + + private class TestHub : Hub { } + private class TestComponent { } + + private class TestController : Controller + { + [HttpGet("/")] + public void Index() { } + } + + [ApiController] + private class MyApiController : ControllerBase + { + [HttpGet("other")] + public void Index() { } + } + + private class TestControllerFeatureProvider : IApplicationFeatureProvider + { + public void PopulateFeature(IEnumerable parts, ControllerFeature feature) + { + feature.Controllers.Clear(); + feature.Controllers.Add(typeof(TestController).GetTypeInfo()); + feature.Controllers.Add(typeof(MyApiController).GetTypeInfo()); + } + } + + private sealed class TestCustomEndpoint : Endpoint + { + public TestCustomEndpoint() : base(null, null, null) { } + } + + private sealed class CustomEndpointDataSource : EndpointDataSource + { + public override IReadOnlyList Endpoints => [new TestCustomEndpoint()]; + public override IChangeToken GetChangeToken() => NullChangeToken.Singleton; + } +} diff --git a/src/Http/Routing/src/Builder/EndpointRoutingApplicationBuilderExtensions.cs b/src/Http/Routing/src/Builder/EndpointRoutingApplicationBuilderExtensions.cs index f8fbd3864f89..c3a546b6f684 100644 --- a/src/Http/Routing/src/Builder/EndpointRoutingApplicationBuilderExtensions.cs +++ b/src/Http/Routing/src/Builder/EndpointRoutingApplicationBuilderExtensions.cs @@ -15,6 +15,7 @@ public static class EndpointRoutingApplicationBuilderExtensions { private const string EndpointRouteBuilder = "__EndpointRouteBuilder"; private const string GlobalEndpointRouteBuilderKey = "__GlobalEndpointRouteBuilder"; + private const string GlobalRouteGroupBuilderKey = "__GlobalRouteGroupBuilder"; private const string UseRoutingKey = "__UseRouting"; /// @@ -92,9 +93,13 @@ public static IApplicationBuilder UseEndpoints(this IApplicationBuilder builder, VerifyRoutingServicesAreRegistered(builder); - VerifyEndpointRoutingMiddlewareIsRegistered(builder, out var endpointRouteBuilder); + // When registering endpoints, we want to target the underlying RouteGroupBuilder + // that is created by the WebApplication for targeting map action calls. However, we want + // to target the GlobalEndpointRouteBuilder which contains the GroupedEndpointDataSource + // for any conventions and EndpointDataSource resolution. + VerifyEndpointRoutingMiddlewareIsRegistered(builder, out var routeGroupBuilder, out var endpointRouteBuilder); - configure(endpointRouteBuilder); + configure(routeGroupBuilder); // Yes, this mutates an IOptions. We're registering data sources in a global collection which // can be used for discovery of endpoints or URL generation. @@ -126,7 +131,7 @@ private static void VerifyRoutingServicesAreRegistered(IApplicationBuilder app) } } - private static void VerifyEndpointRoutingMiddlewareIsRegistered(IApplicationBuilder app, out IEndpointRouteBuilder endpointRouteBuilder) + private static void VerifyEndpointRoutingMiddlewareIsRegistered(IApplicationBuilder app, out IEndpointRouteBuilder routeGroupBuilder, out IEndpointRouteBuilder endpointRouteBuilder) { if (!app.Properties.TryGetValue(EndpointRouteBuilder, out var obj)) { @@ -138,6 +143,10 @@ private static void VerifyEndpointRoutingMiddlewareIsRegistered(IApplicationBuil throw new InvalidOperationException(message); } + routeGroupBuilder = app.Properties.TryGetValue(GlobalRouteGroupBuilderKey, out var globalRouteGroupBuilder) + ? (IEndpointRouteBuilder)globalRouteGroupBuilder! + : (IEndpointRouteBuilder)obj!; + endpointRouteBuilder = (IEndpointRouteBuilder)obj!; // This check handles the case where Map or something else that forks the pipeline is called between the two diff --git a/src/Http/Routing/src/EndpointDataSource.cs b/src/Http/Routing/src/EndpointDataSource.cs index 83b65d054a52..14a46268a334 100644 --- a/src/Http/Routing/src/EndpointDataSource.cs +++ b/src/Http/Routing/src/EndpointDataSource.cs @@ -36,12 +36,23 @@ public virtual IReadOnlyList GetGroupedEndpoints(RouteGroupContext con { // Only evaluate Endpoints once per call. var endpoints = Endpoints; - var wrappedEndpoints = new RouteEndpoint[endpoints.Count]; + var wrappedEndpoints = new Endpoint[endpoints.Count]; for (int i = 0; i < endpoints.Count; i++) { var endpoint = endpoints[i]; + if (context.Conventions.Count == 0 + && context.FinallyConventions.Count == 0 + && endpoint is not RouteEndpoint) + { + // No conventions to apply, so just return the endpoints as-is. This supports + // scenarios where the endpoint is registered as part of the global route group + // handler on the WebApplication. + wrappedEndpoints[i] = endpoint; + continue; + } + // Endpoint does not provide a RoutePattern but RouteEndpoint does. So it's impossible to apply a prefix for custom Endpoints. // Supporting arbitrary Endpoints just to add group metadata would require changing the Endpoint type breaking any real scenario. if (endpoint is not RouteEndpoint routeEndpoint) diff --git a/src/Http/Routing/test/UnitTests/Builder/GroupTest.cs b/src/Http/Routing/test/UnitTests/Builder/GroupTest.cs index 39c0539f3b4d..7609173538b5 100644 --- a/src/Http/Routing/test/UnitTests/Builder/GroupTest.cs +++ b/src/Http/Routing/test/UnitTests/Builder/GroupTest.cs @@ -298,12 +298,13 @@ public async Task ChangingMostEndpointBuilderPropertiesInConvention_Works() } [Fact] - public void GivenNonRouteEndpoint_ThrowsNotSupportedException() + public void GivenNonRouteEndpoint_WithConventions_ThrowsNotSupportedException() { var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(EmptyServiceProvider.Instance)); var group = builder.MapGroup("/group"); - ((IEndpointRouteBuilder)group).DataSources.Add(new TestCustomEndpintDataSource()); + group.WithMetadata(new EndpointGroupNameAttribute("group")); + ((IEndpointRouteBuilder)group).DataSources.Add(new TestCustomEndpointDataSource()); var dataSource = GetEndpointDataSource(builder); var ex = Assert.Throws(() => dataSource.Endpoints); @@ -313,6 +314,20 @@ public void GivenNonRouteEndpoint_ThrowsNotSupportedException() ex.Message); } + [Fact] + public void GivenNonRouteEndpoint_WithNoConventions_ReturnsEndpointAsIs() + { + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(EmptyServiceProvider.Instance)); + + var group = builder.MapGroup("/group"); + ((IEndpointRouteBuilder)group).DataSources.Add(new TestCustomEndpointDataSource()); + + var dataSource = GetEndpointDataSource(builder); + var endpoint = Assert.Single(dataSource.Endpoints); + + Assert.IsType(endpoint); + } + [Fact] public void OuterGroupMetadata_AddedFirst() { @@ -388,7 +403,7 @@ private sealed class TestCustomEndpoint : Endpoint public TestCustomEndpoint() : base(null, null, null) { } } - private sealed class TestCustomEndpintDataSource : EndpointDataSource + private sealed class TestCustomEndpointDataSource : EndpointDataSource { public override IReadOnlyList Endpoints => new[] { new TestCustomEndpoint() }; public override IChangeToken GetChangeToken() => throw new NotImplementedException();