diff --git a/AspNetCore.slnx b/AspNetCore.slnx index cf6f14abe556..03980d7b60a4 100644 --- a/AspNetCore.slnx +++ b/AspNetCore.slnx @@ -19,6 +19,9 @@ + + + diff --git a/src/Antiforgery/Antiforgery.slnf b/src/Antiforgery/Antiforgery.slnf index 7ce6a6ecf3dd..03f3fd7f5bc3 100644 --- a/src/Antiforgery/Antiforgery.slnf +++ b/src/Antiforgery/Antiforgery.slnf @@ -1,10 +1,11 @@ -{ +{ "solution": { "path": "..\\..\\AspNetCore.slnx", "projects": [ + "src\\Antiforgery\\benchmarks\\Microsoft.AspNetCore.Antiforgery.Benchmarks\\Microsoft.AspNetCore.Antiforgery.Benchmarks.csproj", "src\\Antiforgery\\samples\\MinimalFormSample\\MinimalFormSample.csproj", "src\\Antiforgery\\src\\Microsoft.AspNetCore.Antiforgery.csproj", "src\\Antiforgery\\test\\Microsoft.AspNetCore.Antiforgery.Test.csproj" ] } -} +} \ No newline at end of file diff --git a/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Benchmarks/AntiforgeryBenchmarks.cs b/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Benchmarks/AntiforgeryBenchmarks.cs new file mode 100644 index 000000000000..411ac3c51480 --- /dev/null +++ b/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Benchmarks/AntiforgeryBenchmarks.cs @@ -0,0 +1,165 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Security.Claims; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.Antiforgery.Benchmarks.Benchmarks; + +[AspNetCoreBenchmark] +public class AntiforgeryBenchmarks +{ + private IServiceProvider _serviceProvider = null!; + private IAntiforgery _antiforgery = null!; + private string _cookieName = null!; + private string _formFieldName = null!; + + // Reusable contexts - reset between iterations instead of recreating + private DefaultHttpContext _getAndStoreTokensContext = null!; + private DefaultHttpContext _validateRequestContext = null!; + private TestHttpResponseFeature _getAndStoreTokensResponseFeature = null!; + + // Pre-generated tokens for validation benchmark + private string _cookieToken = null!; + private string _requestToken = null!; + + // Pre-allocated form collection for validation benchmark + private FormCollection _validationFormCollection = null!; + + [GlobalSetup] + public void Setup() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddAntiforgery(); + serviceCollection.AddLogging(); + _serviceProvider = serviceCollection.BuildServiceProvider(); + + _antiforgery = _serviceProvider.GetRequiredService(); + + // Get the actual cookie and form field names from options + var options = _serviceProvider.GetRequiredService>().Value; + _cookieName = options.Cookie.Name!; + _formFieldName = options.FormFieldName; + + // Create reusable context for GetAndStoreTokens + _getAndStoreTokensResponseFeature = new TestHttpResponseFeature(); + _getAndStoreTokensContext = CreateHttpContext(_getAndStoreTokensResponseFeature); + + // Generate tokens for validation benchmark + var tokenContext = CreateHttpContext(new TestHttpResponseFeature()); + var tokenSet = _antiforgery.GetAndStoreTokens(tokenContext); + _cookieToken = tokenSet.CookieToken!; + _requestToken = tokenSet.RequestToken!; + + // Pre-allocate form collection for validation + _validationFormCollection = new FormCollection(new Dictionary + { + { _formFieldName, _requestToken } + }); + + // Create reusable context for ValidateRequestAsync + _validateRequestContext = CreateHttpContextWithTokens(); + } + + [IterationSetup(Target = nameof(GetAndStoreTokens))] + public void SetupGetAndStoreTokens() + { + // Reset the context instead of creating a new one + ResetHttpContextForGetAndStoreTokens(); + } + + [IterationSetup(Target = nameof(ValidateRequestAsync))] + public void SetupValidateRequest() + { + // Reset the context instead of creating a new one + ResetHttpContextForValidation(); + } + + [Benchmark] + public AntiforgeryTokenSet GetAndStoreTokens() + { + return _antiforgery.GetAndStoreTokens(_getAndStoreTokensContext); + } + + [Benchmark] + public Task ValidateRequestAsync() + { + return _antiforgery.ValidateRequestAsync(_validateRequestContext); + } + + private DefaultHttpContext CreateHttpContext(TestHttpResponseFeature responseFeature) + { + var context = new DefaultHttpContext(); + context.RequestServices = _serviceProvider; + + // Create an authenticated identity with a Name claim (required by antiforgery) + var identity = new ClaimsIdentity( + [new Claim(ClaimsIdentity.DefaultNameClaimType, "testuser@example.com")], + "TestAuth"); + context.User = new ClaimsPrincipal(identity); + + context.Request.Method = "POST"; + context.Request.ContentType = "application/x-www-form-urlencoded"; + + // Setup response features to allow cookie writing + context.Features.Set(responseFeature); + context.Features.Set(new StreamResponseBodyFeature(Stream.Null)); + + return context; + } + + private DefaultHttpContext CreateHttpContextWithTokens() + { + var context = new DefaultHttpContext(); + context.RequestServices = _serviceProvider; + + // Create an authenticated identity with a Name claim (required by antiforgery) + var identity = new ClaimsIdentity( + [new Claim(ClaimsIdentity.DefaultNameClaimType, "testuser@example.com")], + "TestAuth"); + context.User = new ClaimsPrincipal(identity); + + context.Request.Method = "POST"; + context.Request.ContentType = "application/x-www-form-urlencoded"; + + // Set the cookie token using the actual cookie name from options + context.Request.Headers.Cookie = $"{_cookieName}={_cookieToken}"; + + // Set the request token in form using the pre-allocated form collection + context.Request.Form = _validationFormCollection; + + return context; + } + + private void ResetHttpContextForGetAndStoreTokens() + { + // Clear the antiforgery feature so it generates fresh tokens + _getAndStoreTokensContext.Features.Set(null); + + // Reset response headers that antiforgery sets + _getAndStoreTokensResponseFeature.Headers.Clear(); + } + + private void ResetHttpContextForValidation() + { + // Clear the antiforgery feature so it deserializes tokens fresh + _validateRequestContext.Features.Set(null); + } + + private sealed class TestHttpResponseFeature : IHttpResponseFeature + { + public int StatusCode { get; set; } = 200; + public string? ReasonPhrase { get; set; } + public IHeaderDictionary Headers { get; set; } = new HeaderDictionary(); + public Stream Body { get; set; } = Stream.Null; + public bool HasStarted => false; + + public void OnStarting(Func callback, object state) { } + public void OnCompleted(Func callback, object state) { } + } +} diff --git a/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Benchmarks/AntiforgeryTokenGeneratorBenchmarks.cs b/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Benchmarks/AntiforgeryTokenGeneratorBenchmarks.cs new file mode 100644 index 000000000000..70d6edb70526 --- /dev/null +++ b/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Benchmarks/AntiforgeryTokenGeneratorBenchmarks.cs @@ -0,0 +1,132 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Security.Claims; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Antiforgery.Benchmarks.Benchmarks; + +[AspNetCoreBenchmark] +public class AntiforgeryTokenGeneratorBenchmarks +{ + private IAntiforgeryTokenGenerator _tokenGenerator = null!; + + // Anonymous user scenario + private HttpContext _anonymousHttpContext = null!; + private AntiforgeryToken _anonymousCookieToken = null!; + private AntiforgeryToken _anonymousRequestToken = null!; + + // Authenticated user with username scenario + private HttpContext _authenticatedHttpContext = null!; + private AntiforgeryToken _authenticatedCookieToken = null!; + private AntiforgeryToken _authenticatedRequestToken = null!; + + // Claims-based user scenario + private HttpContext _claimsHttpContext = null!; + private AntiforgeryToken _claimsCookieToken = null!; + private AntiforgeryToken _claimsRequestToken = null!; + + [GlobalSetup] + public void Setup() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddAntiforgery(); + var serviceProvider = serviceCollection.BuildServiceProvider(); + + _tokenGenerator = serviceProvider.GetRequiredService(); + + // Setup anonymous user scenario + _anonymousHttpContext = new DefaultHttpContext(); + _anonymousHttpContext.User = new ClaimsPrincipal(new ClaimsIdentity()); + + _anonymousCookieToken = new AntiforgeryToken { IsCookieToken = true }; + _anonymousRequestToken = new AntiforgeryToken + { + IsCookieToken = false, + SecurityToken = _anonymousCookieToken.SecurityToken, + Username = string.Empty + }; + + // Setup authenticated user with username scenario + _authenticatedHttpContext = new DefaultHttpContext(); + var authenticatedIdentity = new ClaimsIdentity( + [new Claim(ClaimsIdentity.DefaultNameClaimType, "testuser@example.com")], + "TestAuthentication"); + _authenticatedHttpContext.User = new ClaimsPrincipal(authenticatedIdentity); + + _authenticatedCookieToken = new AntiforgeryToken { IsCookieToken = true }; + _authenticatedRequestToken = new AntiforgeryToken + { + IsCookieToken = false, + SecurityToken = _authenticatedCookieToken.SecurityToken, + Username = "testuser@example.com" + }; + + // Setup claims-based user scenario + _claimsHttpContext = new DefaultHttpContext(); + var claimsIdentity = new ClaimsIdentity( + [ + new Claim(ClaimsIdentity.DefaultNameClaimType, "claimsuser@example.com"), + new Claim("sub", "user-id-12345"), + new Claim(ClaimTypes.NameIdentifier, "unique-id") + ], + "ClaimsAuthentication"); + _claimsHttpContext.User = new ClaimsPrincipal(claimsIdentity); + + _claimsCookieToken = new AntiforgeryToken { IsCookieToken = true }; + + // For claims-based users, we need to extract the ClaimUid + var claimUid = new byte[32]; + _ = new DefaultClaimUidExtractor().TryExtractClaimUidBytes(_claimsHttpContext.User, claimUid); + _claimsRequestToken = new AntiforgeryToken + { + IsCookieToken = false, + SecurityToken = _claimsCookieToken.SecurityToken, + ClaimUid = claimUid is not null ? new BinaryBlob(256, claimUid) : null + }; + } + + [Benchmark] + public object GenerateRequestToken_Anonymous() + { + return _tokenGenerator.GenerateRequestToken(_anonymousHttpContext, _anonymousCookieToken); + } + + [Benchmark] + public object GenerateRequestToken_Authenticated() + { + return _tokenGenerator.GenerateRequestToken(_authenticatedHttpContext, _authenticatedCookieToken); + } + + [Benchmark] + public bool TryValidateTokenSet_Anonymous() + { + return _tokenGenerator.TryValidateTokenSet( + _anonymousHttpContext, + _anonymousCookieToken, + _anonymousRequestToken, + out _); + } + + [Benchmark] + public bool TryValidateTokenSet_Authenticated() + { + return _tokenGenerator.TryValidateTokenSet( + _authenticatedHttpContext, + _authenticatedCookieToken, + _authenticatedRequestToken, + out _); + } + + [Benchmark] + public bool TryValidateTokenSet_ClaimsBased() + { + return _tokenGenerator.TryValidateTokenSet( + _claimsHttpContext, + _claimsCookieToken, + _claimsRequestToken, + out _); + } +} diff --git a/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Benchmarks/AntiforgeryTokenSerializerBenchmarks.cs b/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Benchmarks/AntiforgeryTokenSerializerBenchmarks.cs new file mode 100644 index 000000000000..3ea2e5278f88 --- /dev/null +++ b/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Benchmarks/AntiforgeryTokenSerializerBenchmarks.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using BenchmarkDotNet.Attributes; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Antiforgery.Benchmarks.Benchmarks; + +[AspNetCoreBenchmark] +public class AntiforgeryTokenSerializerBenchmarks +{ +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + private IAntiforgeryTokenSerializer _tokenSerializer; + + private AntiforgeryToken _token; + private string _serializedToken; +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + + [GlobalSetup] + public void Setup() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddAntiforgery(); + var serviceProvider = serviceCollection.BuildServiceProvider(); + _tokenSerializer = serviceProvider.GetRequiredService(); + + _token = new AntiforgeryToken() + { + IsCookieToken = false, + Username = "user@test.com", + ClaimUid = new BinaryBlob(AntiforgeryToken.ClaimUidBitLength), + AdditionalData = "additional-data-here" + }; + + _serializedToken = _tokenSerializer.Serialize(_token); + } + + [Benchmark] + public string Serialize() + { + return _tokenSerializer.Serialize(_token); + } + + [Benchmark] + public object Deserialize() + { + return _tokenSerializer.Deserialize(_serializedToken); + } +} diff --git a/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks.csproj b/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks.csproj new file mode 100644 index 000000000000..72cdf7ac580d --- /dev/null +++ b/src/Antiforgery/benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks/Microsoft.AspNetCore.Antiforgery.Benchmarks.csproj @@ -0,0 +1,27 @@ + + + + $(DefaultNetCoreTargetFramework) + Exe + true + true + false + $(DefineConstants);IS_BENCHMARKS + true + false + enable + + + + + + + + + + + + + + + diff --git a/src/Antiforgery/src/AntiforgeryServiceCollectionExtensions.cs b/src/Antiforgery/src/AntiforgeryServiceCollectionExtensions.cs index d40bf9354798..d2cd8332f2ac 100644 --- a/src/Antiforgery/src/AntiforgeryServiceCollectionExtensions.cs +++ b/src/Antiforgery/src/AntiforgeryServiceCollectionExtensions.cs @@ -3,7 +3,6 @@ using Microsoft.AspNetCore.Antiforgery; using Microsoft.Extensions.DependencyInjection.Extensions; -using Microsoft.Extensions.ObjectPool; using Microsoft.Extensions.Options; namespace Microsoft.Extensions.DependencyInjection; @@ -34,14 +33,6 @@ public static IServiceCollection AddAntiforgery(this IServiceCollection services services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(); - services.TryAddSingleton(); - - services.TryAddSingleton>(serviceProvider => - { - var provider = serviceProvider.GetRequiredService(); - var policy = new AntiforgerySerializationContextPooledObjectPolicy(); - return provider.Create(policy); - }); return services; } diff --git a/src/Antiforgery/src/Internal/AntiforgerySerializationContext.cs b/src/Antiforgery/src/Internal/AntiforgerySerializationContext.cs deleted file mode 100644 index a4e416155a5b..000000000000 --- a/src/Antiforgery/src/Internal/AntiforgerySerializationContext.cs +++ /dev/null @@ -1,120 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Text; - -namespace Microsoft.AspNetCore.Antiforgery; - -internal sealed class AntiforgerySerializationContext -{ - // Avoid allocating 256 bytes (the default) and using 18 (the AntiforgeryToken minimum). 64 bytes is enough for - // a short username or claim UID and some additional data. MemoryStream bumps capacity to 256 if exceeded. - private const int InitialStreamSize = 64; - - // Don't let the MemoryStream grow beyond 1 MB. - private const int MaximumStreamSize = 0x100000; - - // Start _chars off with length 256 (18 bytes is protected into 116 bytes then encoded into 156 characters). - // Double length from there if necessary. - private const int InitialCharsLength = 256; - - // Don't let _chars grow beyond 512k characters. - private const int MaximumCharsLength = 0x80000; - - private char[]? _chars; - private MemoryStream? _stream; - private BinaryReader? _reader; - private BinaryWriter? _writer; - - public MemoryStream Stream - { - get - { - if (_stream == null) - { - _stream = new MemoryStream(InitialStreamSize); - } - - return _stream; - } - private set - { - _stream = value; - } - } - - public BinaryReader Reader - { - get - { - if (_reader == null) - { - // Leave open to clean up correctly even if only one of the reader or writer has been created. - _reader = new BinaryReader(Stream, Encoding.UTF8, leaveOpen: true); - } - - return _reader; - } - private set - { - _reader = value; - } - } - - public BinaryWriter Writer - { - get - { - if (_writer == null) - { - // Leave open to clean up correctly even if only one of the reader or writer has been created. - _writer = new BinaryWriter(Stream, Encoding.UTF8, leaveOpen: true); - } - - return _writer; - } - private set - { - _writer = value; - } - } - - public char[] GetChars(int count) - { - if (_chars == null || _chars.Length < count) - { - var newLength = _chars == null ? InitialCharsLength : checked(_chars.Length * 2); - while (newLength < count) - { - newLength = checked(newLength * 2); - } - - _chars = new char[newLength]; - } - - return _chars; - } - - public void Reset() - { - if (_chars != null && _chars.Length > MaximumCharsLength) - { - _chars = null; - } - - if (_stream != null) - { - if (Stream.Capacity > MaximumStreamSize) - { - _stream = null; - _reader = null; - _writer = null; - } - else - { - Stream.Position = 0L; - Stream.SetLength(0L); - } - } - } -} diff --git a/src/Antiforgery/src/Internal/AntiforgerySerializationContextPooledObjectPolicy.cs b/src/Antiforgery/src/Internal/AntiforgerySerializationContextPooledObjectPolicy.cs deleted file mode 100644 index dad41c367754..000000000000 --- a/src/Antiforgery/src/Internal/AntiforgerySerializationContextPooledObjectPolicy.cs +++ /dev/null @@ -1,21 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.Extensions.ObjectPool; - -namespace Microsoft.AspNetCore.Antiforgery; - -internal sealed class AntiforgerySerializationContextPooledObjectPolicy : IPooledObjectPolicy -{ - public AntiforgerySerializationContext Create() - { - return new AntiforgerySerializationContext(); - } - - public bool Return(AntiforgerySerializationContext obj) - { - obj.Reset(); - - return true; - } -} diff --git a/src/Antiforgery/src/Internal/AntiforgeryToken.cs b/src/Antiforgery/src/Internal/AntiforgeryToken.cs index a7fc0072c548..5e530fe7e9ec 100644 --- a/src/Antiforgery/src/Internal/AntiforgeryToken.cs +++ b/src/Antiforgery/src/Internal/AntiforgeryToken.cs @@ -25,14 +25,11 @@ public string AdditionalData public bool IsCookieToken { get; set; } - public BinaryBlob? SecurityToken + public BinaryBlob SecurityToken { get { - if (_securityToken == null) - { - _securityToken = new BinaryBlob(SecurityTokenBitLength); - } + _securityToken ??= new BinaryBlob(SecurityTokenBitLength); return _securityToken; } set diff --git a/src/Antiforgery/src/Internal/BinaryBlob.cs b/src/Antiforgery/src/Internal/BinaryBlob.cs index 1196ffd0bdfe..719eb9f24b38 100644 --- a/src/Antiforgery/src/Internal/BinaryBlob.cs +++ b/src/Antiforgery/src/Internal/BinaryBlob.cs @@ -35,6 +35,8 @@ public BinaryBlob(int bitLength, byte[] data) _data = data; } + internal int Length => _data.Length; + public int BitLength { get diff --git a/src/Antiforgery/src/Internal/DefaultAntiforgery.cs b/src/Antiforgery/src/Internal/DefaultAntiforgery.cs index 129d1f7e3104..e4c27ccd1b73 100644 --- a/src/Antiforgery/src/Internal/DefaultAntiforgery.cs +++ b/src/Antiforgery/src/Internal/DefaultAntiforgery.cs @@ -266,7 +266,7 @@ private void CheckSSLConfig(HttpContext context) private static IAntiforgeryFeature GetAntiforgeryFeature(HttpContext httpContext) { var antiforgeryFeature = httpContext.Features.Get(); - if (antiforgeryFeature == null) + if (antiforgeryFeature is null) { antiforgeryFeature = new AntiforgeryFeature(); httpContext.Features.Set(antiforgeryFeature); diff --git a/src/Antiforgery/src/Internal/DefaultAntiforgeryTokenGenerator.cs b/src/Antiforgery/src/Internal/DefaultAntiforgeryTokenGenerator.cs index 06a08914924d..01342c72b49c 100644 --- a/src/Antiforgery/src/Internal/DefaultAntiforgeryTokenGenerator.cs +++ b/src/Antiforgery/src/Internal/DefaultAntiforgeryTokenGenerator.cs @@ -59,7 +59,10 @@ public AntiforgeryToken GenerateRequestToken( if (authenticatedIdentity != null) { isIdentityAuthenticated = true; - requestToken.ClaimUid = GetClaimUidBlob(_claimUidExtractor.ExtractClaimUid(httpContext.User)); + + var claimUidBytes = new byte[32]; + var extractClaimUidBytesResult = _claimUidExtractor.TryExtractClaimUidBytes(httpContext.User, claimUidBytes); + requestToken.ClaimUid = extractClaimUidBytesResult ? new BinaryBlob(256, claimUidBytes) : null; if (requestToken.ClaimUid == null) { @@ -137,16 +140,22 @@ public bool TryValidateTokenSet( // Is the incoming token meant for the current user? var currentUsername = string.Empty; - BinaryBlob? currentClaimUid = null; + + var extractedClaimUidBytes = false; + Span currentClaimUidBytes = stackalloc byte[32]; var authenticatedIdentity = GetAuthenticatedIdentity(httpContext.User); if (authenticatedIdentity != null) { - currentClaimUid = GetClaimUidBlob(_claimUidExtractor.ExtractClaimUid(httpContext.User)); - if (currentClaimUid == null) + var requestTokenClaimUidLength = requestToken.ClaimUid?.Length; + if (requestTokenClaimUidLength is null) { currentUsername = authenticatedIdentity.Name ?? string.Empty; } + else + { + extractedClaimUidBytes = _claimUidExtractor.TryExtractClaimUidBytes(httpContext.User, currentClaimUidBytes); + } } // OpenID and other similar authentication schemes use URIs for the username. @@ -164,14 +173,14 @@ public bool TryValidateTokenSet( return false; } - if (!object.Equals(requestToken.ClaimUid, currentClaimUid)) + if (!AreIdenticalClaimUids(requestToken, extractedClaimUidBytes, currentClaimUidBytes)) { message = Resources.AntiforgeryToken_ClaimUidMismatch; return false; } // Is the AdditionalData valid? - if (_additionalDataProvider != null && + if (_additionalDataProvider is not null && !_additionalDataProvider.ValidateAdditionalData(httpContext, requestToken.AdditionalData)) { message = Resources.AntiforgeryToken_AdditionalDataCheckFailed; @@ -180,16 +189,21 @@ public bool TryValidateTokenSet( message = null; return true; - } - private static BinaryBlob? GetClaimUidBlob(string? base64ClaimUid) - { - if (base64ClaimUid == null) + static bool AreIdenticalClaimUids(AntiforgeryToken token, bool claimUidBytesExtracted, Span claimUidBytes) { - return null; - } + if (token.ClaimUid is null) + { + return !claimUidBytesExtracted; + } - return new BinaryBlob(256, Convert.FromBase64String(base64ClaimUid)); + if (token.ClaimUid.Length != claimUidBytes.Length) + { + return false; + } + + return token.ClaimUid.GetData().SequenceEqual(claimUidBytes); + } } private static ClaimsIdentity? GetAuthenticatedIdentity(ClaimsPrincipal? claimsPrincipal) @@ -199,8 +213,7 @@ public bool TryValidateTokenSet( return null; } - var identitiesList = claimsPrincipal.Identities as List; - if (identitiesList != null) + if (claimsPrincipal.Identities is List identitiesList) { for (var i = 0; i < identitiesList.Count; i++) { diff --git a/src/Antiforgery/src/Internal/DefaultAntiforgeryTokenSerializer.cs b/src/Antiforgery/src/Internal/DefaultAntiforgeryTokenSerializer.cs index 0cc346b0de42..d91380beb267 100644 --- a/src/Antiforgery/src/Internal/DefaultAntiforgeryTokenSerializer.cs +++ b/src/Antiforgery/src/Internal/DefaultAntiforgeryTokenSerializer.cs @@ -1,9 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; +using System.Buffers.Text; +using System.Text; using Microsoft.AspNetCore.DataProtection; -using Microsoft.AspNetCore.WebUtilities; -using Microsoft.Extensions.ObjectPool; +using Microsoft.AspNetCore.Shared; namespace Microsoft.AspNetCore.Antiforgery; @@ -12,47 +14,63 @@ internal sealed class DefaultAntiforgeryTokenSerializer : IAntiforgeryTokenSeria private const string Purpose = "Microsoft.AspNetCore.Antiforgery.AntiforgeryToken.v1"; private const byte TokenVersion = 0x01; - private readonly IDataProtector _cryptoSystem; - private readonly ObjectPool _pool; + private readonly IDataProtector _defaultCryptoSystem; + private readonly ISpanDataProtector? _perfCryptoSystem; - public DefaultAntiforgeryTokenSerializer( - IDataProtectionProvider provider, - ObjectPool pool) + public DefaultAntiforgeryTokenSerializer(IDataProtectionProvider provider) { ArgumentNullException.ThrowIfNull(provider); - ArgumentNullException.ThrowIfNull(pool); - _cryptoSystem = provider.CreateProtector(Purpose); - _pool = pool; + _defaultCryptoSystem = provider.CreateProtector(Purpose); + _perfCryptoSystem = _defaultCryptoSystem as ISpanDataProtector; } public AntiforgeryToken Deserialize(string serializedToken) { - var serializationContext = _pool.Get(); - + byte[]? tokenBytesRent = null; Exception? innerException = null; try { - var count = serializedToken.Length; - var charsRequired = WebEncoders.GetArraySizeRequiredToDecode(count); - var chars = serializationContext.GetChars(charsRequired); - var tokenBytes = WebEncoders.Base64UrlDecode( - serializedToken, - offset: 0, - buffer: chars, - bufferOffset: 0, - count: count); - - var unprotectedBytes = _cryptoSystem.Unprotect(tokenBytes); - var stream = serializationContext.Stream; - stream.Write(unprotectedBytes, offset: 0, count: unprotectedBytes.Length); - stream.Position = 0L; - - var reader = serializationContext.Reader; - var token = Deserialize(reader); - if (token != null) + var tokenDecodedSize = Base64Url.GetMaxDecodedLength(serializedToken.Length); + + var rent = tokenDecodedSize < 256 + ? stackalloc byte[256] + : (tokenBytesRent = ArrayPool.Shared.Rent(tokenDecodedSize)); + var tokenBytes = rent[..tokenDecodedSize]; + + var status = Base64Url.DecodeFromChars(serializedToken, tokenBytes, out int charsConsumed, out int bytesWritten); + if (status is not OperationStatus.Done) + { + throw new FormatException("Failed to decode token as Base64 char sequence."); + } + + var tokenBytesDecoded = tokenBytes[..bytesWritten]; + + if (_perfCryptoSystem is not null) + { + var protectBuffer = new RefPooledArrayBufferWriter(stackalloc byte[256]); + try + { + _perfCryptoSystem.Unprotect(tokenBytesDecoded, ref protectBuffer); + var token = Deserialize(protectBuffer.WrittenSpan); + if (token is not null) + { + return token; + } + } + finally + { + protectBuffer.Dispose(); + } + } + else { - return token; + var unprotectedBytes = _defaultCryptoSystem.Unprotect(tokenBytesDecoded.ToArray()); + var token = Deserialize(unprotectedBytes); + if (token is not null) + { + return token; + } } } catch (Exception ex) @@ -62,7 +80,10 @@ public AntiforgeryToken Deserialize(string serializedToken) } finally { - _pool.Return(serializationContext); + if (tokenBytesRent is not null) + { + ArrayPool.Shared.Return(tokenBytesRent); + } } // if we reached this point, something went wrong deserializing @@ -81,39 +102,77 @@ public AntiforgeryToken Deserialize(string serializedToken) * | `- Username: UTF-8 string with 7-bit integer length prefix * `- AdditionalData: UTF-8 string with 7-bit integer length prefix */ - private static AntiforgeryToken? Deserialize(BinaryReader reader) + private static AntiforgeryToken? Deserialize(ReadOnlySpan tokenBytes) { - // we can only consume tokens of the same serialized version that we generate - var embeddedVersion = reader.ReadByte(); + // Minimum lengths: + // - Cookie token: 1 (version) + 16 (securityToken) + 1 (isCookieToken) = 18 bytes + // - Request token (username): 18 + 1 (isClaimsBased) + 1 (username prefix) + 1 (additionalData prefix) = 21 bytes + // - Request token (claims): 18 + 1 (isClaimsBased) + 32 (claimUid) + 1 (additionalData prefix) = 52 bytes + const int minCookieTokenLength = 1 + (AntiforgeryToken.SecurityTokenBitLength / 8) + 1; // 18 bytes + const int minRequestTokenLength = minCookieTokenLength + 1 + 1 + 1; // 21 bytes (username-based) + + if (tokenBytes.Length < minCookieTokenLength) + { + return null; + } + + var offset = 0; + + var embeddedVersion = tokenBytes[offset++]; if (embeddedVersion != TokenVersion) { return null; } var deserializedToken = new AntiforgeryToken(); - var securityTokenBytes = reader.ReadBytes(AntiforgeryToken.SecurityTokenBitLength / 8); - deserializedToken.SecurityToken = - new BinaryBlob(AntiforgeryToken.SecurityTokenBitLength, securityTokenBytes); - deserializedToken.IsCookieToken = reader.ReadBoolean(); + + // Read SecurityToken (16 bytes) + const int securityTokenByteLength = AntiforgeryToken.SecurityTokenBitLength / 8; + deserializedToken.SecurityToken = new BinaryBlob( + AntiforgeryToken.SecurityTokenBitLength, + tokenBytes.Slice(offset, securityTokenByteLength).ToArray()); + offset += securityTokenByteLength; + + // Read IsCookieToken (1 byte) + deserializedToken.IsCookieToken = tokenBytes[offset++] != 0; if (!deserializedToken.IsCookieToken) { - var isClaimsBased = reader.ReadBoolean(); + // Validate minimum length for request token + if (tokenBytes.Length < minRequestTokenLength) + { + return null; + } + + // Read IsClaimsBased (1 byte) + var isClaimsBased = tokenBytes[offset++] != 0; if (isClaimsBased) { - var claimUidBytes = reader.ReadBytes(AntiforgeryToken.ClaimUidBitLength / 8); - deserializedToken.ClaimUid = new BinaryBlob(AntiforgeryToken.ClaimUidBitLength, claimUidBytes); + // Read ClaimUid (32 bytes) + const int claimUidByteLength = AntiforgeryToken.ClaimUidBitLength / 8; + if (tokenBytes.Length < offset + claimUidByteLength + 1) // +1 for additionalData prefix + { + return null; + } + + deserializedToken.ClaimUid = new BinaryBlob( + AntiforgeryToken.ClaimUidBitLength, + tokenBytes.Slice(offset, claimUidByteLength).ToArray()); + offset += claimUidByteLength; } else { - deserializedToken.Username = reader.ReadString(); + // Read Username (7-bit encoded length prefix + UTF-8 string) + offset += tokenBytes[offset..].Read7BitEncodedString(out var username); + deserializedToken.Username = username; } - deserializedToken.AdditionalData = reader.ReadString(); + offset += tokenBytes[offset..].Read7BitEncodedString(out var additionalData); + deserializedToken.AdditionalData = additionalData; } - // if there's still unconsumed data in the stream, fail - if (reader.BaseStream.ReadByte() != -1) + // if there's still unconsumed data in the span, fail + if (offset != tokenBytes.Length) { return null; } @@ -126,50 +185,87 @@ public string Serialize(AntiforgeryToken token) { ArgumentNullException.ThrowIfNull(token); - var serializationContext = _pool.Get(); + var securityTokenBytes = token.SecurityToken.GetData(); + var claimUidBytes = token.ClaimUid?.GetData(); + + var totalSize = + 1 // TokenVersion + + securityTokenBytes.Length + // SecurityToken + + 1; // IsCookieToken + if (!token.IsCookieToken) + { + totalSize += 1; // isClaimsBased + + if (token.ClaimUid is not null) + { + totalSize += claimUidBytes!.Length; + } + else + { + var usernameByteCount = Encoding.UTF8.GetByteCount(token.Username!); + totalSize += usernameByteCount.Measure7BitEncodedUIntLength() + usernameByteCount; + } + + var additionalDataByteCount = Encoding.UTF8.GetByteCount(token.AdditionalData); + totalSize += additionalDataByteCount.Measure7BitEncodedUIntLength() + additionalDataByteCount; + } + + byte[]? tokenBytesRent = null; + + var rent = totalSize < 256 + ? stackalloc byte[255] + : (tokenBytesRent = ArrayPool.Shared.Rent(totalSize)); + var tokenBytes = rent[..totalSize]; try { - var writer = serializationContext.Writer; - writer.Write(TokenVersion); - writer.Write(token.SecurityToken!.GetData()); - writer.Write(token.IsCookieToken); + var offset = 0; + tokenBytes[offset++] = TokenVersion; + securityTokenBytes.CopyTo(tokenBytes.Slice(offset, securityTokenBytes.Length)); + offset += securityTokenBytes.Length; + tokenBytes[offset++] = token.IsCookieToken ? (byte)1 : (byte)0; if (!token.IsCookieToken) { if (token.ClaimUid != null) { - writer.Write(true /* isClaimsBased */); - writer.Write(token.ClaimUid.GetData()); + tokenBytes[offset++] = 1; // isClaimsBased + claimUidBytes!.CopyTo(tokenBytes.Slice(offset, claimUidBytes!.Length)); + offset += claimUidBytes.Length; } else { - writer.Write(false /* isClaimsBased */); - writer.Write(token.Username!); + tokenBytes[offset++] = 0; // isClaimsBased + offset += tokenBytes[offset..].Write7BitEncodedString(token.Username!); } - - writer.Write(token.AdditionalData); + offset += tokenBytes[offset..].Write7BitEncodedString(token.AdditionalData); } - writer.Flush(); - var stream = serializationContext.Stream; - var bytes = _cryptoSystem.Protect(stream.ToArray()); - - var count = bytes.Length; - var charsRequired = WebEncoders.GetArraySizeRequiredToEncode(count); - var chars = serializationContext.GetChars(charsRequired); - var outputLength = WebEncoders.Base64UrlEncode( - bytes, - offset: 0, - output: chars, - outputOffset: 0, - count: count); - - return new string(chars, startIndex: 0, length: outputLength); + if (_perfCryptoSystem is not null) + { + var protectBuffer = new RefPooledArrayBufferWriter(stackalloc byte[255]); + try + { + _perfCryptoSystem.Protect(tokenBytes, ref protectBuffer); + return Base64Url.EncodeToString(protectBuffer.WrittenSpan); + } + finally + { + protectBuffer.Dispose(); + } + } + else + { + var protectedBytes = _defaultCryptoSystem.Protect(tokenBytes.ToArray()); + return Base64Url.EncodeToString(protectedBytes); + } } finally { - _pool.Return(serializationContext); + if (tokenBytesRent is not null) + { + ArrayPool.Shared.Return(tokenBytesRent); + } } } } diff --git a/src/Antiforgery/src/Internal/DefaultClaimUidExtractor.cs b/src/Antiforgery/src/Internal/DefaultClaimUidExtractor.cs index 31ffcd37f0f9..114b3c8bef1a 100644 --- a/src/Antiforgery/src/Internal/DefaultClaimUidExtractor.cs +++ b/src/Antiforgery/src/Internal/DefaultClaimUidExtractor.cs @@ -1,10 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Diagnostics; using System.Security.Claims; using System.Security.Cryptography; -using Microsoft.Extensions.ObjectPool; +using Microsoft.AspNetCore.Shared; namespace Microsoft.AspNetCore.Antiforgery; @@ -13,35 +14,26 @@ namespace Microsoft.AspNetCore.Antiforgery; /// internal sealed class DefaultClaimUidExtractor : IClaimUidExtractor { - private readonly ObjectPool _pool; - - public DefaultClaimUidExtractor(ObjectPool pool) - { - _pool = pool; - } - - /// - public string? ExtractClaimUid(ClaimsPrincipal claimsPrincipal) + public bool TryExtractClaimUidBytes(ClaimsPrincipal claimsPrincipal, Span destination) { Debug.Assert(claimsPrincipal != null); var uniqueIdentifierParameters = GetUniqueIdentifierParameters(claimsPrincipal.Identities); - if (uniqueIdentifierParameters == null) + if (uniqueIdentifierParameters is null) { - // No authenticated identities containing claims found. - return null; + return false; } - var claimUidBytes = ComputeSha256(uniqueIdentifierParameters); - return Convert.ToBase64String(claimUidBytes); + ComputeSha256(uniqueIdentifierParameters, destination); + return true; } - public static IList? GetUniqueIdentifierParameters(IEnumerable claimsIdentities) + public static List? GetUniqueIdentifierParameters(IEnumerable claimsIdentities) { var identitiesList = claimsIdentities as List; if (identitiesList == null) { - identitiesList = new List(claimsIdentities); + identitiesList = [.. claimsIdentities]; } for (var i = 0; i < identitiesList.Count; i++) @@ -56,36 +48,36 @@ public DefaultClaimUidExtractor(ObjectPool pool claim => string.Equals("sub", claim.Type, StringComparison.Ordinal)); if (subClaim != null && !string.IsNullOrEmpty(subClaim.Value)) { - return new string[] - { - subClaim.Type, - subClaim.Value, - subClaim.Issuer - }; + return + [ + subClaim.Type, + subClaim.Value, + subClaim.Issuer + ]; } var nameIdentifierClaim = identity.FindFirst( claim => string.Equals(ClaimTypes.NameIdentifier, claim.Type, StringComparison.Ordinal)); if (nameIdentifierClaim != null && !string.IsNullOrEmpty(nameIdentifierClaim.Value)) { - return new string[] - { - nameIdentifierClaim.Type, - nameIdentifierClaim.Value, - nameIdentifierClaim.Issuer - }; + return + [ + nameIdentifierClaim.Type, + nameIdentifierClaim.Value, + nameIdentifierClaim.Issuer + ]; } var upnClaim = identity.FindFirst( claim => string.Equals(ClaimTypes.Upn, claim.Type, StringComparison.Ordinal)); if (upnClaim != null && !string.IsNullOrEmpty(upnClaim.Value)) { - return new string[] - { - upnClaim.Type, - upnClaim.Value, - upnClaim.Issuer - }; + return + [ + upnClaim.Type, + upnClaim.Value, + upnClaim.Issuer + ]; } } @@ -119,33 +111,42 @@ public DefaultClaimUidExtractor(ObjectPool pool return identifierParameters; } - private byte[] ComputeSha256(IEnumerable parameters) + private static void ComputeSha256(List parameters, Span destination) { - var serializationContext = _pool.Get(); + Debug.Assert(destination.Length >= SHA256.HashSizeInBytes); - try + // Calculate total size needed for serialization + var totalSize = 0; + for (var i = 0; i < parameters.Count; i++) { - var writer = serializationContext.Writer; - foreach (string parameter in parameters) - { - writer.Write(parameter); // also writes the length as a prefix; unambiguous - } + var byteCount = System.Text.Encoding.UTF8.GetByteCount(parameters[i]); + totalSize += byteCount.Measure7BitEncodedUIntLength() + byteCount; + } - writer.Flush(); + // Use stackalloc for small buffers, otherwise rent + byte[]? rentedBuffer = null; + var buffer = totalSize <= 256 + ? stackalloc byte[256] + : (rentedBuffer = ArrayPool.Shared.Rent(totalSize)); - bool success = serializationContext.Stream.TryGetBuffer(out ArraySegment buffer); - if (!success) + try + { + var span = buffer[..totalSize]; + var offset = 0; + for (var i = 0; i < parameters.Count; i++) { - throw new InvalidOperationException(); + offset += span.Slice(offset).Write7BitEncodedString(parameters[i]); } - var bytes = SHA256.HashData(buffer); - - return bytes; + // Hash directly into destination (SHA256 output is always 32 bytes) + SHA256.HashData(span.Slice(0, offset), destination); } finally { - _pool.Return(serializationContext); + if (rentedBuffer is not null) + { + ArrayPool.Shared.Return(rentedBuffer); + } } } } diff --git a/src/Antiforgery/src/Internal/IClaimUidExtractor.cs b/src/Antiforgery/src/Internal/IClaimUidExtractor.cs index de8c536c34b5..f16ed16a2cf8 100644 --- a/src/Antiforgery/src/Internal/IClaimUidExtractor.cs +++ b/src/Antiforgery/src/Internal/IClaimUidExtractor.cs @@ -11,9 +11,7 @@ namespace Microsoft.AspNetCore.Antiforgery; internal interface IClaimUidExtractor { /// - /// Extracts claims identifier. + /// Extracts claims identifier, and writes into buffer. /// - /// The . - /// The claims identifier. - string? ExtractClaimUid(ClaimsPrincipal claimsPrincipal); + bool TryExtractClaimUidBytes(ClaimsPrincipal claimsPrincipal, Span destination); } diff --git a/src/Antiforgery/src/Microsoft.AspNetCore.Antiforgery.csproj b/src/Antiforgery/src/Microsoft.AspNetCore.Antiforgery.csproj index a70a9e37c045..369284a835b8 100644 --- a/src/Antiforgery/src/Microsoft.AspNetCore.Antiforgery.csproj +++ b/src/Antiforgery/src/Microsoft.AspNetCore.Antiforgery.csproj @@ -1,4 +1,4 @@ - + An antiforgery system for ASP.NET Core designed to generate and validate tokens to prevent Cross-Site Request Forgery attacks. @@ -25,6 +25,12 @@ - + + + + + + + diff --git a/src/Antiforgery/test/DefaultAntiforgeryTokenGeneratorTest.cs b/src/Antiforgery/test/DefaultAntiforgeryTokenGeneratorTest.cs index 3691b24aa3b0..0e1fea0d2bf7 100644 --- a/src/Antiforgery/test/DefaultAntiforgeryTokenGeneratorTest.cs +++ b/src/Antiforgery/test/DefaultAntiforgeryTokenGeneratorTest.cs @@ -4,6 +4,7 @@ #nullable disable using System.Security.Claims; using System.Security.Cryptography; +using System.Security.Principal; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.InternalTesting; using Moq; @@ -85,7 +86,7 @@ public void GenerateRequestToken_AuthenticatedWithoutUsernameAndNoAdditionalData httpContext.User = new ClaimsPrincipal(new MyAuthenticatedIdentityWithoutUsername()); var options = new AntiforgeryOptions(); - var claimUidExtractor = new Mock().Object; + var claimUidExtractor = new DummyClaimUidExtractor(httpContext.User.Identity!, blob: null, failsExtraction: true); var tokenProvider = new DefaultAntiforgeryTokenGenerator( claimUidExtractor: claimUidExtractor, @@ -119,7 +120,7 @@ public void GenerateRequestToken_AuthenticatedWithoutUsername_WithAdditionalData mockAdditionalDataProvider.Setup(o => o.GetAdditionalData(httpContext)) .Returns("additional-data"); - var claimUidExtractor = new Mock().Object; + var claimUidExtractor = new DummyClaimUidExtractor(httpContext.User.Identity!, blob: null, failsExtraction: true); var tokenProvider = new DefaultAntiforgeryTokenGenerator( claimUidExtractor: claimUidExtractor, @@ -152,13 +153,8 @@ public void GenerateRequestToken_ClaimsBasedIdentity() var base64ClaimUId = Convert.ToBase64String(data); var expectedClaimUid = new BinaryBlob(256, data); - var mockClaimUidExtractor = new Mock(); - mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(It.Is(c => c.Identity == identity))) - .Returns(base64ClaimUId); - - var tokenProvider = new DefaultAntiforgeryTokenGenerator( - claimUidExtractor: mockClaimUidExtractor.Object, - additionalDataProvider: null); + var claimUidExtractor = new DummyClaimUidExtractor(identity, expectedClaimUid); + var tokenProvider = new DefaultAntiforgeryTokenGenerator(claimUidExtractor, additionalDataProvider: null); // Act var fieldToken = tokenProvider.GenerateRequestToken(httpContext, cookieToken); @@ -187,7 +183,7 @@ public void GenerateRequestToken_RegularUserWithUsername() httpContext.User = new ClaimsPrincipal(mockIdentity.Object); - var claimUidExtractor = new Mock().Object; + var claimUidExtractor = new DummyClaimUidExtractor(httpContext.User.Identity!, blob: null, failsExtraction: true); var tokenProvider = new DefaultAntiforgeryTokenGenerator( claimUidExtractor: claimUidExtractor, @@ -401,21 +397,15 @@ public void TryValidateTokenSet_UsernameMismatch(string identityUsername, string IsCookieToken = false }; - var mockClaimUidExtractor = new Mock(); - mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(It.Is(c => c.Identity == identity))) - .Returns((string)null); + var claimUidExtractor = new DummyClaimUidExtractor(identity, null, failsExtraction: true); + var tokenProvider = new DefaultAntiforgeryTokenGenerator(claimUidExtractor, additionalDataProvider: null); - var tokenProvider = new DefaultAntiforgeryTokenGenerator( - claimUidExtractor: mockClaimUidExtractor.Object, - additionalDataProvider: null); - - string expectedMessage = + var expectedMessage = $"The provided antiforgery token was meant for user \"{embeddedUsername}\", " + $"but the current user is \"{identityUsername}\"."; // Act - string message; - var result = tokenProvider.TryValidateTokenSet(httpContext, cookieToken, fieldtoken, out message); + var result = tokenProvider.TryValidateTokenSet(httpContext, cookieToken, fieldtoken, out var message); // Assert Assert.False(result); @@ -439,21 +429,15 @@ public void TryValidateTokenSet_ClaimUidMismatch() }; var differentToken = new BinaryBlob(256); - var mockClaimUidExtractor = new Mock(); - mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(It.Is(c => c.Identity == identity))) - .Returns(Convert.ToBase64String(differentToken.GetData())); + var dummyClaimUidExtractor = new DummyClaimUidExtractor(identity, differentToken); + var tokenProvider = new DefaultAntiforgeryTokenGenerator(claimUidExtractor: dummyClaimUidExtractor, additionalDataProvider: null); - var tokenProvider = new DefaultAntiforgeryTokenGenerator( - claimUidExtractor: mockClaimUidExtractor.Object, - additionalDataProvider: null); - - string expectedMessage = + var expectedMessage = "The provided antiforgery token was meant for a different " + "claims-based user than the current user."; // Act - string message; - var result = tokenProvider.TryValidateTokenSet(httpContext, cookieToken, fieldtoken, out message); + var result = tokenProvider.TryValidateTokenSet(httpContext, cookieToken, fieldtoken, out var message); // Assert Assert.False(result); @@ -581,17 +565,13 @@ public void TryValidateTokenSet_Success_ClaimsBasedUser() ClaimUid = new BinaryBlob(256) }; - var mockClaimUidExtractor = new Mock(); - mockClaimUidExtractor.Setup(o => o.ExtractClaimUid(It.Is(c => c.Identity == identity))) - .Returns(Convert.ToBase64String(fieldtoken.ClaimUid.GetData())); - + var dummyClaimUidExtractor = new DummyClaimUidExtractor(identity, fieldtoken.ClaimUid); var tokenProvider = new DefaultAntiforgeryTokenGenerator( - claimUidExtractor: mockClaimUidExtractor.Object, + claimUidExtractor: dummyClaimUidExtractor, additionalDataProvider: null); // Act - string message; - var result = tokenProvider.TryValidateTokenSet(httpContext, cookieToken, fieldtoken, out message); + var result = tokenProvider.TryValidateTokenSet(httpContext, cookieToken, fieldtoken, out var message); // Assert Assert.True(result); @@ -616,5 +596,35 @@ public override string Name get { return String.Empty; } } } + + private class DummyClaimUidExtractor : IClaimUidExtractor + { + private readonly IIdentity _identity; + private readonly BinaryBlob _differentToken; + private readonly bool _failsExtraction; + + public DummyClaimUidExtractor(IIdentity identity, BinaryBlob blob, bool failsExtraction = false) + { + _identity = identity; + _differentToken = blob; + _failsExtraction = failsExtraction; + } + + public bool TryExtractClaimUidBytes(ClaimsPrincipal claimsPrincipal, Span destination) + { + if (_failsExtraction) + { + return false; + } + + if (claimsPrincipal.Identity == _identity) + { + _differentToken.GetData().CopyTo(destination); + return true; + } + + return false; + } + } } #nullable restore diff --git a/src/Antiforgery/test/DefaultAntiforgeryTokenSerializerTest.cs b/src/Antiforgery/test/DefaultAntiforgeryTokenSerializerTest.cs index ac1a32675186..4d6bbb6c7d74 100644 --- a/src/Antiforgery/test/DefaultAntiforgeryTokenSerializerTest.cs +++ b/src/Antiforgery/test/DefaultAntiforgeryTokenSerializerTest.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Buffers; using Microsoft.AspNetCore.DataProtection; using Microsoft.Extensions.ObjectPool; using Moq; @@ -10,10 +12,8 @@ namespace Microsoft.AspNetCore.Antiforgery.Internal; public class DefaultAntiforgeryTokenSerializerTest { private static readonly Mock _dataProtector = GetDataProtector(); - private static readonly BinaryBlob _claimUid = new BinaryBlob(256, new byte[] { 0x6F, 0x16, 0x48, 0xE9, 0x72, 0x49, 0xAA, 0x58, 0x75, 0x40, 0x36, 0xA6, 0x7E, 0x24, 0x8C, 0xF0, 0x44, 0xF0, 0x7E, 0xCF, 0xB0, 0xED, 0x38, 0x75, 0x56, 0xCE, 0x02, 0x9A, 0x4F, 0x9A, 0x40, 0xE0 }); - private static readonly BinaryBlob _securityToken = new BinaryBlob(128, new byte[] { 0x70, 0x5E, 0xED, 0xCC, 0x7D, 0x42, 0xF1, 0xD6, 0xB3, 0xB9, 0x8A, 0x59, 0x36, 0x25, 0xBB, 0x4C }); - private static readonly ObjectPool _pool = - new DefaultObjectPoolProvider().Create(new AntiforgerySerializationContextPooledObjectPolicy()); + private static readonly BinaryBlob _claimUid = new BinaryBlob(256, [0x6F, 0x16, 0x48, 0xE9, 0x72, 0x49, 0xAA, 0x58, 0x75, 0x40, 0x36, 0xA6, 0x7E, 0x24, 0x8C, 0xF0, 0x44, 0xF0, 0x7E, 0xCF, 0xB0, 0xED, 0x38, 0x75, 0x56, 0xCE, 0x02, 0x9A, 0x4F, 0x9A, 0x40, 0xE0]); + private static readonly BinaryBlob _securityToken = new BinaryBlob(128, [0x70, 0x5E, 0xED, 0xCC, 0x7D, 0x42, 0xF1, 0xD6, 0xB3, 0xB9, 0x8A, 0x59, 0x36, 0x25, 0xBB, 0x4C]); private const byte _salt = 0x05; [Theory] @@ -44,7 +44,7 @@ public class DefaultAntiforgeryTokenSerializerTest public void Deserialize_BadToken_Throws(string serializedToken) { // Arrange - var testSerializer = new DefaultAntiforgeryTokenSerializer(_dataProtector.Object, _pool); + var testSerializer = new DefaultAntiforgeryTokenSerializer(_dataProtector.Object); // Act & assert var ex = Assert.Throws(() => testSerializer.Deserialize(serializedToken)); @@ -55,7 +55,7 @@ public void Deserialize_BadToken_Throws(string serializedToken) public void Serialize_FieldToken_WithClaimUid_TokenRoundTripSuccessful() { // Arrange - var testSerializer = new DefaultAntiforgeryTokenSerializer(_dataProtector.Object, _pool); + var testSerializer = new DefaultAntiforgeryTokenSerializer(_dataProtector.Object); //"01" // Version //+ "705EEDCC7D42F1D6B3B98A593625BB4C" // SecurityToken @@ -85,7 +85,7 @@ public void Serialize_FieldToken_WithClaimUid_TokenRoundTripSuccessful() public void Serialize_FieldToken_WithUsername_TokenRoundTripSuccessful() { // Arrange - var testSerializer = new DefaultAntiforgeryTokenSerializer(_dataProtector.Object, _pool); + var testSerializer = new DefaultAntiforgeryTokenSerializer(_dataProtector.Object); //"01" // Version //+ "705EEDCC7D42F1D6B3B98A593625BB4C" // SecurityToken @@ -116,7 +116,7 @@ public void Serialize_FieldToken_WithUsername_TokenRoundTripSuccessful() public void Serialize_CookieToken_TokenRoundTripSuccessful() { // Arrange - var testSerializer = new DefaultAntiforgeryTokenSerializer(_dataProtector.Object, _pool); + var testSerializer = new DefaultAntiforgeryTokenSerializer(_dataProtector.Object); //"01" // Version //+ "705EEDCC7D42F1D6B3B98A593625BB4C" // SecurityToken @@ -128,7 +128,7 @@ public void Serialize_CookieToken_TokenRoundTripSuccessful() }; // Act - string actualSerializedData = testSerializer.Serialize(token); + var actualSerializedData = testSerializer.Serialize(token); var deserializedToken = testSerializer.Deserialize(actualSerializedData); // Assert @@ -138,39 +138,15 @@ public void Serialize_CookieToken_TokenRoundTripSuccessful() private static Mock GetDataProtector() { - var mockCryptoSystem = new Mock(); - mockCryptoSystem.Setup(o => o.Protect(It.IsAny())) - .Returns(Protect) - .Verifiable(); - mockCryptoSystem.Setup(o => o.Unprotect(It.IsAny())) - .Returns(UnProtect) - .Verifiable(); + var testSpanDataProtector = new TestSpanDataProtector(); var provider = new Mock(); provider .Setup(p => p.CreateProtector(It.IsAny())) - .Returns(mockCryptoSystem.Object); + .Returns(testSpanDataProtector); return provider; } - private static byte[] Protect(byte[] data) - { - var input = new List(data); - input.Add(_salt); - return input.ToArray(); - } - - private static byte[] UnProtect(byte[] data) - { - var salt = data[data.Length - 1]; - if (salt != _salt) - { - throw new ArgumentException("Invalid salt value in data"); - } - - return data.Take(data.Length - 1).ToArray(); - } - private static void AssertTokensEqual(AntiforgeryToken expected, AntiforgeryToken actual) { Assert.NotNull(expected); @@ -181,4 +157,47 @@ private static void AssertTokensEqual(AntiforgeryToken expected, AntiforgeryToke Assert.Equal(expected.SecurityToken, actual.SecurityToken); Assert.Equal(expected.Username, actual.Username); } + + private sealed class TestSpanDataProtector : ISpanDataProtector + { + public IDataProtector CreateProtector(string purpose) => this; + + public void Protect(ReadOnlySpan plaintext, ref TWriter destination) where TWriter : IBufferWriter, allows ref struct + { + var result = ProtectImpl(plaintext.ToArray()); + var destinationSpan = destination.GetSpan(result.Length); + result.CopyTo(destinationSpan); + destination.Advance(result.Length); + } + + public void Unprotect(ReadOnlySpan protectedData, ref TWriter destination) + where TWriter : IBufferWriter, allows ref struct + { + var result = UnprotectImpl(protectedData.ToArray()); + var destinationSpan = destination.GetSpan(result.Length); + result.CopyTo(destinationSpan); + destination.Advance(result.Length); + } + + public byte[] Protect(byte[] plaintext) => ProtectImpl(plaintext); + + public byte[] Unprotect(byte[] protectedData) => UnprotectImpl(protectedData); + + private static byte[] ProtectImpl(byte[] data) + { + var input = new List(data); + input.Add(_salt); + return input.ToArray(); + } + + private static byte[] UnprotectImpl(byte[] data) + { + var salt = data[data.Length - 1]; + if (salt != _salt) + { + throw new ArgumentException("Invalid salt value in data"); + } + return data.Take(data.Length - 1).ToArray(); + } + } } diff --git a/src/Antiforgery/test/DefaultClaimUidExtractorTest.cs b/src/Antiforgery/test/DefaultClaimUidExtractorTest.cs index 8c8dd66334c9..cbc1f9aa9121 100644 --- a/src/Antiforgery/test/DefaultClaimUidExtractorTest.cs +++ b/src/Antiforgery/test/DefaultClaimUidExtractorTest.cs @@ -9,24 +9,26 @@ namespace Microsoft.AspNetCore.Antiforgery.Internal; public class DefaultClaimUidExtractorTest { - private static readonly ObjectPool _pool = - new DefaultObjectPoolProvider().Create(new AntiforgerySerializationContextPooledObjectPolicy()); + private readonly DefaultClaimUidExtractor _claimUidExtractor; + + public DefaultClaimUidExtractorTest() + { + _claimUidExtractor = new DefaultClaimUidExtractor(); + } [Fact] public void ExtractClaimUid_Unauthenticated() { - // Arrange - var extractor = new DefaultClaimUidExtractor(_pool); - var mockIdentity = new Mock(); mockIdentity.Setup(o => o.IsAuthenticated) .Returns(false); // Act - var claimUid = extractor.ExtractClaimUid(new ClaimsPrincipal(mockIdentity.Object)); + var claimUid = new byte[32]; + var result = _claimUidExtractor.TryExtractClaimUidBytes(new ClaimsPrincipal(mockIdentity.Object), claimUid); // Assert - Assert.Null(claimUid); + Assert.False(result); } [Fact] @@ -36,16 +38,15 @@ public void ExtractClaimUid_ClaimsIdentity() var mockIdentity = new Mock(); mockIdentity.Setup(o => o.IsAuthenticated) .Returns(true); - mockIdentity.Setup(o => o.Claims).Returns(new Claim[] { new Claim(ClaimTypes.Name, "someName") }); - - var extractor = new DefaultClaimUidExtractor(_pool); + mockIdentity.Setup(o => o.Claims).Returns([new Claim(ClaimTypes.Name, "someName")]); // Act - var claimUid = extractor.ExtractClaimUid(new ClaimsPrincipal(mockIdentity.Object)); + var claimUid = new byte[32]; + var result = _claimUidExtractor.TryExtractClaimUidBytes(new ClaimsPrincipal(mockIdentity.Object), claimUid); // Assert - Assert.NotNull(claimUid); - Assert.Equal("yhXE+2v4zSXHtRHmzm4cmrhZca2J0g7yTUwtUerdeF4=", claimUid); + Assert.True(result); + Assert.Equal("yhXE+2v4zSXHtRHmzm4cmrhZca2J0g7yTUwtUerdeF4=", Convert.ToBase64String(claimUid)); } [Fact] @@ -58,11 +59,10 @@ public void DefaultUniqueClaimTypes_NotPresent_SerializesAllClaimTypes() identity.AddClaim(new Claim(ClaimTypes.NameIdentifier, string.Empty)); // Arrange - var claimsIdentity = (ClaimsIdentity)identity; + var claimsIdentity = identity; // Act - var identiferParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { claimsIdentity })! - .ToArray(); + var identiferParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters([claimsIdentity])!.ToArray(); var claims = claimsIdentity.Claims.ToList(); claims.Sort((a, b) => string.Compare(a.Type, b.Type, StringComparison.Ordinal)); @@ -85,7 +85,7 @@ public void DefaultUniqueClaimTypes_Present() identity.AddClaim(new Claim(ClaimTypes.NameIdentifier, "nameIdentifierValue")); // Act - var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { identity }); + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters([identity]); // Assert Assert.Equal(new string[] @@ -106,7 +106,7 @@ public void GetUniqueIdentifierParameters_PrefersSubClaimOverNameIdentifierAndUp identity.AddClaim(new Claim(ClaimTypes.Upn, "upnClaimValue")); // Act - var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { identity }); + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters([identity]); // Assert Assert.Equal(new string[] @@ -126,7 +126,7 @@ public void GetUniqueIdentifierParameters_PrefersNameIdentifierOverUpn() identity.AddClaim(new Claim(ClaimTypes.Upn, "upnClaimValue")); // Act - var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { identity }); + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters([identity]); // Assert Assert.Equal(new string[] @@ -146,7 +146,7 @@ public void GetUniqueIdentifierParameters_UsesUpnIfPresent() identity.AddClaim(new Claim(ClaimTypes.Upn, "upnClaimValue")); // Act - var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { identity }); + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters([identity]); // Assert Assert.Equal(new string[] @@ -167,7 +167,7 @@ public void GetUniqueIdentifierParameters_MultipleIdentities_UsesOnlyAuthenticat identity2.AddClaim(new Claim(ClaimTypes.NameIdentifier, "nameIdentifierValue")); // Act - var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters(new ClaimsIdentity[] { identity1, identity2 }); + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters([identity1, identity2]); // Assert Assert.Equal(new string[] @@ -192,8 +192,7 @@ public void GetUniqueIdentifierParameters_NoKnownClaimTypesFound_SortsAndReturns identity4.AddClaim(new Claim(ClaimTypes.Name, "claimName")); // Act - var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters( - new ClaimsIdentity[] { identity1, identity2, identity3, identity4 }); + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters([identity1, identity2, identity3, identity4]); // Assert Assert.Equal(new List @@ -220,8 +219,7 @@ public void GetUniqueIdentifierParameters_PrefersNameFromFirstIdentity_OverSubFr identity2.AddClaim(new Claim("sub", "subClaimValue")); // Act - var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters( - new ClaimsIdentity[] { identity1, identity2 }); + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters([identity1, identity2]); // Assert Assert.Equal(new string[] @@ -242,8 +240,7 @@ public void GetUniqueIdentifierParameters_PrefersUpnFromFirstIdentity_OverNameFr identity2.AddClaim(new Claim(ClaimTypes.NameIdentifier, "nameIdentifierValue")); // Act - var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters( - new ClaimsIdentity[] { identity1, identity2 }); + var uniqueIdentifierParameters = DefaultClaimUidExtractor.GetUniqueIdentifierParameters([identity1, identity2]); // Assert Assert.Equal(new string[] diff --git a/src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/Internal/Int7BitEncodingUtilsTests.cs b/src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/Internal/Int7BitEncodingUtilsTests.cs deleted file mode 100644 index f4ba8ca0db9e..000000000000 --- a/src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/Internal/Int7BitEncodingUtilsTests.cs +++ /dev/null @@ -1,31 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Collections.Generic; -using System.Text; -using Microsoft.AspNetCore.DataProtection.Internal; -using Microsoft.AspNetCore.Shared; - -namespace Microsoft.AspNetCore.DataProtection.Tests.Internal; - -public class Int7BitEncodingUtilsTests -{ - [Theory] - [InlineData(0, 1)] - [InlineData(1, 1)] - [InlineData(0b0_1111111, 1)] - [InlineData(0b1_0000000, 2)] - [InlineData(0b1111111_1111111, 2)] - [InlineData(0b1_0000000_0000000, 3)] - [InlineData(0b1111111_1111111_1111111, 3)] - [InlineData(0b1_0000000_0000000_0000000, 4)] - [InlineData(0b1111111_1111111_1111111_1111111, 4)] - [InlineData(0b1_0000000_0000000_0000000_0000000, 5)] - [InlineData(uint.MaxValue, 5)] - public void Measure7BitEncodedUIntLength_ReturnsExceptedLength(uint value, int expectedSize) - { - var actualSize = value.Measure7BitEncodedUIntLength(); - Assert.Equal(expectedSize, actualSize); - } -} diff --git a/src/Shared/Encoding/Int7BitEncodingUtils.cs b/src/Shared/Encoding/Int7BitEncodingUtils.cs index ca97096d1c17..06733395f096 100644 --- a/src/Shared/Encoding/Int7BitEncodingUtils.cs +++ b/src/Shared/Encoding/Int7BitEncodingUtils.cs @@ -1,11 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace Microsoft.AspNetCore.Shared; @@ -49,4 +45,108 @@ public static int Write7BitEncodedInt(this Span target, uint uValue) target[index++] = (byte)uValue; return index; } + + /// + /// Reads a 7-bit encoded unsigned integer from the source span. + /// + /// The source span to read from. + /// The decoded value. + /// The number of bytes consumed from the source span. + /// Thrown when the encoded value is malformed or exceeds 32 bits. + public static int Read7BitEncodedInt(this ReadOnlySpan source, out int value) + { + // Read out an int 7 bits at a time. The high bit of the byte, + // when on, indicates more bytes to read. + // A 32-bit unsigned integer can be encoded in at most 5 bytes. + + value = 0; + var shift = 0; + var index = 0; + + byte b; + do + { + // Check if we've exceeded the maximum number of bytes for a 32-bit integer + // or if we've run out of data. + if (shift == 35 || index >= source.Length) + { + throw new FormatException("Bad 7-bit encoded integer."); + } + + b = source[index++]; + value |= (b & 0x7F) << shift; + shift += 7; + } + while ((b & 0x80) != 0); + + return index; + } + + /// + /// Writes a 7-bit length-prefixed UTF-8 encoded string to the target span. + /// + /// The target span to write to. + /// The string to encode. + /// The number of bytes written to the target span. + internal static int Write7BitEncodedString(this Span target, string value) + { + if (string.IsNullOrEmpty(value)) + { + target[0] = 0; + return 1; + } + + var stringByteCount = Encoding.UTF8.GetByteCount(value); + var lengthPrefixSize = target.Write7BitEncodedInt(stringByteCount); +#if NETCOREAPP + Encoding.UTF8.GetBytes(value.AsSpan(), target[lengthPrefixSize..]); +#else + var stringBytes = Encoding.UTF8.GetBytes(value); + stringBytes.CopyTo(target.Slice(lengthPrefixSize)); +#endif + + return lengthPrefixSize + stringByteCount; + } + + /// + /// Calculates the number of bytes required to write a 7-bit length-prefixed UTF-8 encoded string. + /// + /// The string to measure. + /// The number of bytes required. + internal static int Measure7BitEncodedStringLength(string value) + { + if (string.IsNullOrEmpty(value)) + { + return 1; + } + + var stringByteCount = Encoding.UTF8.GetByteCount(value); + return stringByteCount.Measure7BitEncodedUIntLength() + stringByteCount; + } + + /// + /// Reads a 7-bit length-prefixed UTF-8 encoded string from the specified byte span and returns the number of bytes consumed. + /// + /// The span of bytes containing the 7-bit encoded length and UTF-8 encoded string data. + /// When this method returns, contains the decoded string value, or if the length is zero. + /// The total number of bytes consumed from to read the string. + /// Thrown if the encoded length is greater than the available bytes in . + internal static int Read7BitEncodedString(this ReadOnlySpan bytes, out string value) + { + value = string.Empty; + var consumed = Read7BitEncodedInt(bytes, out var length); + if (length == 0) + { + return consumed; + } + + if (bytes.Length < consumed + length) + { + throw new FormatException("Bad 7-bit encoded string."); + } + + value = Encoding.UTF8.GetString(bytes.Slice(consumed, length)); + consumed += length; + return consumed; + } } diff --git a/src/Shared/WebEncoders/WebEncoders.cs b/src/Shared/WebEncoders/WebEncoders.cs index 8d92baa94c51..20518a00ce23 100644 --- a/src/Shared/WebEncoders/WebEncoders.cs +++ b/src/Shared/WebEncoders/WebEncoders.cs @@ -204,6 +204,38 @@ public static byte[] Base64UrlDecode(string input, int offset, char[] buffer, in return Convert.FromBase64CharArray(buffer, bufferOffset, arraySizeRequired); } +#if NET9_0_OR_GREATER + /// + /// Decodes a base64url-encoded into . + /// + /// The base64url-encoded input to decode. + /// The buffer to receive the decoded bytes. + /// The number of bytes written to . + /// + /// The input must not contain any whitespace or padding characters. + /// Throws if the input is malformed. + /// + internal static int Base64UrlDecode(string input, Span output) + { + if (string.IsNullOrEmpty(input)) + { + return 0; + } + + var status = Base64Url.DecodeFromChars(input, output, out _, out int bytesWritten); + if (status != OperationStatus.Done) + { + throw new FormatException( + string.Format( + CultureInfo.CurrentCulture, + EncoderResources.WebEncoders_MalformedInput, + input.Length)); + } + + return bytesWritten; + } +#endif + /// /// Gets the minimum char[] size required for decoding of characters /// with the method. diff --git a/src/Shared/test/Shared.Tests/Encoding/Int7BitEncodingUtilsTests.cs b/src/Shared/test/Shared.Tests/Encoding/Int7BitEncodingUtilsTests.cs new file mode 100644 index 000000000000..6ef62c0625b3 --- /dev/null +++ b/src/Shared/test/Shared.Tests/Encoding/Int7BitEncodingUtilsTests.cs @@ -0,0 +1,342 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text; +using Microsoft.AspNetCore.Shared; + +namespace Microsoft.AspNetCore.Shared.Tests.Encoding; + +public class Int7BitEncodingUtilsTests +{ + [Theory] + [InlineData(0, 1)] + [InlineData(1, 1)] + [InlineData(0b0_1111111, 1)] + [InlineData(0b1_0000000, 2)] + [InlineData(0b1111111_1111111, 2)] + [InlineData(0b1_0000000_0000000, 3)] + [InlineData(0b1111111_1111111_1111111, 3)] + [InlineData(0b1_0000000_0000000_0000000, 4)] + [InlineData(0b1111111_1111111_1111111_1111111, 4)] + [InlineData(0b1_0000000_0000000_0000000_0000000, 5)] + [InlineData(uint.MaxValue, 5)] + public void Measure7BitEncodedUIntLength_ReturnsExceptedLength(uint value, int expectedSize) + { + var actualSize = value.Measure7BitEncodedUIntLength(); + Assert.Equal(expectedSize, actualSize); + } + + [Theory] + [InlineData(0, new byte[] { 0x00 })] + [InlineData(1, new byte[] { 0x01 })] + [InlineData(127, new byte[] { 0x7F })] + [InlineData(128, new byte[] { 0x80, 0x01 })] + [InlineData(255, new byte[] { 0xFF, 0x01 })] + [InlineData(256, new byte[] { 0x80, 0x02 })] + [InlineData(16383, new byte[] { 0xFF, 0x7F })] + [InlineData(16384, new byte[] { 0x80, 0x80, 0x01 })] + [InlineData(2097151, new byte[] { 0xFF, 0xFF, 0x7F })] + [InlineData(2097152, new byte[] { 0x80, 0x80, 0x80, 0x01 })] + [InlineData(268435455, new byte[] { 0xFF, 0xFF, 0xFF, 0x7F })] + [InlineData(268435456, new byte[] { 0x80, 0x80, 0x80, 0x80, 0x01 })] + [InlineData(int.MaxValue, new byte[] { 0xFF, 0xFF, 0xFF, 0xFF, 0x07 })] + public void Read7BitEncodedInt_DecodesCorrectly(int expected, byte[] encoded) + { + ReadOnlySpan source = encoded; + + var bytesConsumed = source.Read7BitEncodedInt(out var value); + + Assert.Equal(expected, value); + Assert.Equal(encoded.Length, bytesConsumed); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(127)] + [InlineData(128)] + [InlineData(255)] + [InlineData(16383)] + [InlineData(16384)] + [InlineData(2097151)] + [InlineData(2097152)] + [InlineData(268435455)] + [InlineData(268435456)] + [InlineData(int.MaxValue)] + public void Read7BitEncodedInt_RoundTripsWithWrite(int value) + { + Span buffer = stackalloc byte[5]; + + var bytesWritten = buffer.Write7BitEncodedInt(value); + + ReadOnlySpan source = buffer.Slice(0, bytesWritten); + var bytesConsumed = source.Read7BitEncodedInt(out var decoded); + + Assert.Equal(value, decoded); + Assert.Equal(bytesWritten, bytesConsumed); + } + + [Fact] + public void Read7BitEncodedInt_WithEmptySpan_ThrowsFormatException() + { + Assert.Throws(() => + { + var source = ReadOnlySpan.Empty; + return source.Read7BitEncodedInt(out _); + }); + } + + [Fact] + public void Read7BitEncodedInt_WithTruncatedData_ThrowsFormatException() + { + Assert.Throws(() => + { + // This represents the start of a multi-byte encoded value but is incomplete + // 0x80 has continuation bit set, meaning more bytes should follow + ReadOnlySpan source = [0x80]; + + return source.Read7BitEncodedInt(out _); + }); + } + + [Fact] + public void Read7BitEncodedInt_WithOverflow_ThrowsFormatException() + { + Assert.Throws(() => + { + // 6 bytes with continuation bits set would overflow a 32-bit integer + ReadOnlySpan source = [0x80, 0x80, 0x80, 0x80, 0x80, 0x01]; + + return source.Read7BitEncodedInt(out _); + }); + } + + [Fact] + public void Read7BitEncodedInt_WithExtraDataAfterValue_ConsumesOnlyNeededBytes() + { + // Value 127 followed by extra bytes + ReadOnlySpan source = [0x7F, 0xFF, 0xFF]; + + var bytesConsumed = source.Read7BitEncodedInt(out var value); + + Assert.Equal(127, value); + Assert.Equal(1, bytesConsumed); + } + + [Theory] + [InlineData("")] + [InlineData("Hello")] + [InlineData("Hello, World!")] + [InlineData("UTF-8: \u00e9\u00e8\u00ea")] + public void Read7BitEncodedString_DecodesCorrectly(string expected) + { + var stringBytes = System.Text.Encoding.UTF8.GetBytes(expected); + var lengthBytes = new byte[5]; + Span lengthSpan = lengthBytes; + var lengthSize = lengthSpan.Write7BitEncodedInt(stringBytes.Length); + + var encodedBytes = new byte[lengthSize + stringBytes.Length]; + Array.Copy(lengthBytes, 0, encodedBytes, 0, lengthSize); + Array.Copy(stringBytes, 0, encodedBytes, lengthSize, stringBytes.Length); + + ReadOnlySpan source = encodedBytes; + + var bytesConsumed = source.Read7BitEncodedString(out var value); + + Assert.Equal(expected, value); + Assert.Equal(encodedBytes.Length, bytesConsumed); + } + + [Fact] + public void Read7BitEncodedString_WithEmptyString_ReturnsEmptyAndConsumesLengthByte() + { + // Length of 0 + ReadOnlySpan source = new byte[] { 0x00 }; + + var bytesConsumed = source.Read7BitEncodedString(out var value); + + Assert.Equal(string.Empty, value); + Assert.Equal(1, bytesConsumed); + } + + [Fact] + public void Read7BitEncodedString_WithTruncatedStringData_ThrowsFormatException() + { + Assert.Throws(() => + { + // Length says 10 bytes, but only 3 bytes of data follow + ReadOnlySpan source = [0x0A, 0x41, 0x42, 0x43]; + + return source.Read7BitEncodedString(out _); + }); + } + + [Fact] + public void Read7BitEncodedString_WithMultiByteLengthPrefixAndTruncatedData_ThrowsFormatException() + { + Assert.Throws(() => + { + // Length prefix 0xC8 0x01 = 200 (multi-byte), but only 2 bytes of string data follow + // Total: 4 bytes, but need 2 (prefix) + 200 (data) = 202 bytes + ReadOnlySpan source = [0xC8, 0x01, 0x41, 0x42]; + + return source.Read7BitEncodedString(out _); + }); + } + + [Fact] + public void Read7BitEncodedString_WithTruncatedLengthPrefix_ThrowsFormatException() + { + Assert.Throws(() => + { + // Continuation bit set but no more data + ReadOnlySpan source = [0x80]; + + return source.Read7BitEncodedString(out _); + }); + } + + [Fact] + public void Read7BitEncodedString_WithExtraDataAfterString_ConsumesOnlyNeededBytes() + { + // "Hi" (length 2) followed by extra bytes + ReadOnlySpan source = new byte[] { 0x02, 0x48, 0x69, 0xFF, 0xFF }; + + var bytesConsumed = source.Read7BitEncodedString(out var value); + + Assert.Equal("Hi", value); + Assert.Equal(3, bytesConsumed); // 1 byte length + 2 bytes string + } + + [Fact] + public void Read7BitEncodedString_WithMultiByteLengthPrefix_DecodesCorrectly() + { + // Create a string that requires a multi-byte length prefix (> 127 bytes) + var longString = new string('A', 200); + var stringBytes = System.Text.Encoding.UTF8.GetBytes(longString); + var lengthBytes = new byte[5]; + Span lengthSpan = lengthBytes; + var lengthSize = lengthSpan.Write7BitEncodedInt(stringBytes.Length); + + Assert.True(lengthSize > 1); // Verify we're testing multi-byte length + + var encodedBytes = new byte[lengthSize + stringBytes.Length]; + Array.Copy(lengthBytes, 0, encodedBytes, 0, lengthSize); + Array.Copy(stringBytes, 0, encodedBytes, lengthSize, stringBytes.Length); + + ReadOnlySpan source = encodedBytes; + + var bytesConsumed = source.Read7BitEncodedString(out var value); + + Assert.Equal(longString, value); + Assert.Equal(encodedBytes.Length, bytesConsumed); + } + + [Theory] + [InlineData("")] + [InlineData("Hello")] + [InlineData("Hello, World!")] + [InlineData("UTF-8: \u00e9\u00e8\u00ea")] + public void Write7BitEncodedString_EncodesCorrectly(string value) + { + var expectedByteCount = Int7BitEncodingUtils.Measure7BitEncodedStringLength(value); + Span buffer = stackalloc byte[expectedByteCount]; + + var bytesWritten = buffer.Write7BitEncodedString(value); + + Assert.Equal(expectedByteCount, bytesWritten); + + // Verify by reading back + ReadOnlySpan source = buffer.Slice(0, bytesWritten); + var bytesConsumed = source.Read7BitEncodedString(out var decoded); + + Assert.Equal(value, decoded); + Assert.Equal(bytesWritten, bytesConsumed); + } + + [Fact] + public void Write7BitEncodedString_WithNullString_WritesZeroLength() + { + Span buffer = stackalloc byte[10]; + + var bytesWritten = buffer.Write7BitEncodedString(null!); + + Assert.Equal(1, bytesWritten); + Assert.Equal(0, buffer[0]); + } + + [Fact] + public void Write7BitEncodedString_WithEmptyString_WritesZeroLength() + { + Span buffer = stackalloc byte[10]; + + var bytesWritten = buffer.Write7BitEncodedString(string.Empty); + + Assert.Equal(1, bytesWritten); + Assert.Equal(0, buffer[0]); + } + + [Fact] + public void Write7BitEncodedString_WithLongString_UsesMultiByteLengthPrefix() + { + var longString = new string('A', 200); + var expectedByteCount = Int7BitEncodingUtils.Measure7BitEncodedStringLength(longString); + var buffer = new byte[expectedByteCount]; + + var bytesWritten = buffer.AsSpan().Write7BitEncodedString(longString); + + Assert.Equal(expectedByteCount, bytesWritten); + + // Verify the length prefix is multi-byte (200 > 127) + Assert.True((buffer[0] & 0x80) != 0); // Continuation bit set + + // Verify by reading back + ReadOnlySpan source = buffer; + var bytesConsumed = source.Read7BitEncodedString(out var decoded); + + Assert.Equal(longString, decoded); + Assert.Equal(bytesWritten, bytesConsumed); + } + + [Theory] + [InlineData("", 1)] + [InlineData("A", 2)] + [InlineData("Hello", 6)] + public void Measure7BitEncodedStringLength_ReturnsCorrectLength(string value, int expectedLength) + { + var actualLength = Int7BitEncodingUtils.Measure7BitEncodedStringLength(value); + + Assert.Equal(expectedLength, actualLength); + } + + [Fact] + public void Measure7BitEncodedStringLength_WithNullString_ReturnsOne() + { + var length = Int7BitEncodingUtils.Measure7BitEncodedStringLength(null!); + + Assert.Equal(1, length); + } + + [Fact] + public void Measure7BitEncodedStringLength_WithLongString_IncludesMultiByteLengthPrefix() + { + var longString = new string('A', 200); + + var length = Int7BitEncodingUtils.Measure7BitEncodedStringLength(longString); + + // 200 bytes for string + 2 bytes for length prefix (200 requires 2 bytes in 7-bit encoding) + Assert.Equal(202, length); + } + + [Fact] + public void Measure7BitEncodedStringLength_WithUtf8String_CountsUtf8Bytes() + { + // Each of these characters is 2 bytes in UTF-8 + var utf8String = "\u00e9\u00e8\u00ea"; // 3 chars, 6 bytes + + var length = Int7BitEncodingUtils.Measure7BitEncodedStringLength(utf8String); + + // 6 bytes for string + 1 byte for length prefix + Assert.Equal(7, length); + } +} diff --git a/src/Shared/test/Shared.Tests/Microsoft.AspNetCore.Shared.Tests.csproj b/src/Shared/test/Shared.Tests/Microsoft.AspNetCore.Shared.Tests.csproj index e017c1501705..ea1335a8167a 100644 --- a/src/Shared/test/Shared.Tests/Microsoft.AspNetCore.Shared.Tests.csproj +++ b/src/Shared/test/Shared.Tests/Microsoft.AspNetCore.Shared.Tests.csproj @@ -1,4 +1,4 @@ - + $(DefaultNetCoreTargetFramework) @@ -43,8 +43,9 @@ - - + + +