diff --git a/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs b/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs index f13540fa579c..b74fefdb8b96 100644 --- a/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs +++ b/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Net.Security; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; @@ -96,6 +97,12 @@ public void AllowAnyClientCertificate() /// public Action? OnAuthenticate { get; set; } + /// + /// A callback to be invoked to get the TLS client hello bytes. + /// Null by default. + /// + public Action>? TlsClientHelloBytesCallback { get; set; } + /// /// Specifies the maximum amount of time allowed for the TLS/SSL handshake. This must be positive /// or . Defaults to 10 seconds. diff --git a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs index 32bd1dd59889..42d7ac8f0476 100644 --- a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs +++ b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs @@ -5,6 +5,7 @@ using System.Security.Cryptography.X509Certificates; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; using Microsoft.AspNetCore.Server.Kestrel.Https; using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; using Microsoft.Extensions.DependencyInjection; @@ -197,6 +198,15 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, HttpsConn listenOptions.IsTls = true; listenOptions.HttpsOptions = httpsOptions; + if (httpsOptions.TlsClientHelloBytesCallback is not null) + { + listenOptions.Use(next => + { + var middleware = new TlsListenerMiddleware(next, httpsOptions.TlsClientHelloBytesCallback); + return middleware.OnTlsClientHelloAsync; + }); + } + listenOptions.Use(next => { var middleware = new HttpsConnectionMiddleware(next, httpsOptions, listenOptions.Protocols, loggerFactory, metrics); diff --git a/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs new file mode 100644 index 000000000000..01bc75553a09 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs @@ -0,0 +1,149 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using Microsoft.AspNetCore.Connections; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; + +internal sealed class TlsListenerMiddleware +{ + private readonly ConnectionDelegate _next; + private readonly Action> _tlsClientHelloBytesCallback; + + public TlsListenerMiddleware(ConnectionDelegate next, Action> tlsClientHelloBytesCallback) + { + _next = next; + _tlsClientHelloBytesCallback = tlsClientHelloBytesCallback; + } + + /// + /// Sniffs the TLS Client Hello message, and invokes a callback if found. + /// + internal async Task OnTlsClientHelloAsync(ConnectionContext connection) + { + var input = connection.Transport.Input; + ClientHelloParseState parseState = ClientHelloParseState.NotEnoughData; + + while (true) + { + var result = await input.ReadAsync(); + var buffer = result.Buffer; + + try + { + // If the buffer length is less than 6 bytes (handshake + version + length + client-hello byte) + // and no more data is coming, we can't block in a loop here because we will not get more data + if (result.IsCompleted && buffer.Length < 6) + { + break; + } + + parseState = TryParseClientHello(buffer, out var clientHelloBytes); + if (parseState == ClientHelloParseState.NotEnoughData) + { + // if no data will be added, and we still lack enough bytes + // we can't block in a loop, so just exit + if (result.IsCompleted) + { + break; + } + + continue; + } + + if (parseState == ClientHelloParseState.ValidTlsClientHello) + { + _tlsClientHelloBytesCallback(connection, clientHelloBytes); + } + + Debug.Assert(parseState is ClientHelloParseState.ValidTlsClientHello or ClientHelloParseState.NotTlsClientHello); + break; // We can continue with the middleware pipeline + } + finally + { + if (parseState is ClientHelloParseState.NotEnoughData) + { + input.AdvanceTo(buffer.Start, buffer.End); + } + else + { + // ready to continue middleware pipeline, reset the buffer to initial state + input.AdvanceTo(buffer.Start); + } + } + } + + await _next(connection); + } + + /// + /// RFCs + /// ---- + /// TLS 1.1: https://datatracker.ietf.org/doc/html/rfc4346#section-6.2 + /// TLS 1.2: https://datatracker.ietf.org/doc/html/rfc5246#section-6.2 + /// TLS 1.3: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1 + /// + private static ClientHelloParseState TryParseClientHello(ReadOnlySequence buffer, out ReadOnlySequence clientHelloBytes) + { + clientHelloBytes = default; + + if (buffer.Length < 6) + { + return ClientHelloParseState.NotEnoughData; + } + + var reader = new SequenceReader(buffer); + + // Content type must be 0x16 for TLS Handshake + if (!reader.TryRead(out byte contentType) || contentType != 0x16) + { + return ClientHelloParseState.NotTlsClientHello; + } + + // Protocol version + if (!reader.TryReadBigEndian(out short version) || !IsValidProtocolVersion(version)) + { + return ClientHelloParseState.NotTlsClientHello; + } + + // Record length + if (!reader.TryReadBigEndian(out short recordLength)) + { + return ClientHelloParseState.NotTlsClientHello; + } + + // byte 6: handshake message type (must be 0x01 for ClientHello) + if (!reader.TryRead(out byte handshakeType) || handshakeType != 0x01) + { + return ClientHelloParseState.NotTlsClientHello; + } + + // 5 bytes are + // 1) Handshake (1 byte) + // 2) Protocol version (2 bytes) + // 3) Record length (2 bytes) + if (buffer.Length < 5 + recordLength) + { + return ClientHelloParseState.NotEnoughData; + } + + clientHelloBytes = buffer.Slice(0, 5 + recordLength); + return ClientHelloParseState.ValidTlsClientHello; + } + + private static bool IsValidProtocolVersion(short version) + => version is 0x0300 // SSL 3.0 (0x0300) + or 0x0301 // TLS 1.0 (0x0301) + or 0x0302 // TLS 1.1 (0x0302) + or 0x0303 // TLS 1.2 (0x0303) + or 0x0304; // TLS 1.3 (0x0304) + + private enum ClientHelloParseState : byte + { + NotEnoughData, + NotTlsClientHello, + ValidTlsClientHello + } +} diff --git a/src/Servers/Kestrel/Core/src/PublicAPI.Unshipped.txt b/src/Servers/Kestrel/Core/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..69a983823915 100644 --- a/src/Servers/Kestrel/Core/src/PublicAPI.Unshipped.txt +++ b/src/Servers/Kestrel/Core/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Server.Kestrel.Https.HttpsConnectionAdapterOptions.TlsClientHelloBytesCallback.get -> System.Action>? +Microsoft.AspNetCore.Server.Kestrel.Https.HttpsConnectionAdapterOptions.TlsClientHelloBytesCallback.set -> void diff --git a/src/Servers/Kestrel/Core/test/TestHelpers/ObservablePipeReader.cs b/src/Servers/Kestrel/Core/test/TestHelpers/ObservablePipeReader.cs new file mode 100644 index 000000000000..299d6491b46b --- /dev/null +++ b/src/Servers/Kestrel/Core/test/TestHelpers/ObservablePipeReader.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests.TestHelpers; + +internal class ObservablePipeReader : PipeReader +{ + private readonly PipeReader _inner; + + public ObservablePipeReader(PipeReader reader) + { + _inner = reader; + } + + /// + /// Number of times was called. + /// + public int ReadAsyncCounter { get; private set; } + + public override void AdvanceTo(SequencePosition consumed) + => _inner.AdvanceTo(consumed); + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + => _inner.AdvanceTo(consumed, examined); + + public override void CancelPendingRead() + => _inner.CancelPendingRead(); + + public override void Complete(Exception exception = null) + => _inner.Complete(exception); + + public override ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + ReadAsyncCounter++; + return _inner.ReadAsync(cancellationToken); + } + + public override bool TryRead(out ReadResult result) + { + return _inner.TryRead(out result); + } +} diff --git a/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs new file mode 100644 index 000000000000..590555e3c22e --- /dev/null +++ b/src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs @@ -0,0 +1,569 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Net; +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware; +using Microsoft.AspNetCore.Server.Kestrel.Core.Tests.TestHelpers; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Moq; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests; + +public partial class TlsListenerMiddlewareTests +{ + [Theory] + [MemberData(nameof(ValidClientHelloData))] + public Task OnTlsClientHelloAsync_ValidData(int id, byte[] packetBytes, bool nextMiddlewareInvoked) + => RunTlsClientHelloCallbackTest(id, packetBytes, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: true); + + [Theory] + [MemberData(nameof(InvalidClientHelloData))] + public Task OnTlsClientHelloAsync_InvalidData(int id, byte[] packetBytes, bool nextMiddlewareInvoked) + => RunTlsClientHelloCallbackTest(id, packetBytes, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: false); + + [Theory] + [MemberData(nameof(ValidClientHelloData_Segmented))] + public Task OnTlsClientHelloAsync_ValidData_MultipleSegments(int id, List packets, bool nextMiddlewareInvoked) + => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: true); + + [Theory] + [MemberData(nameof(InvalidClientHelloData_Segmented))] + public Task OnTlsClientHelloAsync_InvalidData_MultipleSegments(int id, List packets, bool nextMiddlewareInvoked) + => RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: false); + + [Fact] + public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads() + { + var serviceContext = new TestServiceContext(); + + var pipe = new Pipe(); + var writer = pipe.Writer; + var reader = new ObservablePipeReader(pipe.Reader); + + var transport = new DuplexPipe(reader, writer); + var transportConnection = new DefaultConnectionContext("test", transport, transport); + + var nextMiddlewareInvoked = false; + var tlsClientHelloCallbackInvoked = false; + + var middleware = new TlsListenerMiddleware( + next: ctx => + { + nextMiddlewareInvoked = true; + var readResult = ctx.Transport.Input.ReadAsync(); + Assert.Equal(5, readResult.Result.Buffer.Length); + + return Task.CompletedTask; + }, + tlsClientHelloBytesCallback: (ctx, data) => + { + tlsClientHelloCallbackInvoked = true; + } + ); + + await writer.WriteAsync(new byte[1] { 0x16 }); + var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection)); + await writer.WriteAsync(new byte[2] { 0x03, 0x01 }); + await writer.WriteAsync(new byte[2] { 0x00, 0x20 }); + await writer.CompleteAsync(); + + await middlewareTask; + Assert.True(nextMiddlewareInvoked); + Assert.False(tlsClientHelloCallbackInvoked); + + // ensuring that we have read limited number of times + Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 5, + $"Expected ReadAsync() to happen about 2-5 times. Actually happened {reader.ReadAsyncCounter} times."); + } + + private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( + int id, + List packets, + bool nextMiddlewareInvokedExpected, + bool tlsClientHelloCallbackExpected) + { + var pipe = new Pipe(); + var writer = pipe.Writer; + var reader = new ObservablePipeReader(pipe.Reader); + + var transport = new DuplexPipe(reader, writer); + var transportConnection = new DefaultConnectionContext("test", transport, transport); + + var nextMiddlewareInvokedActual = false; + var tlsClientHelloCallbackActual = false; + + var fullLength = packets.Sum(p => p.Length); + + var middleware = new TlsListenerMiddleware( + next: ctx => + { + nextMiddlewareInvokedActual = true; + return Task.CompletedTask; + }, + tlsClientHelloBytesCallback: (ctx, data) => + { + tlsClientHelloCallbackActual = true; + + Assert.NotNull(ctx); + Assert.False(data.IsEmpty); + Assert.Equal(fullLength, data.Length); + } + ); + + // write first packet + await writer.WriteAsync(packets[0]); + var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection)); + + var random = new Random(); + await Task.Delay(millisecondsDelay: random.Next(25, 75)); + + // write all next packets + foreach (var packet in packets.Skip(1)) + { + await writer.WriteAsync(packet); + await Task.Delay(millisecondsDelay: random.Next(25, 75)); + } + await writer.CompleteAsync(); + await middlewareTask; + + Assert.Equal(nextMiddlewareInvokedExpected, nextMiddlewareInvokedActual); + Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual); + } + + private async Task RunTlsClientHelloCallbackTest( + int id, + byte[] packetBytes, + bool nextMiddlewareExpected, + bool tlsClientHelloCallbackExpected) + { + var pipe = new Pipe(); + var writer = pipe.Writer; + var reader = new ObservablePipeReader(pipe.Reader); + + var transport = new DuplexPipe(reader, writer); + var transportConnection = new DefaultConnectionContext("test", transport, transport); + + var nextMiddlewareInvokedActual = false; + var tlsClientHelloCallbackActual = false; + + var middleware = new TlsListenerMiddleware( + next: ctx => + { + nextMiddlewareInvokedActual = true; + var readResult = ctx.Transport.Input.ReadAsync(); + Assert.Equal(packetBytes.Length, readResult.Result.Buffer.Length); + + return Task.CompletedTask; + }, + tlsClientHelloBytesCallback: (ctx, data) => + { + tlsClientHelloCallbackActual = true; + + Assert.NotNull(ctx); + Assert.False(data.IsEmpty); + Assert.Equal(packetBytes.Length, data.Length); + } + ); + + await writer.WriteAsync(packetBytes); + await writer.CompleteAsync(); + + // call middleware and expect a callback + await middleware.OnTlsClientHelloAsync(transportConnection); + + Assert.Equal(nextMiddlewareExpected, nextMiddlewareInvokedActual); + Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual); + } + + public static IEnumerable ValidClientHelloData() + { + int id = 0; + foreach (var clientHello in valid_collection) + { + yield return new object[] { id++, clientHello, true /* invokes next middleware */ }; + } + } + + public static IEnumerable InvalidClientHelloData() + { + int id = 0; + foreach (byte[] clientHello in invalid_collection) + { + yield return new object[] { id++, clientHello, true /* invokes next middleware */ }; + } + } + + public static IEnumerable ValidClientHelloData_Segmented() + { + int id = 0; + foreach (var clientHello in valid_collection) + { + var clientHelloSegments = new List + { + clientHello.Take(1).ToArray(), + clientHello.Skip(1).Take(2).ToArray(), + clientHello.Skip(3).Take(2).ToArray(), + clientHello.Skip(5).Take(1).ToArray(), + clientHello.Skip(6).Take(clientHello.Length - 6).ToArray() + }; + + yield return new object[] { id++, clientHelloSegments, true /* invokes next middleware */ }; + } + } + + public static IEnumerable InvalidClientHelloData_Segmented() + { + int id = 0; + foreach (var clientHello in invalid_collection) + { + var clientHelloSegments = new List(); + if (clientHello.Length >= 1) + { + clientHelloSegments.Add(clientHello.Take(1).ToArray()); + } + if (clientHello.Length >= 3) + { + clientHelloSegments.Add(clientHello.Skip(1).Take(2).ToArray()); + } + if (clientHello.Length >= 5) + { + clientHelloSegments.Add(clientHello.Skip(3).Take(2).ToArray()); + } + if (clientHello.Length >= 6) + { + clientHelloSegments.Add(clientHello.Skip(5).Take(1).ToArray()); + } + if (clientHello.Length >= 7) + { + clientHelloSegments.Add(clientHello.Skip(6).Take(clientHello.Length - 6).ToArray()); + } + + yield return new object[] { id++, clientHelloSegments, true /* invokes next middleware */ }; + } + } + + private static byte[] valid_clientHelloHeader = + { + // 0x16 = Handshake + 0x16, + // 0x0301 = TLS 1.0 + 0x03, 0x01, + // length = 0x0020 (32 bytes) + 0x00, 0x20, + // Handshake.msg_type (client hello) + 0x01, + // 31 bytes (zeros for simplicity) + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0 + }; + + private static byte[] valid_Ssl3ClientHello = + { + 0x16, 0x03, 0x00, // ContentType: Handshake, Version: SSL 3.0 + 0x00, 0x2F, // Length: 47 bytes + 0x01, // Handshake Type: ClientHello + 0x00, 0x00, 0x2B, // Length: 43 bytes + 0x03, 0x00, // Client Version: SSL 3.0 + // Random (32 bytes) + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, + 0x00, // Session ID Length + 0x00, 0x04, // Cipher Suites Length + 0x00, 0x2F, 0x00, 0x35, // Cipher Suites + 0x01, 0x00 // Compression Methods: null + }; + + private static byte[] valid_Tls10ClientHello = + { + 0x16, 0x03, 0x01, // ContentType: Handshake, Version: TLS 1.0 + 0x00, 0x2F, // Length: 47 bytes + 0x01, // Handshake Type: ClientHello + 0x00, 0x00, 0x2B, // Length: 43 bytes + 0x03, 0x01, // Client Version: TLS 1.0 + // Random (32 bytes) + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, + 0x00, // Session ID Length + 0x00, 0x04, // Cipher Suites Length + 0x00, 0x2F, 0x00, 0x35, // Cipher Suites + 0x01, 0x00 // Compression Methods: null + }; + + private static byte[] valid_Tls11ClientHello = + { + 0x16, 0x03, 0x02, // ContentType: Handshake, Version: TLS 1.1 + 0x00, 0x2F, // Length: 47 bytes + 0x01, // Handshake Type: ClientHello + 0x00, 0x00, 0x2B, // Length: 43 bytes + 0x03, 0x02, // Client Version: TLS 1.1 + // Random (32 bytes) + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, + 0x00, // Session ID Length + 0x00, 0x04, // Cipher Suites Length + 0x00, 0x2F, 0x00, 0x35, // Cipher Suites: TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA + 0x01, 0x00 // Compression Methods: null + }; + + private static byte[] valid_Tls12ClientHello = + { + // SslPlainText.(ContentType+ProtocolVersion) + 0x16, 0x03, 0x03, + // SslPlainText.length + 0x00, 0xD1, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x00, 0xCD, + // ClientHello.client_version + 0x03, 0x03, + // ClientHello.random + 0x0C, 0x3C, 0x85, 0x78, 0xCA, + 0x67, 0x70, 0xAA, 0x38, 0xCB, + 0x28, 0xBC, 0xDC, 0x3E, 0x30, + 0xBF, 0x11, 0x96, 0x95, 0x1A, + 0xB9, 0xF0, 0x99, 0xA4, 0x91, + 0x09, 0x13, 0xB4, 0x89, 0x94, + 0x27, 0x2E, + // ClientHello.SessionId + 0x00, + // ClientHello.cipher_suites_length + 0x00, 0x5C, + // ClientHello.cipher_suites + 0xC0, 0x30, 0xC0, 0x2C, 0xC0, 0x28, 0xC0, 0x24, + 0xC0, 0x14, 0xC0, 0x0A, 0x00, 0x9f, 0x00, 0x6B, + 0x00, 0x39, 0xCC, 0xA9, 0xCC, 0xA8, 0xCC, 0xAA, + 0xFF, 0x85, 0x00, 0xC4, 0x00, 0x88, 0x00, 0x81, + 0x00, 0x9D, 0x00, 0x3D, 0x00, 0x35, 0x00, 0xC0, + 0x00, 0x84, 0xC0, 0x2f, 0xC0, 0x2B, 0xC0, 0x27, + 0xC0, 0x23, 0xC0, 0x13, 0xC0, 0x09, 0x00, 0x9E, + 0x00, 0x67, 0x00, 0x33, 0x00, 0xBE, 0x00, 0x45, + 0x00, 0x9C, 0x00, 0x3C, 0x00, 0x2F, 0x00, 0xBA, + 0x00, 0x41, 0xC0, 0x11, 0xC0, 0x07, 0x00, 0x05, + 0x00, 0x04, 0xC0, 0x12, 0xC0, 0x08, 0x00, 0x16, + 0x00, 0x0a, 0x00, 0xff, + // ClientHello.compression_methods + 0x01, 0x01, + // ClientHello.extension_list_length + 0x00, 0x48, + // Extension.extension_type (ec_point_formats) + 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, + // Extension.extension_type (supported_groups) + 0x00, 0x0A, 0x00, 0x08, 0x00, 0x06, 0x00, 0x1D, + 0x00, 0x17, 0x00, 0x18, + // Extension.extension_type (session_ticket) + 0x00, 0x23, 0x00, 0x00, + // Extension.extension_type (signature_algorithms) + 0x00, 0x0D, 0x00, 0x1C, 0x00, 0x1A, 0x06, 0x01, + 0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03, + 0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED, + 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03, + // Extension.extension_type (application_level_Protocol) + 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0C, 0x02, 0x68, + 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, 0x2F, 0x31, + 0x2E, 0x31 + }; + + private static byte[] valid_Tls13ClientHello = + { + // SslPlainText.(ContentType+ProtocolVersion) + 0x16, 0x03, 0x04, + // SslPlainText.length + 0x01, 0x08, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x01, 0x04, + // ClientHello.client_version + 0x03, 0x03, + // ClientHello.random + 0x0C, 0x3C, 0x85, 0x78, 0xCA, 0x67, 0x70, 0xAA, + 0x38, 0xCB, 0x28, 0xBC, 0xDC, 0x3E, 0x30, 0xBF, + 0x11, 0x96, 0x95, 0x1A, 0xB9, 0xF0, 0x99, 0xA4, + 0x91, 0x09, 0x13, 0xB4, 0x89, 0x94, 0x27, 0x2E, + // ClientHello.SessionId_Length + 0x20, + // ClientHello.SessionId + 0x0C, 0x3C, 0x85, 0x78, 0xCA, 0x67, 0x70, 0xAA, + 0x38, 0xCB, 0x28, 0xBC, 0xDC, 0x3E, 0x30, 0xBF, + 0x11, 0x96, 0x95, 0x1A, 0xB9, 0xF0, 0x99, 0xA4, + 0x91, 0x09, 0x13, 0xB4, 0x89, 0x94, 0x27, 0x2E, + // ClientHello.cipher_suites_length + 0x00, 0x0C, + // ClientHello.cipher_suites + 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, 0xC0, 0x14, + 0xc0, 0x30, 0x00, 0xFF, + // ClientHello.compression_methods + 0x01, 0x00, + // ClientHello.extension_list_length + 0x00, 0xAF, + // Extension.extension_type (server_name) (10.211.55.2) + 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, + 0x0B, 0x31, 0x30, 0x2E, 0x32, 0x31, 0x31, 0x2E, + 0x35, 0x35, 0x2E, 0x32, + // Extension.extension_type (ec_point_formats) + 0x00, 0x0B, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, + // Extension.extension_type (supported_groups) + 0x00, 0x0A, 0x00, 0x0C, 0x00, 0x0A, 0x00, 0x1D, + 0x00, 0x17, 0x00, 0x1E, 0x00, 0x19, 0x00, 0x18, + // Extension.extension_type (application_level_Protocol) (boo) + 0x00, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, 0x62, + 0x6f, 0x6f, + // Extension.extension_type (encrypt_then_mac) + 0x00, 0x16, 0x00, 0x00, + // Extension.extension_type (extended_master_key_secret) + 0x00, 0x17, 0x00, 0x00, + // Extension.extension_type (signature_algorithms) + 0x00, 0x0D, 0x00, 0x30, 0x00, 0x2E, + 0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03, + 0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03, + 0x06, 0x03, 0xEF, 0xEF, 0x05, 0x01, 0x05, 0x03, + 0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED, + 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03, + 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, + // Extension.extension_type (supported_versions) + 0x00, 0x2B, 0x00, 0x09, 0x08, 0x03, 0x04, 0x03, + 0x03, 0x03, 0x02, 0x03, 0x01, + // Extension.extension_type (psk_key_exchange_modes) + 0x00, 0x2D, 0x00, 0x02, 0x01, 0x01, + // Extension.extension_type (key_share) + 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1D, + 0x00, 0x20, + 0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED, + 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03, + 0x04, 0x01, 0x04, 0x03, 0xEE, 0xEE, 0xED, 0xED, + 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03 + }; + + private static byte[] valid_TlsClientHelloNoExtensions = + { + 0x16, 0x03, 0x03, 0x00, 0x39, 0x01, 0x00, 0x00, + 0x35, 0x03, 0x03, 0x62, 0x5d, 0x50, 0x2a, 0x41, + 0x2f, 0xd8, 0xc3, 0x65, 0x35, 0xea, 0x01, 0x70, + 0x03, 0x7e, 0x7e, 0x2d, 0xd4, 0xfe, 0x93, 0x39, + 0xa4, 0x04, 0x66, 0xbb, 0x46, 0x91, 0x41, 0xc3, + 0x48, 0x87, 0x3d, 0x00, 0x00, 0x0e, 0x00, 0x3d, + 0x00, 0x3c, 0x00, 0x0a, 0x00, 0x35, 0x00, 0x2f, + 0x00, 0x05, 0x00, 0x04, 0x01, 0x00 + }; + + private static byte[] invalid_TlsClientHelloHeader = + { + // Handshake - incorrect + 0x01, + // ProtocolVersion + 0x03, 0x04, + // SslPlainText.length + 0x00, 0xCB, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x00, 0xC7, + }; + + private static byte[] invalid_3BytesMessage = + { + // Handshake + 0x016, + // Protocol Version + 0x03, 0x01, + // not enough data - so incorrect + }; + + private static byte[] invalid_9BytesMessage = + { + // 0x16 = Handshake + 0x16, + // 0x0301 = TLS 1.0 + 0x03, 0x01, + // length = 0x0020 (32 bytes) + 0x00, 0x20, + // Handshake.msg_type (client hello) + 0x01, + // should have 31 bytes (zeros for simplicity) + 0, 0, 0 + // no other data here - incorrect + }; + + private static byte[] invalid_UnknownProtocolVersion1 = + { + // Handshake + 0x016, + // ProtocolVersion - incorrect + 0x02, 0x05, + // SslPlainText.length + 0x00, 0xCB, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x00, 0xC7, + }; + + private static byte[] invalid_UnknownProtocolVersion2 = + { + // Handshake + 0x016, + // ProtocolVersion - incorrect + 0x02, 0x01, + // SslPlainText.length + 0x00, 0xCB, + // Handshake.msg_type (client hello) + 0x01, + // Handshake.length + 0x00, 0x00, 0xC7, + }; + + private static byte[] invalid_IncorrectHandshakeMessageType = + { + // Handshake + 0x016, + // ProtocolVersion + 0x02, 0x00, + // SslPlainText.length + 0x00, 0xCB, + // Handshake.msg_type (client hello) - incorrect + 0x02, + // Handshake.length + 0x00, 0x00, 0xC7, + }; + + private static List valid_collection = new List() + { + valid_clientHelloHeader, valid_Ssl3ClientHello, valid_Tls10ClientHello, + valid_Tls11ClientHello, valid_Tls12ClientHello, valid_Tls13ClientHello, + valid_TlsClientHelloNoExtensions + }; + + private static List invalid_collection = new List() + { + invalid_TlsClientHelloHeader, invalid_3BytesMessage, invalid_9BytesMessage, + invalid_UnknownProtocolVersion1, invalid_UnknownProtocolVersion2, invalid_IncorrectHandshakeMessageType + }; +} diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/InMemoryTransportConnection.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/InMemoryTransportConnection.cs index 086694908360..245027b33330 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/InMemoryTransportConnection.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/InMemoryTransportConnection.cs @@ -91,7 +91,7 @@ public override async ValueTask DisposeAsync() // This piece of code allows us to wait until the PipeReader has been awaited on. // We need to wrap lots of layers (including the ValueTask) to gain visiblity into when // the machinery for the await happens - private class ObservableDuplexPipe : IDuplexPipe + internal class ObservableDuplexPipe : IDuplexPipe { private readonly ObservablePipeReader _reader; @@ -110,11 +110,14 @@ public ObservableDuplexPipe(IDuplexPipe duplexPipe) public PipeWriter Output { get; } + public int ReadAsyncCounter => _reader.ReadAsyncCounter; + private class ObservablePipeReader : PipeReader { private readonly PipeReader _reader; private readonly TaskCompletionSource _tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + public int ReadAsyncCounter { get; private set; } = 0; public Task WaitForReadTask => _tcs.Task; public ObservablePipeReader(PipeReader reader) @@ -144,6 +147,7 @@ public override void Complete(Exception exception = null) public override ValueTask ReadAsync(CancellationToken cancellationToken = default) { + ReadAsyncCounter++; var task = _reader.ReadAsync(cancellationToken); if (_tcs.Task.IsCompleted) @@ -152,7 +156,7 @@ public override ValueTask ReadAsync(CancellationToken cancellationTo } return new ValueTask(new ObservableValueTask(task, _tcs), 0); - } + } public override bool TryRead(out ReadResult result) { diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs new file mode 100644 index 000000000000..b57ca2405ba4 --- /dev/null +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.cs @@ -0,0 +1,69 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +namespace InMemory.FunctionalTests; + +public class TlsListenerMiddlewareTests : TestApplicationErrorLoggerLoggedTest +{ + private static readonly X509Certificate2 _x509Certificate2 = TestResources.GetTestCertificate(); + + [Fact] + public async Task TlsClientHelloBytesCallback_InvokedAndHasTlsMessageBytes() + { + var tlsClientHelloCallbackInvoked = false; + + var testContext = new TestServiceContext(LoggerFactory); + await using (var server = new TestServer(context => Task.CompletedTask, + testContext, + listenOptions => + { + listenOptions.UseHttps(_x509Certificate2, httpsOptions => + { + httpsOptions.TlsClientHelloBytesCallback = (connection, clientHelloBytes) => + { + Logger.LogDebug("[Received TlsClientHelloBytesCallback] Connection: {0}; TLS client hello buffer: {1}", connection.ConnectionId, clientHelloBytes.Length); + tlsClientHelloCallbackInvoked = true; + Assert.True(clientHelloBytes.Length > 32); + Assert.NotNull(connection); + }; + }); + })) + { + using (var connection = server.CreateConnection()) + { + using (var sslStream = new SslStream(connection.Stream, false, (sender, cert, chain, errors) => true, null)) + { + await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions + { + TargetHost = "localhost", + EnabledSslProtocols = SslProtocols.None + }, CancellationToken.None); + + var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); + await sslStream.WriteAsync(request, 0, request.Length); + await sslStream.ReadAsync(new Memory(new byte[1024])); + } + } + } + + Assert.True(tlsClientHelloCallbackInvoked); + } +}