From 24a1224b9c160d73989d100a461872259b8adda2 Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Tue, 29 Oct 2024 12:08:43 +0000 Subject: [PATCH] Finalise implementation and test all possible scenarios --- .../Pipeline/DefaultResponseBuilder.cs | 43 +++-- .../TransportClient/HttpRequestInvoker.cs | 19 +- .../TransportClient/HttpWebRequestInvoker.cs | 17 +- .../Responses/Special/StreamResponse.cs | 13 +- .../Responses/TransportResponse.cs | 6 + .../Plumbing/InMemoryConnectionFactory.cs | 5 +- .../ResponseBuilderDisposeTests.cs | 163 ++++++++++++++++-- 7 files changed, 214 insertions(+), 52 deletions(-) diff --git a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs index 5935da4..8e6461a 100644 --- a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs +++ b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs @@ -224,11 +224,15 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, details.ResponseBodyInBytes = bytes; } - if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; + if (TrySetSpecialType(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; if (details.HttpStatusCode.HasValue && requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) + { + // In this scenario, we always dispose as we've explicitly skipped reading the response + responseStream.Dispose(); return null; + } var serializer = requestData.ConnectionSettings.RequestResponseSerializer; @@ -249,6 +253,9 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + if (!response.LeaveOpen) + responseStream.Dispose(); + return response; } @@ -260,11 +267,18 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, { response = new TResponse(); SetErrorOnResponse(response, error); + + if (!response.LeaveOpen) + responseStream.Dispose(); + return response; } if (!requestData.ValidateResponseContentType(mimeType)) + { + responseStream.Dispose(); return default; + } var beforeTicks = Stopwatch.GetTimestamp(); @@ -278,34 +292,36 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - if (!response.LeaveOpen) + if (response is null || !response.LeaveOpen) responseStream.Dispose(); return response; } catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) { + // Note that this is only thrown after a check if the stream length is zero. When the length is zero, + // `default` is returned by Deserialize(Async) instead. responseStream.Dispose(); return default; } } - private static bool SetSpecialTypes(string mimeType, byte[] bytes, Stream responseStream, - MemoryStreamFactory memoryStreamFactory, out TResponse cs) + private static bool TrySetSpecialType(string mimeType, byte[] bytes, Stream responseStream, + MemoryStreamFactory memoryStreamFactory, out TResponse response) where TResponse : TransportResponse, new() { - cs = null; + response = null; var responseType = typeof(TResponse); if (!SpecialTypes.Contains(responseType)) return false; if (responseType == typeof(StringResponse)) - cs = new StringResponse(bytes.Utf8String()) as TResponse; + response = new StringResponse(bytes.Utf8String()) as TResponse; else if (responseType == typeof(StreamResponse)) - cs = new StreamResponse(responseStream, mimeType) as TResponse; + response = new StreamResponse(responseStream, mimeType) as TResponse; else if (responseType == typeof(BytesResponse)) - cs = new BytesResponse(bytes) as TResponse; + response = new BytesResponse(bytes) as TResponse; else if (responseType == typeof(VoidResponse)) - cs = VoidResponse.Default as TResponse; + response = VoidResponse.Default as TResponse; else if (responseType == typeof(DynamicResponse)) { //if not json store the result under "body" @@ -315,17 +331,20 @@ private static bool SetSpecialTypes(string mimeType, byte[] bytes, St { ["body"] = new DynamicValue(bytes.Utf8String()) }; - cs = new DynamicResponse(dictionary) as TResponse; + response = new DynamicResponse(dictionary) as TResponse; } else { using var ms = memoryStreamFactory.Create(bytes); var body = LowLevelRequestResponseSerializer.Instance.Deserialize(ms); - cs = new DynamicResponse(body) as TResponse; + response = new DynamicResponse(body) as TResponse; } } - return cs != null; + if (!response.LeaveOpen) + responseStream.Dispose(); + + return response != null; } private static bool NeedsToEagerReadStream() diff --git a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs index 582facf..fa215a1 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs @@ -166,9 +166,17 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - // Defer disposal of the response message - if (response is StreamResponse sr) - sr.Finalizer = () => receivedResponse.Dispose(); + // Unless indicated otherwise by the TransportResponse, we've now handled the response stream, so we can dispose of the HttpResponseMessage + // to release the connection. In cases, where the derived response works directly on the stream, it can be left open and additional IDisposable + // resources can be linked such that their disposal is deferred. + if (response.LeaveOpen) + { + response.LinkedDisposables = [receivedResponse]; + } + else + { + receivedResponse.Dispose(); + } if (!OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners || (!(Activity.Current?.IsAllDataRequested ?? false))) return response; @@ -180,11 +188,6 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req foreach (var attribute in attributes) Activity.Current?.SetTag(attribute.Key, attribute.Value); - // Unless indicated otherwise by the TransportResponse, we've now handled the response stream, so we can dispose of the HttpResponseMessage - // to release the connection. - if (!response.LeaveOpen) - receivedResponse.Dispose(); - return response; } catch diff --git a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs index 11e5823..265f952 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs @@ -177,9 +177,17 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - // Defer disposal of the response message - if (response is StreamResponse sr) - sr.Finalizer = () => receivedResponse.Dispose(); + // Unless indicated otherwise by the TransportResponse, we've now handled the response stream, so we can dispose of the HttpResponseMessage + // to release the connection. In cases, where the derived response works directly on the stream, it can be left open and additional IDisposable + // resources can be linked such that their disposal is deferred. + if (response.LeaveOpen) + { + response.LinkedDisposables = [receivedResponse]; + } + else + { + receivedResponse.Dispose(); + } if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) { @@ -190,9 +198,6 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req } } - if (!response.LeaveOpen) - receivedResponse.Dispose(); - return response; } catch diff --git a/src/Elastic.Transport/Responses/Special/StreamResponse.cs b/src/Elastic.Transport/Responses/Special/StreamResponse.cs index 3cb9144..08b2de6 100644 --- a/src/Elastic.Transport/Responses/Special/StreamResponse.cs +++ b/src/Elastic.Transport/Responses/Special/StreamResponse.cs @@ -13,14 +13,10 @@ namespace Elastic.Transport; /// MUST be disposed after use to ensure the HTTP connection is freed for reuse. /// /// -public class StreamResponse : - TransportResponse, - IDisposable +public class StreamResponse : TransportResponse, IDisposable { private bool _disposed; - internal Action? Finalizer { get; set; } - /// /// The MIME type of the response, if present. /// @@ -53,7 +49,12 @@ protected virtual void Dispose(bool disposing) if (disposing) { Body.Dispose(); - Finalizer?.Invoke(); + + if (LinkedDisposables is not null) + { + foreach (var disposable in LinkedDisposables) + disposable.Dispose(); + } } _disposed = true; diff --git a/src/Elastic.Transport/Responses/TransportResponse.cs b/src/Elastic.Transport/Responses/TransportResponse.cs index 0080373..5f65983 100644 --- a/src/Elastic.Transport/Responses/TransportResponse.cs +++ b/src/Elastic.Transport/Responses/TransportResponse.cs @@ -2,6 +2,8 @@ // Elasticsearch B.V licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information +using System; +using System.Collections.Generic; using System.Text.Json.Serialization; namespace Elastic.Transport; @@ -35,6 +37,10 @@ public override string ToString() => ApiCallDetails?.DebugInformation // ReSharper disable once ConstantNullCoalescingCondition ?? $"{nameof(ApiCallDetails)} not set, likely a bug, reverting to default ToString(): {base.ToString()}"; + [JsonIgnore] + internal IEnumerable? LinkedDisposables { get; set; } + + [JsonIgnore] internal virtual bool LeaveOpen { get; } = false; } diff --git a/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs b/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs index c38bb05..f4eb21d 100644 --- a/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs +++ b/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs @@ -3,16 +3,17 @@ // See the LICENSE file in the project root for more information using System; +using Elastic.Transport.Products; namespace Elastic.Transport.Tests.Plumbing { public static class InMemoryConnectionFactory { - public static TransportConfiguration Create() + public static TransportConfiguration Create(ProductRegistration productRegistration = null) { var invoker = new InMemoryRequestInvoker(); var pool = new SingleNodePool(new Uri("http://localhost:9200")); - var settings = new TransportConfiguration(pool, invoker); + var settings = new TransportConfiguration(pool, invoker, productRegistration: productRegistration); return settings; } } diff --git a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs index dade397..c9675ec 100644 --- a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs +++ b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs @@ -5,8 +5,10 @@ using System; using System.Collections.Generic; using System.IO; +using System.Text; using System.Threading; using System.Threading.Tasks; +using Elastic.Transport.Products; using Elastic.Transport.Tests.Plumbing; using FluentAssertions; using Xunit; @@ -19,52 +21,117 @@ public class ResponseBuilderDisposeTests private readonly ITransportConfiguration _settingsDisableDirectStream = InMemoryConnectionFactory.Create().DisableDirectStreaming(); [Fact] - public async Task StreamResponseWithPotentialBody_StreamIsNotDisposed() => await AssertResponse(false, expectedDisposed: false); + public async Task StreamResponseWithPotentialBody_StreamIsNotDisposed() => + await AssertResponse(false, expectedDisposed: false); [Fact] - public async Task StreamResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsNotDisposed() => await AssertResponse(true, expectedDisposed: false); + public async Task StreamResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsNotDisposed() => + await AssertResponse(true, expectedDisposed: false); [Fact] - public async Task StreamResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(false, 204); + public async Task StreamResponseWith204StatusCode_StreamIsDisposed() => + await AssertResponse(false, 204); [Fact] - public async Task StreamResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(false, httpMethod: HttpMethod.HEAD); + public async Task StreamResponseForHeadRequest_StreamIsDisposed() => + await AssertResponse(false, httpMethod: HttpMethod.HEAD); [Fact] - public async Task StreamResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(false, contentLength: 0); + public async Task StreamResponseWithZeroContentLength_StreamIsDisposed() => + await AssertResponse(false, contentLength: 0); [Fact] - public async Task ResponseWithPotentialBody_StreamIsDisposed() => await AssertResponse(false, expectedDisposed: true); + public async Task ResponseWithPotentialBody_StreamIsDisposed() => + await AssertResponse(false, expectedDisposed: true); [Fact] - public async Task ResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => await AssertResponse(true, expectedDisposed: true); + public async Task ResponseWithPotentialBodyButInvalidMimeType_StreamIsDisposed() => + await AssertResponse(false, mimeType: "application/not-valid", expectedDisposed: true); [Fact] - public async Task ResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(false, 204); + public async Task ResponseWithPotentialBodyButSkippedStatusCode_StreamIsDisposed() => + await AssertResponse(false, skipStatusCode: 200, expectedDisposed: true); [Fact] - public async Task ResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(false, httpMethod: HttpMethod.HEAD); + public async Task ResponseWithPotentialBodyButEmptyJson_StreamIsDisposed() => + await AssertResponse(false, responseJson: " ", expectedDisposed: true); [Fact] - public async Task ResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(false, contentLength: 0); + // NOTE: The empty string here hits a fast path in STJ which returns default if the stream length is zero. + public async Task ResponseWithPotentialBodyButNullResponseDuringDeserialization_StreamIsDisposed() => + await AssertResponse(false, responseJson: "", expectedDisposed: true); [Fact] - public async Task StringResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => await AssertResponse(true, expectedDisposed: true, memoryStreamCreateExpected: 1); + public async Task ResponseWithPotentialBodyAndCustomResponseBuilder_StreamIsDisposed() => + await AssertResponse(false, customResponseBuilder: new TestCustomResponseBuilder(), expectedDisposed: true); - private async Task AssertResponse(bool disableDirectStreaming, int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, bool expectedDisposed = true, int memoryStreamCreateExpected = -1) - where T : TransportResponse, new() + [Fact] + // NOTE: We expect one memory stream factory creation when handling error responses + public async Task ResponseWithPotentialBodyAndErrorResponse_StreamIsDisposed() => + await AssertResponse(false, productRegistration: new TestProductRegistration(), expectedDisposed: true, memoryStreamCreateExpected: 1); + + [Fact] + public async Task ResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => + await AssertResponse(true, expectedDisposed: true); + + [Fact] + public async Task ResponseWith204StatusCode_StreamIsDisposed() => + await AssertResponse(false, 204); + + [Fact] + public async Task ResponseForHeadRequest_StreamIsDisposed() => + await AssertResponse(false, httpMethod: HttpMethod.HEAD); + + [Fact] + public async Task ResponseWithZeroContentLength_StreamIsDisposed() => + await AssertResponse(false, contentLength: 0); + + [Fact] + public async Task StringResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => + await AssertResponse(true, expectedDisposed: true, memoryStreamCreateExpected: 1); + + private async Task AssertResponse(bool disableDirectStreaming, int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, + bool expectedDisposed = true, string mimeType = "application/json", string responseJson = "{}", int skipStatusCode = -1, + CustomResponseBuilder customResponseBuilder = null, ProductRegistration productRegistration = null, int memoryStreamCreateExpected = -1) + where T : TransportResponse, new() { - var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings; + ITransportConfiguration config; + + if (skipStatusCode > -1 ) + { + config = InMemoryConnectionFactory.Create(productRegistration) + .DisableDirectStreaming(disableDirectStreaming) + .SkipDeserializationForStatusCodes(skipStatusCode); + } + else if (productRegistration is not null) + { + config = InMemoryConnectionFactory.Create(productRegistration) + .DisableDirectStreaming(disableDirectStreaming); + } + else + { + config = disableDirectStreaming ? _settingsDisableDirectStream : _settings; + } + var memoryStreamFactory = new TrackMemoryStreamFactory(); - var requestData = new RequestData(httpMethod, "/", null, settings, null, memoryStreamFactory, default) + var requestParams = customResponseBuilder is null ? null : + new TestRequestParameters { CustomResponseBuilder = customResponseBuilder }; + + var requestData = new RequestData(httpMethod, "/", null, config, requestParams, memoryStreamFactory, default) { Node = new Node(new Uri("http://localhost:9200")) }; var stream = new TrackDisposeStream(); - var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, contentLength, null, null); + if (!string.IsNullOrEmpty(responseJson)) + { + stream.Write(Encoding.UTF8.GetBytes(responseJson), 0, responseJson.Length); + stream.Position = 0; + } + + var response = config.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, mimeType, contentLength, null, null); response.Should().NotBeNull(); @@ -83,12 +150,12 @@ private async Task AssertResponse(bool disableDirectStreaming, int statusCode stream = new TrackDisposeStream(); var ct = new CancellationToken(); - response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, + response = await config.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, cancellationToken: ct); response.Should().NotBeNull(); - memoryStreamFactory.Created.Count.Should().Be(memoryStreamCreateExpected > -1 ? memoryStreamCreateExpected + 1: disableDirectStreaming ? 2 : 0); + memoryStreamFactory.Created.Count.Should().Be(memoryStreamCreateExpected > -1 ? memoryStreamCreateExpected + 1 : disableDirectStreaming ? 2 : 0); if (disableDirectStreaming) { var memoryStream = memoryStreamFactory.Created[0]; @@ -101,6 +168,66 @@ private async Task AssertResponse(bool disableDirectStreaming, int statusCode } } + private class TestProductRegistration : ProductRegistration + { + public override string DefaultMimeType => "application/json"; + public override string Name => "name"; + public override string ServiceIdentifier => "id"; + public override bool SupportsPing => false; + public override bool SupportsSniff => false; + public override HeadersList ResponseHeadersToParse => []; + public override MetaHeaderProvider MetaHeaderProvider => null; + public override string ProductAssemblyVersion => "0.0.0"; + public override IReadOnlyDictionary DefaultOpenTelemetryAttributes => new Dictionary(); + public override RequestData CreatePingRequestData(Node node, RequestConfiguration requestConfiguration, ITransportConfiguration global, MemoryStreamFactory memoryStreamFactory) => throw new NotImplementedException(); + public override RequestData CreateSniffRequestData(Node node, IRequestConfiguration requestConfiguration, ITransportConfiguration settings, MemoryStreamFactory memoryStreamFactory) => throw new NotImplementedException(); + public override IReadOnlyCollection DefaultHeadersToParse() => []; + public override bool HttpStatusCodeClassifier(HttpMethod method, int statusCode) => true; + public override bool NodePredicate(Node node) => throw new NotImplementedException(); + public override Dictionary ParseOpenTelemetryAttributesFromApiCallDetails(ApiCallDetails callDetails) => throw new NotImplementedException(); + public override TransportResponse Ping(IRequestInvoker requestInvoker, RequestData pingData) => throw new NotImplementedException(); + public override Task PingAsync(IRequestInvoker requestInvoker, RequestData pingData, CancellationToken cancellationToken) => throw new NotImplementedException(); + public override Tuple> Sniff(IRequestInvoker requestInvoker, bool forceSsl, RequestData requestData) => throw new NotImplementedException(); + public override Task>> SniffAsync(IRequestInvoker requestInvoker, bool forceSsl, RequestData requestData, CancellationToken cancellationToken) => throw new NotImplementedException(); + public override int SniffOrder(Node node) => throw new NotImplementedException(); + public override bool TryGetServerErrorReason(TResponse response, out string reason) => throw new NotImplementedException(); + public override ResponseBuilder ResponseBuilder => new TestErrorResponseBuilder(); + } + + private class TestError : ErrorResponse + { + public string MyError { get; set; } + + public override bool HasError() => true; + } + + private class TestErrorResponseBuilder : DefaultResponseBuilder + { + protected override void SetErrorOnResponse(TResponse response, TestError error) + { + // nothing to do in this scenario + } + + protected override bool TryGetError(ApiCallDetails apiCallDetails, RequestData requestData, Stream responseStream, out TestError error) + { + error = new TestError(); + return true; + } + + protected override bool RequiresErrorDeserialization(ApiCallDetails details, RequestData requestData) => true; + } + + private class TestCustomResponseBuilder : CustomResponseBuilder + { + public override object DeserializeResponse(Serializer serializer, ApiCallDetails response, Stream stream) => + new TestResponse { ApiCallDetails = response }; + + public override Task DeserializeResponseAsync(Serializer serializer, ApiCallDetails response, Stream stream, CancellationToken ctx = default) => + Task.FromResult(new TestResponse { ApiCallDetails = response }); + } + + private class TestRequestParameters : RequestParameters { } + private class TrackDisposeStream : MemoryStream { public TrackDisposeStream() { }