From 1f7aa5b01d4206aa4a4a1df0e19536a46a71d02c Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Fri, 11 Oct 2024 14:04:08 +0100 Subject: [PATCH 1/9] Enhance StreamResponse handling and update dependencies Updated `ResponseBuilderDefaults` to include `StreamResponse` in `SpecialTypes`. Refactored `SetBodyCoreAsync` in `DefaultResponseBuilder.cs` for readability and removed unnecessary `using` statements. Modified `RequestCoreAsync` in `HttpWebRequestInvoker.cs` and `BuildResponseAsync` in `InMemoryRequestInvoker.cs` to handle `StreamResponse` types with proper disposal. Updated `Elastic.Transport.csproj` to reference `System.Text.Json` version `8.0.5`. --- .../Pipeline/DefaultResponseBuilder.cs | 95 +++++++++---------- .../TransportClient/HttpWebRequestInvoker.cs | 36 +++---- .../TransportClient/InMemoryRequestInvoker.cs | 14 ++- 3 files changed, 76 insertions(+), 69 deletions(-) diff --git a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs index 892bd64..414bbd5 100644 --- a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs +++ b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs @@ -26,7 +26,7 @@ internal static class ResponseBuilderDefaults public static readonly Type[] SpecialTypes = { - typeof(StringResponse), typeof(BytesResponse), typeof(VoidResponse), typeof(DynamicResponse) + typeof(StringResponse), typeof(BytesResponse), typeof(VoidResponse), typeof(DynamicResponse), typeof(StreamResponse) }; } @@ -224,68 +224,65 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, details.ResponseBodyInBytes = bytes; } - using (responseStream) - { - if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; + if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; - if (details.HttpStatusCode.HasValue && - requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) - return null; + if (details.HttpStatusCode.HasValue && + requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) + return null; - var serializer = requestData.ConnectionSettings.RequestResponseSerializer; + var serializer = requestData.ConnectionSettings.RequestResponseSerializer; - TResponse response; - if (requestData.CustomResponseBuilder != null) - { - var beforeTicks = Stopwatch.GetTimestamp(); + TResponse response; + if (requestData.CustomResponseBuilder != null) + { + var beforeTicks = Stopwatch.GetTimestamp(); - if (isAsync) - response = await requestData.CustomResponseBuilder - .DeserializeResponseAsync(serializer, details, responseStream, cancellationToken) - .ConfigureAwait(false) as TResponse; - else - response = requestData.CustomResponseBuilder - .DeserializeResponse(serializer, details, responseStream) as TResponse; + if (isAsync) + response = await requestData.CustomResponseBuilder + .DeserializeResponseAsync(serializer, details, responseStream, cancellationToken) + .ConfigureAwait(false) as TResponse; + else + response = requestData.CustomResponseBuilder + .DeserializeResponse(serializer, details, responseStream) as TResponse; - var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); - if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) - Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); + if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) + Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - return response; - } + return response; + } - // TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid! - // ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying. - try + // TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid! + // ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying. + try + { + if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError()) { - if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError()) - { - response = new TResponse(); - SetErrorOnResponse(response, error); - return response; - } + response = new TResponse(); + SetErrorOnResponse(response, error); + return response; + } - if (!requestData.ValidateResponseContentType(mimeType)) - return default; + if (!requestData.ValidateResponseContentType(mimeType)) + return default; - var beforeTicks = Stopwatch.GetTimestamp(); + var beforeTicks = Stopwatch.GetTimestamp(); - if (isAsync) - response = await serializer.DeserializeAsync(responseStream, cancellationToken).ConfigureAwait(false); - else - response = serializer.Deserialize(responseStream); + if (isAsync) + response = await serializer.DeserializeAsync(responseStream, cancellationToken).ConfigureAwait(false); + else + response = serializer.Deserialize(responseStream); - var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); + var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); - if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) - Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) + Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - return response; - } - catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) - { - return default; - } + return response; + } + catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) + { + return default; } } diff --git a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs index e1c8ae7..fb138c9 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs @@ -161,28 +161,32 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req { unregisterWaitHandle?.Invoke(); } - responseStream ??= Stream.Null; - TResponse response; + var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); - if (isAsync) - response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) - .ConfigureAwait(false); - else - response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - - if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) + using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) { - var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); - foreach (var attribute in attributes) + TResponse response; + + if (isAsync) + response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) + .ConfigureAwait(false); + else + response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); + + if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) { - Activity.Current?.SetTag(attribute.Key, attribute.Value); + var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); + foreach (var attribute in attributes) + { + Activity.Current?.SetTag(attribute.Key, attribute.Value); + } } - } - return response; + return response; + } } private static Dictionary> ParseHeaders(RequestData requestData, HttpWebResponse responseMessage, Dictionary> responseHeaders) diff --git a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs index 12e1f9b..37f8eb0 100644 --- a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs @@ -109,10 +109,16 @@ public async Task BuildResponseAsync(RequestData requestDa requestData.MadeItToResponse = true; var sc = statusCode ?? _statusCode; - Stream s = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); - return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder - .ToResponseAsync(requestData, _exception, sc, _headers, s, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) + + Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); + + var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); + + using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) + { + return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder + .ToResponseAsync(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) .ConfigureAwait(false); + } } - } From f90b97466b38d1feb4cc16748e7a92e1fad62244 Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Tue, 22 Oct 2024 14:01:28 +0200 Subject: [PATCH 2/9] Update tests --- .../TransportClient/InMemoryRequestInvoker.cs | 54 ++++--- .../Http/StreamResponseTests.cs | 38 +++++ .../ResponseBuilderDisposeTests.cs | 152 +++++------------- 3 files changed, 108 insertions(+), 136 deletions(-) create mode 100644 tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs diff --git a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs index 37f8eb0..d3fa8fd 100644 --- a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs @@ -66,24 +66,32 @@ public TResponse BuildResponse(RequestData requestData, byte[] respon { var body = responseBody ?? _responseBody; var data = requestData.PostData; - if (data != null) + + if (data is not null) { - using (var stream = requestData.MemoryStreamFactory.Create()) + using var stream = requestData.MemoryStreamFactory.Create(); + if (requestData.HttpCompression) + { + using var zipStream = new GZipStream(stream, CompressionMode.Compress); + data.Write(zipStream, requestData.ConnectionSettings); + } + else { - if (requestData.HttpCompression) - { - using var zipStream = new GZipStream(stream, CompressionMode.Compress); - data.Write(zipStream, requestData.ConnectionSettings); - } - else - data.Write(stream, requestData.ConnectionSettings); + data.Write(stream, requestData.ConnectionSettings); } } requestData.MadeItToResponse = true; var sc = statusCode ?? _statusCode; - Stream s = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); - return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse(requestData, _exception, sc, _headers, s, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null); + Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); + + var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); + + using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) + { + return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder + .ToResponse(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null); + } } /// > @@ -93,17 +101,19 @@ public async Task BuildResponseAsync(RequestData requestDa { var body = responseBody ?? _responseBody; var data = requestData.PostData; - if (data != null) + + if (data is not null) { - using (var stream = requestData.MemoryStreamFactory.Create()) + using var stream = requestData.MemoryStreamFactory.Create(); + + if (requestData.HttpCompression) + { + using var zipStream = new GZipStream(stream, CompressionMode.Compress); + await data.WriteAsync(zipStream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); + } + else { - if (requestData.HttpCompression) - { - using var zipStream = new GZipStream(stream, CompressionMode.Compress); - await data.WriteAsync(zipStream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); - } - else - await data.WriteAsync(stream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); + await data.WriteAsync(stream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); } } requestData.MadeItToResponse = true; @@ -117,8 +127,8 @@ public async Task BuildResponseAsync(RequestData requestDa using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) { return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder - .ToResponseAsync(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) - .ConfigureAwait(false); + .ToResponseAsync(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) + .ConfigureAwait(false); } } } diff --git a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs new file mode 100644 index 0000000..57333bd --- /dev/null +++ b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs @@ -0,0 +1,38 @@ +// Licensed to Elasticsearch B.V under one or more agreements. +// Elasticsearch B.V licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information + +using System.IO; +using System.Text.Json; +using System.Threading.Tasks; +using Elastic.Transport.IntegrationTests.Plumbing; +using Elastic.Transport.Products.Elasticsearch; +using Microsoft.AspNetCore.Mvc; +using Xunit; + +namespace Elastic.Transport.IntegrationTests.Http; + +public class StreamResponseTests(TransportTestServer instance) : AssemblyServerTestsBase(instance) +{ + private const string Path = "/streamresponse"; + + [Fact] + public async Task StreamResponse_ShouldNotBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))); + var transport = new DistributedTransport(config); + + var response = await transport.PostAsync(Path, PostData.String("{}")); + + var sr = new StreamReader(response.Body); + var responseString = sr.ReadToEndAsync(); + } +} + +[ApiController, Route("[controller]")] +public class StreamResponseController : ControllerBase +{ + [HttpPost] + public Task Post([FromBody] JsonElement body) => Task.FromResult(body); +} diff --git a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs index db5f3e4..551b460 100644 --- a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs +++ b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs @@ -3,146 +3,70 @@ // See the LICENSE file in the project root for more information using System; -using System.Collections.Generic; using System.IO; -using System.Linq; using System.Threading; using System.Threading.Tasks; using Elastic.Transport.Tests.Plumbing; using FluentAssertions; using Xunit; -namespace Elastic.Transport.Tests +namespace Elastic.Transport.Tests; + +public class ResponseBuilderDisposeTests { - public class ResponseBuilderDisposeTests - { - private readonly ITransportConfiguration _settings = InMemoryConnectionFactory.Create().DisableDirectStreaming(false); - private readonly ITransportConfiguration _settingsDisableDirectStream = InMemoryConnectionFactory.Create().DisableDirectStreaming(); + private readonly ITransportConfiguration _settings = InMemoryConnectionFactory.Create().DisableDirectStreaming(false); - [Fact] public async Task ResponseWithHttpStatusCode() => await AssertRegularResponse(false, 1); + [Fact] + public async Task ResponseWithPotentialBody_StreamIsNotDisposed() => await AssertResponse(expectedDisposed: false); - [Fact] public async Task ResponseBuilderWithNoHttpStatusCode() => await AssertRegularResponse(false); + [Fact] + public async Task ResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(204); - [Fact] public async Task ResponseWithHttpStatusCodeDisableDirectStreaming() => - await AssertRegularResponse(true, 1); + [Fact] + public async Task ResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(httpMethod: HttpMethod.HEAD); - [Fact] public async Task ResponseBuilderWithNoHttpStatusCodeDisableDirectStreaming() => - await AssertRegularResponse(true); + [Fact] + public async Task ResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(contentLength: 0); - private async Task AssertRegularResponse(bool disableDirectStreaming, int? statusCode = null) + private async Task AssertResponse(int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, bool expectedDisposed = true) + { + var settings = _settings; + var requestData = new RequestData(httpMethod, "/", null, settings, null, null, null, default) { - var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings; - var memoryStreamFactory = new TrackMemoryStreamFactory(); - var requestData = new RequestData(HttpMethod.GET, "/", null, settings, null, null, memoryStreamFactory, default) - { - Node = new Node(new Uri("http://localhost:9200")) - }; - - var stream = new TrackDisposeStream(); - var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, -1, null, null); - response.Should().NotBeNull(); - - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 1 : 0); - if (disableDirectStreaming) - { - var memoryStream = memoryStreamFactory.Created[0]; - memoryStream.IsDisposed.Should().BeTrue(); - } - stream.IsDisposed.Should().BeTrue(); - - - stream = new TrackDisposeStream(); - var ct = new CancellationToken(); - response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, -1, null, null, - cancellationToken: ct); - response.Should().NotBeNull(); - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 2 : 0); - if (disableDirectStreaming) - { - var memoryStream = memoryStreamFactory.Created[1]; - memoryStream.IsDisposed.Should().BeTrue(); - } - stream.IsDisposed.Should().BeTrue(); - } + Node = new Node(new Uri("http://localhost:9200")) + }; - [Fact] public async Task StreamResponseWithHttpStatusCode() => await AssertStreamResponse(false, 200); + var stream = new TrackDisposeStream(); - [Fact] public async Task StreamResponseBuilderWithNoHttpStatusCode() => await AssertStreamResponse(false); + var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, contentLength, null, null); - [Fact] public async Task StreamResponseWithHttpStatusCodeDisableDirectStreaming() => - await AssertStreamResponse(true, 1); + response.Should().NotBeNull(); + stream.IsDisposed.Should().Be(expectedDisposed); - [Fact] public async Task StreamResponseBuilderWithNoHttpStatusCodeDisableDirectStreaming() => - await AssertStreamResponse(true); + stream = new TrackDisposeStream(); + var ct = new CancellationToken(); - private async Task AssertStreamResponse(bool disableDirectStreaming, int? statusCode = null) - { - var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings; - var memoryStreamFactory = new TrackMemoryStreamFactory(); - - var requestData = new RequestData(HttpMethod.GET, "/", null, settings, null, null, memoryStreamFactory, default) - { - Node = new Node(new Uri("http://localhost:9200")) - }; - - var stream = new TrackDisposeStream(); - var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, -1, null, null); - response.Should().NotBeNull(); - - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 1 : 0); - stream.IsDisposed.Should().Be(true); - - stream = new TrackDisposeStream(); - var ct = new CancellationToken(); - response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, -1, null, null, - cancellationToken: ct); - response.Should().NotBeNull(); - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 2 : 0); - stream.IsDisposed.Should().Be(true); - } + response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, + cancellationToken: ct); + response.Should().NotBeNull(); + stream.IsDisposed.Should().Be(expectedDisposed); + } - private class TrackDisposeStream : MemoryStream - { - public TrackDisposeStream() { } - - public TrackDisposeStream(byte[] bytes) : base(bytes) { } + private class TrackDisposeStream : MemoryStream + { + public TrackDisposeStream() { } - public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { } + public TrackDisposeStream(byte[] bytes) : base(bytes) { } - public bool IsDisposed { get; private set; } + public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { } - protected override void Dispose(bool disposing) - { - IsDisposed = true; - base.Dispose(disposing); - } - } + public bool IsDisposed { get; private set; } - private class TrackMemoryStreamFactory : MemoryStreamFactory + protected override void Dispose(bool disposing) { - public IList Created { get; } = new List(); - - public override MemoryStream Create() - { - var stream = new TrackDisposeStream(); - Created.Add(stream); - return stream; - } - - public override MemoryStream Create(byte[] bytes) - { - var stream = new TrackDisposeStream(bytes); - Created.Add(stream); - return stream; - } - - public override MemoryStream Create(byte[] bytes, int index, int count) - { - var stream = new TrackDisposeStream(bytes, index, count); - Created.Add(stream); - return stream; - } + IsDisposed = true; + base.Dispose(disposing); } } } From e04578a594a3486d5c3a2fd67fe2da87a6f23d7c Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Mon, 28 Oct 2024 14:33:46 +0000 Subject: [PATCH 3/9] Improve stream handling and disposal in transport layer --- .../Pipeline/DefaultResponseBuilder.cs | 93 +++++++++--------- .../TransportClient/HttpRequestInvoker.cs | 1 - .../TransportClient/HttpWebRequestInvoker.cs | 35 +++---- .../TransportClient/InMemoryRequestInvoker.cs | 20 +--- .../Responses/Special/StreamResponse.cs | 32 ++++++- .../Http/StreamResponseTests.cs | 87 ++++++++++++++++- .../ResponseBuilderDisposeTests.cs | 96 ++++++++++++++++--- 7 files changed, 266 insertions(+), 98 deletions(-) diff --git a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs index 414bbd5..86352e9 100644 --- a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs +++ b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs @@ -224,65 +224,70 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, details.ResponseBodyInBytes = bytes; } - if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; + var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); - if (details.HttpStatusCode.HasValue && - requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) - return null; + using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) + { + if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; - var serializer = requestData.ConnectionSettings.RequestResponseSerializer; + if (details.HttpStatusCode.HasValue && + requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) + return null; - TResponse response; - if (requestData.CustomResponseBuilder != null) - { - var beforeTicks = Stopwatch.GetTimestamp(); + var serializer = requestData.ConnectionSettings.RequestResponseSerializer; - if (isAsync) - response = await requestData.CustomResponseBuilder - .DeserializeResponseAsync(serializer, details, responseStream, cancellationToken) - .ConfigureAwait(false) as TResponse; - else - response = requestData.CustomResponseBuilder - .DeserializeResponse(serializer, details, responseStream) as TResponse; + TResponse response; + if (requestData.CustomResponseBuilder != null) + { + var beforeTicks = Stopwatch.GetTimestamp(); - var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); - if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) - Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + if (isAsync) + response = await requestData.CustomResponseBuilder + .DeserializeResponseAsync(serializer, details, responseStream, cancellationToken) + .ConfigureAwait(false) as TResponse; + else + response = requestData.CustomResponseBuilder + .DeserializeResponse(serializer, details, responseStream) as TResponse; - return response; - } + var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); + if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) + Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - // TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid! - // ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying. - try - { - if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError()) - { - response = new TResponse(); - SetErrorOnResponse(response, error); return response; } - if (!requestData.ValidateResponseContentType(mimeType)) - return default; + // TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid! + // ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying. + try + { + if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError()) + { + response = new TResponse(); + SetErrorOnResponse(response, error); + return response; + } - var beforeTicks = Stopwatch.GetTimestamp(); + if (!requestData.ValidateResponseContentType(mimeType)) + return default; - if (isAsync) - response = await serializer.DeserializeAsync(responseStream, cancellationToken).ConfigureAwait(false); - else - response = serializer.Deserialize(responseStream); + var beforeTicks = Stopwatch.GetTimestamp(); - var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); + if (isAsync) + response = await serializer.DeserializeAsync(responseStream, cancellationToken).ConfigureAwait(false); + else + response = serializer.Deserialize(responseStream); - if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) - Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); - return response; - } - catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) - { - return default; + if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) + Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + + return response; + } + catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) + { + return default; + } } } diff --git a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs index a466b78..e58df63 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs @@ -157,7 +157,6 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); using (isStreamResponse ? DiagnosticSources.SingletonDisposable : receive) - using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) { TResponse response; diff --git a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs index fb138c9..4e5b859 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs @@ -162,31 +162,26 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req unregisterWaitHandle?.Invoke(); } - var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); + TResponse response; - using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) - { - TResponse response; - - if (isAsync) - response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) - .ConfigureAwait(false); - else - response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); + if (isAsync) + response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) + .ConfigureAwait(false); + else + response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) + if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) + { + var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); + foreach (var attribute in attributes) { - var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); - foreach (var attribute in attributes) - { - Activity.Current?.SetTag(attribute.Key, attribute.Value); - } + Activity.Current?.SetTag(attribute.Key, attribute.Value); } - - return response; } + + return response; } private static Dictionary> ParseHeaders(RequestData requestData, HttpWebResponse responseMessage, Dictionary> responseHeaders) diff --git a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs index d3fa8fd..aead20a 100644 --- a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs @@ -85,13 +85,8 @@ public TResponse BuildResponse(RequestData requestData, byte[] respon var sc = statusCode ?? _statusCode; Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); - var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); - - using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) - { - return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder - .ToResponse(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null); - } + return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder + .ToResponse(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null); } /// > @@ -122,13 +117,8 @@ public async Task BuildResponseAsync(RequestData requestDa Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); - var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); - - using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) - { - return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder - .ToResponseAsync(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) - .ConfigureAwait(false); - } + return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder + .ToResponseAsync(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) + .ConfigureAwait(false); } } diff --git a/src/Elastic.Transport/Responses/Special/StreamResponse.cs b/src/Elastic.Transport/Responses/Special/StreamResponse.cs index 53dd22a..d920a06 100644 --- a/src/Elastic.Transport/Responses/Special/StreamResponse.cs +++ b/src/Elastic.Transport/Responses/Special/StreamResponse.cs @@ -10,13 +10,15 @@ namespace Elastic.Transport; /// /// A response that exposes the response as . /// -/// Must be disposed after use. +/// MUST be disposed after use to ensure the HTTP connection is freed for reuse. /// /// -public sealed class StreamResponse : +public class StreamResponse : TransportResponse, IDisposable { + private bool _disposed; + internal Action? Finalizer { get; set; } /// @@ -38,10 +40,30 @@ public StreamResponse(Stream body, string? mimeType) MimeType = mimeType ?? string.Empty; } - /// + /// + /// Disposes the underlying stream. + /// + /// + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + Body.Dispose(); + Finalizer?.Invoke(); + } + + _disposed = true; + } + } + + /// + /// Disposes the underlying stream. + /// public void Dispose() { - Body.Dispose(); - Finalizer?.Invoke(); + Dispose(disposing: true); + GC.SuppressFinalize(this); } } diff --git a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs index 57333bd..e006567 100644 --- a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs +++ b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs @@ -2,11 +2,14 @@ // Elasticsearch B.V licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information +using System.Collections.Generic; using System.IO; +using System.Linq; using System.Text.Json; using System.Threading.Tasks; using Elastic.Transport.IntegrationTests.Plumbing; using Elastic.Transport.Products.Elasticsearch; +using FluentAssertions; using Microsoft.AspNetCore.Mvc; using Xunit; @@ -25,8 +28,88 @@ public async Task StreamResponse_ShouldNotBeDisposed() var response = await transport.PostAsync(Path, PostData.String("{}")); - var sr = new StreamReader(response.Body); - var responseString = sr.ReadToEndAsync(); + // Ensure the stream is readable + using var sr = new StreamReader(response.Body); + _ = sr.ReadToEndAsync(); + } + + [Fact] + public async Task StreamResponse_MemoryStreamShouldNotBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory) + .DisableDirectStreaming(true); + + var transport = new DistributedTransport(config); + + _ = await transport.PostAsync(Path, PostData.String("{}")); + + var memoryStream = memoryStreamFactory.Created.Last(); + + memoryStream.IsDisposed.Should().BeFalse(); + } + + [Fact] + public async Task StringResponse_MemoryStreamShouldBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory) + .DisableDirectStreaming(true); + + var transport = new DistributedTransport(config); + + _ = await transport.PostAsync(Path, PostData.String("{}")); + + var memoryStream = memoryStreamFactory.Created.Last(); + + memoryStream.IsDisposed.Should().BeTrue(); + } + + private class TrackDisposeStream : MemoryStream + { + public TrackDisposeStream() { } + + public TrackDisposeStream(byte[] bytes) : base(bytes) { } + + public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { } + + public bool IsDisposed { get; private set; } + + protected override void Dispose(bool disposing) + { + IsDisposed = true; + base.Dispose(disposing); + } + } + + private class TrackMemoryStreamFactory : MemoryStreamFactory + { + public IList Created { get; } = []; + + public override MemoryStream Create() + { + var stream = new TrackDisposeStream(); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes) + { + var stream = new TrackDisposeStream(bytes); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes, int index, int count) + { + var stream = new TrackDisposeStream(bytes, index, count); + Created.Add(stream); + return stream; + } } } diff --git a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs index 551b460..de9fa2e 100644 --- a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs +++ b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information using System; +using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -15,42 +16,89 @@ namespace Elastic.Transport.Tests; public class ResponseBuilderDisposeTests { private readonly ITransportConfiguration _settings = InMemoryConnectionFactory.Create().DisableDirectStreaming(false); + private readonly ITransportConfiguration _settingsDisableDirectStream = InMemoryConnectionFactory.Create().DisableDirectStreaming(); [Fact] - public async Task ResponseWithPotentialBody_StreamIsNotDisposed() => await AssertResponse(expectedDisposed: false); + public async Task StreamResponseWithPotentialBody_StreamIsNotDisposed() => await AssertResponse(false, expectedDisposed: false); [Fact] - public async Task ResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(204); + public async Task StreamResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsNotDisposed() => await AssertResponse(true, expectedDisposed: false); [Fact] - public async Task ResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(httpMethod: HttpMethod.HEAD); + public async Task StreamResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(false, 204); [Fact] - public async Task ResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(contentLength: 0); + public async Task StreamResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(false, httpMethod: HttpMethod.HEAD); - private async Task AssertResponse(int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, bool expectedDisposed = true) + [Fact] + public async Task StreamResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(false, contentLength: 0); + + [Fact] + public async Task ResponseWithPotentialBody_StreamIsDisposed() => await AssertResponse(false, expectedDisposed: true); + + [Fact] + public async Task ResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => await AssertResponse(true, expectedDisposed: true); + + [Fact] + public async Task ResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(false, 204); + + [Fact] + public async Task ResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(false, httpMethod: HttpMethod.HEAD); + + [Fact] + public async Task ResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(false, contentLength: 0); + + [Fact] + public async Task StringResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => await AssertResponse(true, expectedDisposed: true, memoryStreamCreateExpected: 1); + + private async Task AssertResponse(bool disableDirectStreaming, int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, bool expectedDisposed = true, int memoryStreamCreateExpected = -1) + where T : TransportResponse, new() { - var settings = _settings; - var requestData = new RequestData(httpMethod, "/", null, settings, null, null, null, default) + var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings; + var memoryStreamFactory = new TrackMemoryStreamFactory(); + + var requestData = new RequestData(httpMethod, "/", null, settings, null, null, memoryStreamFactory, default) { Node = new Node(new Uri("http://localhost:9200")) }; var stream = new TrackDisposeStream(); - var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, contentLength, null, null); + var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, contentLength, null, null); response.Should().NotBeNull(); - stream.IsDisposed.Should().Be(expectedDisposed); + + memoryStreamFactory.Created.Count.Should().Be(memoryStreamCreateExpected > -1 ? memoryStreamCreateExpected : disableDirectStreaming ? 1 : 0); + if (disableDirectStreaming) + { + var memoryStream = memoryStreamFactory.Created[0]; + stream.IsDisposed.Should().BeTrue(); + memoryStream.IsDisposed.Should().Be(expectedDisposed); + } + else + { + stream.IsDisposed.Should().Be(expectedDisposed); + } stream = new TrackDisposeStream(); var ct = new CancellationToken(); - response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, + response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, cancellationToken: ct); response.Should().NotBeNull(); - stream.IsDisposed.Should().Be(expectedDisposed); + + memoryStreamFactory.Created.Count.Should().Be(memoryStreamCreateExpected > -1 ? memoryStreamCreateExpected + 1: disableDirectStreaming ? 2 : 0); + if (disableDirectStreaming) + { + var memoryStream = memoryStreamFactory.Created[0]; + stream.IsDisposed.Should().BeTrue(); + memoryStream.IsDisposed.Should().Be(expectedDisposed); + } + else + { + stream.IsDisposed.Should().Be(expectedDisposed); + } } private class TrackDisposeStream : MemoryStream @@ -69,4 +117,30 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } } + + private class TrackMemoryStreamFactory : MemoryStreamFactory + { + public IList Created { get; } = []; + + public override MemoryStream Create() + { + var stream = new TrackDisposeStream(); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes) + { + var stream = new TrackDisposeStream(bytes); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes, int index, int count) + { + var stream = new TrackDisposeStream(bytes, index, count); + Created.Add(stream); + return stream; + } + } } From 00dcca4dc617c16da994d5495e0a967c7fbcbfa1 Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Mon, 28 Oct 2024 15:02:06 +0000 Subject: [PATCH 4/9] Refactor variable and add proper disposal handling Renamed variable from `receive` to `receivedResponse` in `RequestCoreAsync` method across `HttpRequestInvoker.cs` and `HttpWebRequestInvoker.cs` for better clarity. Updated `using` statements accordingly. Added `try-catch` blocks around response handling logic to ensure `receivedResponse` is disposed of properly, preventing resource leaks. Adjusted finalizer in `StreamResponse` to use `receivedResponse.Dispose()`. Properly indented OpenTelemetry attributes setting logic within the `try` block. --- .../TransportClient/HttpRequestInvoker.cs | 48 +++++++++++-------- .../TransportClient/HttpWebRequestInvoker.cs | 45 +++++++++++------ 2 files changed, 58 insertions(+), 35 deletions(-) diff --git a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs index e58df63..38c1554 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs @@ -75,7 +75,7 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req Exception ex = null; string mimeType = null; long contentLength = -1; - IDisposable receive = DiagnosticSources.SingletonDisposable; + IDisposable receivedResponse = DiagnosticSources.SingletonDisposable; ReadOnlyDictionary tcpStats = null; ReadOnlyDictionary threadPoolStats = null; Dictionary> responseHeaders = null; @@ -118,7 +118,7 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req responseMessage = client.SendAsync(requestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).GetAwaiter().GetResult(); #endif - receive = responseMessage; + receivedResponse = responseMessage; statusCode = (int)responseMessage.StatusCode; } @@ -156,33 +156,41 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); - using (isStreamResponse ? DiagnosticSources.SingletonDisposable : receive) + using (isStreamResponse ? DiagnosticSources.SingletonDisposable : receivedResponse) { TResponse response; - if (isAsync) - response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) - .ConfigureAwait(false); - else - response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); + try + { + if (isAsync) + response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) + .ConfigureAwait(false); + else + response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - // Defer disposal of the response message - if (response is StreamResponse sr) - sr.Finalizer = () => receive.Dispose(); + // Defer disposal of the response message + if (response is StreamResponse sr) + sr.Finalizer = () => receivedResponse.Dispose(); - if (!OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners || (!(Activity.Current?.IsAllDataRequested ?? false))) - return response; + if (!OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners || (!(Activity.Current?.IsAllDataRequested ?? false))) + return response; - var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); + var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); - if (attributes is null) return response; + if (attributes is null) return response; - foreach (var attribute in attributes) - Activity.Current?.SetTag(attribute.Key, attribute.Value); + foreach (var attribute in attributes) + Activity.Current?.SetTag(attribute.Key, attribute.Value); - return response; + return response; + } + catch + { + receivedResponse.Dispose(); // if there's an exception, ensure we always release the response so that the connection is freed. + throw; + } } } diff --git a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs index 4e5b859..adaaab6 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs @@ -68,6 +68,7 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req Exception ex = null; string mimeType = null; long contentLength = -1; + IDisposable receivedResponse = DiagnosticSources.SingletonDisposable; ReadOnlyDictionary tcpStats = null; ReadOnlyDictionary threadPoolStats = null; Dictionary> responseHeaders = null; @@ -146,6 +147,8 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req httpWebResponse = (HttpWebResponse)request.GetResponse(); } + receivedResponse = httpWebResponse; + HandleResponse(httpWebResponse, out statusCode, out responseStream, out mimeType); responseHeaders = ParseHeaders(requestData, httpWebResponse, responseHeaders); contentLength = httpWebResponse.ContentLength; @@ -162,26 +165,38 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req unregisterWaitHandle?.Invoke(); } - TResponse response; + try + { + TResponse response; + + if (isAsync) + response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) + .ConfigureAwait(false); + else + response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - if (isAsync) - response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) - .ConfigureAwait(false); - else - response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); + // Defer disposal of the response message + if (response is StreamResponse sr) + sr.Finalizer = () => receivedResponse.Dispose(); - if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) - { - var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); - foreach (var attribute in attributes) + if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) { - Activity.Current?.SetTag(attribute.Key, attribute.Value); + var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); + foreach (var attribute in attributes) + { + Activity.Current?.SetTag(attribute.Key, attribute.Value); + } } - } - return response; + return response; + } + catch + { + receivedResponse.Dispose(); // ensure we always release the response so the connection is freed. + throw; + } } private static Dictionary> ParseHeaders(RequestData requestData, HttpWebResponse responseMessage, Dictionary> responseHeaders) From d09925698937932d95adc2772ebc7b5cf37a6057 Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Tue, 29 Oct 2024 09:39:03 +0000 Subject: [PATCH 5/9] Improve response handling and add LeaveOpen property Add `LeaveOpen` property to `StreamResponse` and `TransportResponse` classes for flexible stream management. --- .../Pipeline/DefaultResponseBuilder.cs | 95 +++++++++---------- .../TransportClient/HttpRequestInvoker.cs | 56 +++++------ .../TransportClient/HttpWebRequestInvoker.cs | 3 + .../Responses/Special/StreamResponse.cs | 2 + .../Responses/TransportResponse.cs | 4 +- 5 files changed, 83 insertions(+), 77 deletions(-) diff --git a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs index 86352e9..5935da4 100644 --- a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs +++ b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs @@ -224,70 +224,69 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, details.ResponseBodyInBytes = bytes; } - var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); + if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; - using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) - { - if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; + if (details.HttpStatusCode.HasValue && + requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) + return null; - if (details.HttpStatusCode.HasValue && - requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) - return null; + var serializer = requestData.ConnectionSettings.RequestResponseSerializer; - var serializer = requestData.ConnectionSettings.RequestResponseSerializer; + TResponse response; + if (requestData.CustomResponseBuilder != null) + { + var beforeTicks = Stopwatch.GetTimestamp(); - TResponse response; - if (requestData.CustomResponseBuilder != null) - { - var beforeTicks = Stopwatch.GetTimestamp(); + if (isAsync) + response = await requestData.CustomResponseBuilder + .DeserializeResponseAsync(serializer, details, responseStream, cancellationToken) + .ConfigureAwait(false) as TResponse; + else + response = requestData.CustomResponseBuilder + .DeserializeResponse(serializer, details, responseStream) as TResponse; - if (isAsync) - response = await requestData.CustomResponseBuilder - .DeserializeResponseAsync(serializer, details, responseStream, cancellationToken) - .ConfigureAwait(false) as TResponse; - else - response = requestData.CustomResponseBuilder - .DeserializeResponse(serializer, details, responseStream) as TResponse; + var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); + if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) + Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); - if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) - Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + return response; + } + // TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid! + // ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying. + try + { + if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError()) + { + response = new TResponse(); + SetErrorOnResponse(response, error); return response; } - // TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid! - // ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying. - try - { - if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError()) - { - response = new TResponse(); - SetErrorOnResponse(response, error); - return response; - } + if (!requestData.ValidateResponseContentType(mimeType)) + return default; - if (!requestData.ValidateResponseContentType(mimeType)) - return default; + var beforeTicks = Stopwatch.GetTimestamp(); - var beforeTicks = Stopwatch.GetTimestamp(); + if (isAsync) + response = await serializer.DeserializeAsync(responseStream, cancellationToken).ConfigureAwait(false); + else + response = serializer.Deserialize(responseStream); - if (isAsync) - response = await serializer.DeserializeAsync(responseStream, cancellationToken).ConfigureAwait(false); - else - response = serializer.Deserialize(responseStream); + var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); - var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); + if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) + Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) - Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + if (!response.LeaveOpen) + responseStream.Dispose(); - return response; - } - catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) - { - return default; - } + return response; + } + catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) + { + responseStream.Dispose(); + return default; } } diff --git a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs index 38c1554..582facf 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs @@ -154,43 +154,43 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req ex = e; } - var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); + TResponse response; - using (isStreamResponse ? DiagnosticSources.SingletonDisposable : receivedResponse) + try { - TResponse response; + if (isAsync) + response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) + .ConfigureAwait(false); + else + response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - try - { - if (isAsync) - response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) - .ConfigureAwait(false); - else - response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); + // Defer disposal of the response message + if (response is StreamResponse sr) + sr.Finalizer = () => receivedResponse.Dispose(); - // Defer disposal of the response message - if (response is StreamResponse sr) - sr.Finalizer = () => receivedResponse.Dispose(); + if (!OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners || (!(Activity.Current?.IsAllDataRequested ?? false))) + return response; - if (!OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners || (!(Activity.Current?.IsAllDataRequested ?? false))) - return response; + var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); - var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); + if (attributes is null) return response; - if (attributes is null) return response; + foreach (var attribute in attributes) + Activity.Current?.SetTag(attribute.Key, attribute.Value); - foreach (var attribute in attributes) - Activity.Current?.SetTag(attribute.Key, attribute.Value); + // Unless indicated otherwise by the TransportResponse, we've now handled the response stream, so we can dispose of the HttpResponseMessage + // to release the connection. + if (!response.LeaveOpen) + receivedResponse.Dispose(); - return response; - } - catch - { - receivedResponse.Dispose(); // if there's an exception, ensure we always release the response so that the connection is freed. - throw; - } + return response; + } + catch + { + receivedResponse.Dispose(); // if there's an exception, ensure we always release the response so that the connection is freed. + throw; } } diff --git a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs index adaaab6..11e5823 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs @@ -190,6 +190,9 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req } } + if (!response.LeaveOpen) + receivedResponse.Dispose(); + return response; } catch diff --git a/src/Elastic.Transport/Responses/Special/StreamResponse.cs b/src/Elastic.Transport/Responses/Special/StreamResponse.cs index d920a06..3cb9144 100644 --- a/src/Elastic.Transport/Responses/Special/StreamResponse.cs +++ b/src/Elastic.Transport/Responses/Special/StreamResponse.cs @@ -40,6 +40,8 @@ public StreamResponse(Stream body, string? mimeType) MimeType = mimeType ?? string.Empty; } + internal override bool LeaveOpen => true; + /// /// Disposes the underlying stream. /// diff --git a/src/Elastic.Transport/Responses/TransportResponse.cs b/src/Elastic.Transport/Responses/TransportResponse.cs index e74cece..0080373 100644 --- a/src/Elastic.Transport/Responses/TransportResponse.cs +++ b/src/Elastic.Transport/Responses/TransportResponse.cs @@ -8,7 +8,7 @@ namespace Elastic.Transport; /// /// A response from an Elastic product including details about the request/response life cycle. Base class for the built in low level response -/// types, , , and +/// types, , , , and /// public abstract class TransportResponse : TransportResponse { @@ -34,5 +34,7 @@ public abstract class TransportResponse public override string ToString() => ApiCallDetails?.DebugInformation // ReSharper disable once ConstantNullCoalescingCondition ?? $"{nameof(ApiCallDetails)} not set, likely a bug, reverting to default ToString(): {base.ToString()}"; + + internal virtual bool LeaveOpen { get; } = false; } From 403147507b1abb5e2c67eda1e88b0d52b8f6719b Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Tue, 29 Oct 2024 12:08:43 +0000 Subject: [PATCH 6/9] Finalise implementation and test all possible scenarios --- .../Pipeline/DefaultResponseBuilder.cs | 43 +++-- .../TransportClient/HttpRequestInvoker.cs | 19 +- .../TransportClient/HttpWebRequestInvoker.cs | 17 +- .../Responses/Special/StreamResponse.cs | 13 +- .../Responses/TransportResponse.cs | 6 + .../Plumbing/InMemoryConnectionFactory.cs | 5 +- .../ResponseBuilderDisposeTests.cs | 162 ++++++++++++++++-- 7 files changed, 212 insertions(+), 53 deletions(-) diff --git a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs index 5935da4..8e6461a 100644 --- a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs +++ b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs @@ -224,11 +224,15 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, details.ResponseBodyInBytes = bytes; } - if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; + if (TrySetSpecialType(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; if (details.HttpStatusCode.HasValue && requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) + { + // In this scenario, we always dispose as we've explicitly skipped reading the response + responseStream.Dispose(); return null; + } var serializer = requestData.ConnectionSettings.RequestResponseSerializer; @@ -249,6 +253,9 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + if (!response.LeaveOpen) + responseStream.Dispose(); + return response; } @@ -260,11 +267,18 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, { response = new TResponse(); SetErrorOnResponse(response, error); + + if (!response.LeaveOpen) + responseStream.Dispose(); + return response; } if (!requestData.ValidateResponseContentType(mimeType)) + { + responseStream.Dispose(); return default; + } var beforeTicks = Stopwatch.GetTimestamp(); @@ -278,34 +292,36 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - if (!response.LeaveOpen) + if (response is null || !response.LeaveOpen) responseStream.Dispose(); return response; } catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) { + // Note that this is only thrown after a check if the stream length is zero. When the length is zero, + // `default` is returned by Deserialize(Async) instead. responseStream.Dispose(); return default; } } - private static bool SetSpecialTypes(string mimeType, byte[] bytes, Stream responseStream, - MemoryStreamFactory memoryStreamFactory, out TResponse cs) + private static bool TrySetSpecialType(string mimeType, byte[] bytes, Stream responseStream, + MemoryStreamFactory memoryStreamFactory, out TResponse response) where TResponse : TransportResponse, new() { - cs = null; + response = null; var responseType = typeof(TResponse); if (!SpecialTypes.Contains(responseType)) return false; if (responseType == typeof(StringResponse)) - cs = new StringResponse(bytes.Utf8String()) as TResponse; + response = new StringResponse(bytes.Utf8String()) as TResponse; else if (responseType == typeof(StreamResponse)) - cs = new StreamResponse(responseStream, mimeType) as TResponse; + response = new StreamResponse(responseStream, mimeType) as TResponse; else if (responseType == typeof(BytesResponse)) - cs = new BytesResponse(bytes) as TResponse; + response = new BytesResponse(bytes) as TResponse; else if (responseType == typeof(VoidResponse)) - cs = VoidResponse.Default as TResponse; + response = VoidResponse.Default as TResponse; else if (responseType == typeof(DynamicResponse)) { //if not json store the result under "body" @@ -315,17 +331,20 @@ private static bool SetSpecialTypes(string mimeType, byte[] bytes, St { ["body"] = new DynamicValue(bytes.Utf8String()) }; - cs = new DynamicResponse(dictionary) as TResponse; + response = new DynamicResponse(dictionary) as TResponse; } else { using var ms = memoryStreamFactory.Create(bytes); var body = LowLevelRequestResponseSerializer.Instance.Deserialize(ms); - cs = new DynamicResponse(body) as TResponse; + response = new DynamicResponse(body) as TResponse; } } - return cs != null; + if (!response.LeaveOpen) + responseStream.Dispose(); + + return response != null; } private static bool NeedsToEagerReadStream() diff --git a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs index 582facf..fa215a1 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs @@ -166,9 +166,17 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - // Defer disposal of the response message - if (response is StreamResponse sr) - sr.Finalizer = () => receivedResponse.Dispose(); + // Unless indicated otherwise by the TransportResponse, we've now handled the response stream, so we can dispose of the HttpResponseMessage + // to release the connection. In cases, where the derived response works directly on the stream, it can be left open and additional IDisposable + // resources can be linked such that their disposal is deferred. + if (response.LeaveOpen) + { + response.LinkedDisposables = [receivedResponse]; + } + else + { + receivedResponse.Dispose(); + } if (!OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners || (!(Activity.Current?.IsAllDataRequested ?? false))) return response; @@ -180,11 +188,6 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req foreach (var attribute in attributes) Activity.Current?.SetTag(attribute.Key, attribute.Value); - // Unless indicated otherwise by the TransportResponse, we've now handled the response stream, so we can dispose of the HttpResponseMessage - // to release the connection. - if (!response.LeaveOpen) - receivedResponse.Dispose(); - return response; } catch diff --git a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs index 11e5823..265f952 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs @@ -177,9 +177,17 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - // Defer disposal of the response message - if (response is StreamResponse sr) - sr.Finalizer = () => receivedResponse.Dispose(); + // Unless indicated otherwise by the TransportResponse, we've now handled the response stream, so we can dispose of the HttpResponseMessage + // to release the connection. In cases, where the derived response works directly on the stream, it can be left open and additional IDisposable + // resources can be linked such that their disposal is deferred. + if (response.LeaveOpen) + { + response.LinkedDisposables = [receivedResponse]; + } + else + { + receivedResponse.Dispose(); + } if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) { @@ -190,9 +198,6 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req } } - if (!response.LeaveOpen) - receivedResponse.Dispose(); - return response; } catch diff --git a/src/Elastic.Transport/Responses/Special/StreamResponse.cs b/src/Elastic.Transport/Responses/Special/StreamResponse.cs index 3cb9144..08b2de6 100644 --- a/src/Elastic.Transport/Responses/Special/StreamResponse.cs +++ b/src/Elastic.Transport/Responses/Special/StreamResponse.cs @@ -13,14 +13,10 @@ namespace Elastic.Transport; /// MUST be disposed after use to ensure the HTTP connection is freed for reuse. /// /// -public class StreamResponse : - TransportResponse, - IDisposable +public class StreamResponse : TransportResponse, IDisposable { private bool _disposed; - internal Action? Finalizer { get; set; } - /// /// The MIME type of the response, if present. /// @@ -53,7 +49,12 @@ protected virtual void Dispose(bool disposing) if (disposing) { Body.Dispose(); - Finalizer?.Invoke(); + + if (LinkedDisposables is not null) + { + foreach (var disposable in LinkedDisposables) + disposable.Dispose(); + } } _disposed = true; diff --git a/src/Elastic.Transport/Responses/TransportResponse.cs b/src/Elastic.Transport/Responses/TransportResponse.cs index 0080373..5f65983 100644 --- a/src/Elastic.Transport/Responses/TransportResponse.cs +++ b/src/Elastic.Transport/Responses/TransportResponse.cs @@ -2,6 +2,8 @@ // Elasticsearch B.V licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information +using System; +using System.Collections.Generic; using System.Text.Json.Serialization; namespace Elastic.Transport; @@ -35,6 +37,10 @@ public override string ToString() => ApiCallDetails?.DebugInformation // ReSharper disable once ConstantNullCoalescingCondition ?? $"{nameof(ApiCallDetails)} not set, likely a bug, reverting to default ToString(): {base.ToString()}"; + [JsonIgnore] + internal IEnumerable? LinkedDisposables { get; set; } + + [JsonIgnore] internal virtual bool LeaveOpen { get; } = false; } diff --git a/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs b/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs index c38bb05..f4eb21d 100644 --- a/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs +++ b/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs @@ -3,16 +3,17 @@ // See the LICENSE file in the project root for more information using System; +using Elastic.Transport.Products; namespace Elastic.Transport.Tests.Plumbing { public static class InMemoryConnectionFactory { - public static TransportConfiguration Create() + public static TransportConfiguration Create(ProductRegistration productRegistration = null) { var invoker = new InMemoryRequestInvoker(); var pool = new SingleNodePool(new Uri("http://localhost:9200")); - var settings = new TransportConfiguration(pool, invoker); + var settings = new TransportConfiguration(pool, invoker, productRegistration: productRegistration); return settings; } } diff --git a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs index de9fa2e..7e143a2 100644 --- a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs +++ b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs @@ -5,8 +5,10 @@ using System; using System.Collections.Generic; using System.IO; +using System.Text; using System.Threading; using System.Threading.Tasks; +using Elastic.Transport.Products; using Elastic.Transport.Tests.Plumbing; using FluentAssertions; using Xunit; @@ -19,52 +21,114 @@ public class ResponseBuilderDisposeTests private readonly ITransportConfiguration _settingsDisableDirectStream = InMemoryConnectionFactory.Create().DisableDirectStreaming(); [Fact] - public async Task StreamResponseWithPotentialBody_StreamIsNotDisposed() => await AssertResponse(false, expectedDisposed: false); + public async Task StreamResponseWithPotentialBody_StreamIsNotDisposed() => + await AssertResponse(false, expectedDisposed: false); [Fact] - public async Task StreamResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsNotDisposed() => await AssertResponse(true, expectedDisposed: false); + public async Task StreamResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsNotDisposed() => + await AssertResponse(true, expectedDisposed: false); [Fact] - public async Task StreamResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(false, 204); + public async Task StreamResponseWith204StatusCode_StreamIsDisposed() => + await AssertResponse(false, 204); [Fact] - public async Task StreamResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(false, httpMethod: HttpMethod.HEAD); + public async Task StreamResponseForHeadRequest_StreamIsDisposed() => + await AssertResponse(false, httpMethod: HttpMethod.HEAD); [Fact] - public async Task StreamResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(false, contentLength: 0); + public async Task StreamResponseWithZeroContentLength_StreamIsDisposed() => + await AssertResponse(false, contentLength: 0); [Fact] - public async Task ResponseWithPotentialBody_StreamIsDisposed() => await AssertResponse(false, expectedDisposed: true); + public async Task ResponseWithPotentialBody_StreamIsDisposed() => + await AssertResponse(false, expectedDisposed: true); [Fact] - public async Task ResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => await AssertResponse(true, expectedDisposed: true); + public async Task ResponseWithPotentialBodyButInvalidMimeType_StreamIsDisposed() => + await AssertResponse(false, mimeType: "application/not-valid", expectedDisposed: true); [Fact] - public async Task ResponseWith204StatusCode_StreamIsDisposed() => await AssertResponse(false, 204); + public async Task ResponseWithPotentialBodyButSkippedStatusCode_StreamIsDisposed() => + await AssertResponse(false, skipStatusCode: 200, expectedDisposed: true); [Fact] - public async Task ResponseForHeadRequest_StreamIsDisposed() => await AssertResponse(false, httpMethod: HttpMethod.HEAD); + public async Task ResponseWithPotentialBodyButEmptyJson_StreamIsDisposed() => + await AssertResponse(false, responseJson: " ", expectedDisposed: true); [Fact] - public async Task ResponseWithZeroContentLength_StreamIsDisposed() => await AssertResponse(false, contentLength: 0); + // NOTE: The empty string here hits a fast path in STJ which returns default if the stream length is zero. + public async Task ResponseWithPotentialBodyButNullResponseDuringDeserialization_StreamIsDisposed() => + await AssertResponse(false, responseJson: "", expectedDisposed: true); [Fact] - public async Task StringResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => await AssertResponse(true, expectedDisposed: true, memoryStreamCreateExpected: 1); + public async Task ResponseWithPotentialBodyAndCustomResponseBuilder_StreamIsDisposed() => + await AssertResponse(false, customResponseBuilder: new TestCustomResponseBuilder(), expectedDisposed: true); - private async Task AssertResponse(bool disableDirectStreaming, int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, bool expectedDisposed = true, int memoryStreamCreateExpected = -1) - where T : TransportResponse, new() + [Fact] + // NOTE: We expect one memory stream factory creation when handling error responses + public async Task ResponseWithPotentialBodyAndErrorResponse_StreamIsDisposed() => + await AssertResponse(false, productRegistration: new TestProductRegistration(), expectedDisposed: true, memoryStreamCreateExpected: 1); + + [Fact] + public async Task ResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => + await AssertResponse(true, expectedDisposed: true); + + [Fact] + public async Task ResponseWith204StatusCode_StreamIsDisposed() => + await AssertResponse(false, 204); + + [Fact] + public async Task ResponseForHeadRequest_StreamIsDisposed() => + await AssertResponse(false, httpMethod: HttpMethod.HEAD); + + [Fact] + public async Task ResponseWithZeroContentLength_StreamIsDisposed() => + await AssertResponse(false, contentLength: 0); + + [Fact] + public async Task StringResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => + await AssertResponse(true, expectedDisposed: true, memoryStreamCreateExpected: 1); + + private async Task AssertResponse(bool disableDirectStreaming, int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, + bool expectedDisposed = true, string mimeType = "application/json", string responseJson = "{}", int skipStatusCode = -1, + CustomResponseBuilder customResponseBuilder = null, ProductRegistration productRegistration = null, int memoryStreamCreateExpected = -1) + where T : TransportResponse, new() { - var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings; - var memoryStreamFactory = new TrackMemoryStreamFactory(); + ITransportConfiguration config; - var requestData = new RequestData(httpMethod, "/", null, settings, null, null, memoryStreamFactory, default) + if (skipStatusCode > -1 ) + { + config = InMemoryConnectionFactory.Create(productRegistration) + .DisableDirectStreaming(disableDirectStreaming) + .SkipDeserializationForStatusCodes(skipStatusCode); + } + else if (productRegistration is not null) + { + config = InMemoryConnectionFactory.Create(productRegistration) + .DisableDirectStreaming(disableDirectStreaming); + } + else + { + config = disableDirectStreaming ? _settingsDisableDirectStream : _settings; + } + + var memoryStreamFactory = new TrackMemoryStreamFactory(); + + var requestData = new RequestData(httpMethod, "/", null, config, null, customResponseBuilder, memoryStreamFactory, default) { Node = new Node(new Uri("http://localhost:9200")) }; var stream = new TrackDisposeStream(); - var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, contentLength, null, null); + if (!string.IsNullOrEmpty(responseJson)) + { + stream.Write(Encoding.UTF8.GetBytes(responseJson), 0, responseJson.Length); + stream.Position = 0; + } + + var response = config.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, mimeType, contentLength, null, null); response.Should().NotBeNull(); @@ -83,12 +147,12 @@ private async Task AssertResponse(bool disableDirectStreaming, int statusCode stream = new TrackDisposeStream(); var ct = new CancellationToken(); - response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, + response = await config.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, cancellationToken: ct); response.Should().NotBeNull(); - memoryStreamFactory.Created.Count.Should().Be(memoryStreamCreateExpected > -1 ? memoryStreamCreateExpected + 1: disableDirectStreaming ? 2 : 0); + memoryStreamFactory.Created.Count.Should().Be(memoryStreamCreateExpected > -1 ? memoryStreamCreateExpected + 1 : disableDirectStreaming ? 2 : 0); if (disableDirectStreaming) { var memoryStream = memoryStreamFactory.Created[0]; @@ -101,6 +165,66 @@ private async Task AssertResponse(bool disableDirectStreaming, int statusCode } } + private class TestProductRegistration : ProductRegistration + { + public override string DefaultMimeType => "application/json"; + public override string Name => "name"; + public override string ServiceIdentifier => "id"; + public override bool SupportsPing => false; + public override bool SupportsSniff => false; + public override HeadersList ResponseHeadersToParse => []; + public override MetaHeaderProvider MetaHeaderProvider => null; + public override string ProductAssemblyVersion => "0.0.0"; + public override IReadOnlyDictionary DefaultOpenTelemetryAttributes => new Dictionary(); + public override RequestData CreatePingRequestData(Node node, RequestConfiguration requestConfiguration, ITransportConfiguration global, MemoryStreamFactory memoryStreamFactory) => throw new NotImplementedException(); + public override RequestData CreateSniffRequestData(Node node, IRequestConfiguration requestConfiguration, ITransportConfiguration settings, MemoryStreamFactory memoryStreamFactory) => throw new NotImplementedException(); + public override IReadOnlyCollection DefaultHeadersToParse() => []; + public override bool HttpStatusCodeClassifier(HttpMethod method, int statusCode) => true; + public override bool NodePredicate(Node node) => throw new NotImplementedException(); + public override Dictionary ParseOpenTelemetryAttributesFromApiCallDetails(ApiCallDetails callDetails) => throw new NotImplementedException(); + public override TransportResponse Ping(IRequestInvoker requestInvoker, RequestData pingData) => throw new NotImplementedException(); + public override Task PingAsync(IRequestInvoker requestInvoker, RequestData pingData, CancellationToken cancellationToken) => throw new NotImplementedException(); + public override Tuple> Sniff(IRequestInvoker requestInvoker, bool forceSsl, RequestData requestData) => throw new NotImplementedException(); + public override Task>> SniffAsync(IRequestInvoker requestInvoker, bool forceSsl, RequestData requestData, CancellationToken cancellationToken) => throw new NotImplementedException(); + public override int SniffOrder(Node node) => throw new NotImplementedException(); + public override bool TryGetServerErrorReason(TResponse response, out string reason) => throw new NotImplementedException(); + public override ResponseBuilder ResponseBuilder => new TestErrorResponseBuilder(); + } + + private class TestError : ErrorResponse + { + public string MyError { get; set; } + + public override bool HasError() => true; + } + + private class TestErrorResponseBuilder : DefaultResponseBuilder + { + protected override void SetErrorOnResponse(TResponse response, TestError error) + { + // nothing to do in this scenario + } + + protected override bool TryGetError(ApiCallDetails apiCallDetails, RequestData requestData, Stream responseStream, out TestError error) + { + error = new TestError(); + return true; + } + + protected override bool RequiresErrorDeserialization(ApiCallDetails details, RequestData requestData) => true; + } + + private class TestCustomResponseBuilder : CustomResponseBuilder + { + public override object DeserializeResponse(Serializer serializer, ApiCallDetails response, Stream stream) => + new TestResponse { ApiCallDetails = response }; + + public override Task DeserializeResponseAsync(Serializer serializer, ApiCallDetails response, Stream stream, CancellationToken ctx = default) => + Task.FromResult(new TestResponse { ApiCallDetails = response }); + } + + private class TestRequestParameters : RequestParameters { } + private class TrackDisposeStream : MemoryStream { public TrackDisposeStream() { } From 1b4c515dfcede828cdc0b43bfa59cf627b1c1b71 Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Tue, 29 Oct 2024 15:45:59 +0000 Subject: [PATCH 7/9] Refactoring based on latest decisions --- .../Pipeline/DefaultResponseBuilder.cs | 54 +++++++++---------- .../TransportClient/HttpRequestInvoker.cs | 7 ++- .../TransportClient/HttpWebRequestInvoker.cs | 7 ++- .../Responses/TransportResponse.cs | 3 ++ .../Http/StreamResponseTests.cs | 43 +++++++++++++-- .../ResponseBuilderDisposeTests.cs | 6 +-- 6 files changed, 82 insertions(+), 38 deletions(-) diff --git a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs index 8e6461a..533c545 100644 --- a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs +++ b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs @@ -66,11 +66,8 @@ IReadOnlyDictionary tcpStats // Only attempt to set the body if the response may have content if (MayHaveBody(statusCode, requestData.Method, contentLength)) response = SetBody(details, requestData, responseStream, mimeType); - else - responseStream.Dispose(); response ??= new TResponse(); - response.ApiCallDetails = details; return response; } @@ -101,11 +98,8 @@ public override async Task ToResponseAsync( if (MayHaveBody(statusCode, requestData.Method, contentLength)) response = await SetBodyAsync(details, requestData, responseStream, mimeType, cancellationToken).ConfigureAwait(false); - else - responseStream.Dispose(); response ??= new TResponse(); - response.ApiCallDetails = details; return response; } @@ -211,6 +205,8 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, var disableDirectStreaming = requestData.PostData?.DisableDirectStreaming ?? requestData.ConnectionSettings.DisableDirectStreaming; var requiresErrorDeserialization = RequiresErrorDeserialization(details, requestData); + var ownsStream = false; + if (disableDirectStreaming || NeedsToEagerReadStream() || requiresErrorDeserialization) { var inMemoryStream = requestData.MemoryStreamFactory.Create(); @@ -221,22 +217,28 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, responseStream.CopyTo(inMemoryStream, BufferSize); bytes = SwapStreams(ref responseStream, ref inMemoryStream); + ownsStream = true; details.ResponseBodyInBytes = bytes; } - if (TrySetSpecialType(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; + if (TrySetSpecialType(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var response)) + { + ConditionalDisposal(responseStream, ownsStream, response); + return response; + } if (details.HttpStatusCode.HasValue && requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) { // In this scenario, we always dispose as we've explicitly skipped reading the response - responseStream.Dispose(); + if (ownsStream) + responseStream.Dispose(); + return null; } var serializer = requestData.ConnectionSettings.RequestResponseSerializer; - TResponse response; if (requestData.CustomResponseBuilder != null) { var beforeTicks = Stopwatch.GetTimestamp(); @@ -253,9 +255,7 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - if (!response.LeaveOpen) - responseStream.Dispose(); - + ConditionalDisposal(responseStream, ownsStream, response); return response; } @@ -267,16 +267,13 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, { response = new TResponse(); SetErrorOnResponse(response, error); - - if (!response.LeaveOpen) - responseStream.Dispose(); - + ConditionalDisposal(responseStream, ownsStream, response); return response; } if (!requestData.ValidateResponseContentType(mimeType)) { - responseStream.Dispose(); + ConditionalDisposal(responseStream, ownsStream, response); return default; } @@ -292,18 +289,25 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - if (response is null || !response.LeaveOpen) - responseStream.Dispose(); - + ConditionalDisposal(responseStream, ownsStream, response); return response; } catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) { - // Note that this is only thrown after a check if the stream length is zero. When the length is zero, - // `default` is returned by Deserialize(Async) instead. - responseStream.Dispose(); + // Note the exception this handles is ONLY thrown after a check if the stream length is zero. + // When the length is zero, `default` is returned by Deserialize(Async) instead. + + ConditionalDisposal(responseStream, ownsStream, response); return default; } + + static void ConditionalDisposal(Stream responseStream, bool ownsStream, TResponse response) + { + // We only dispose of the responseStream if we created it (i.e. it is a MemoryStream) we + // created via MemoryStreamFactory. + if (ownsStream && (response is null || !response.LeaveOpen)) + responseStream.Dispose(); + } } private static bool TrySetSpecialType(string mimeType, byte[] bytes, Stream responseStream, @@ -341,9 +345,6 @@ private static bool TrySetSpecialType(string mimeType, byte[] bytes, } } - if (!response.LeaveOpen) - responseStream.Dispose(); - return response != null; } @@ -356,7 +357,6 @@ private static bool NeedsToEagerReadStream() private static byte[] SwapStreams(ref Stream responseStream, ref MemoryStream ms) { var bytes = ms.ToArray(); - responseStream.Dispose(); responseStream = ms; responseStream.Position = 0; return bytes; diff --git a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs index fa215a1..4fd3335 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs @@ -171,10 +171,11 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req // resources can be linked such that their disposal is deferred. if (response.LeaveOpen) { - response.LinkedDisposables = [receivedResponse]; + response.LinkedDisposables = [receivedResponse, responseStream]; } else { + responseStream.Dispose(); receivedResponse.Dispose(); } @@ -192,7 +193,9 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req } catch { - receivedResponse.Dispose(); // if there's an exception, ensure we always release the response so that the connection is freed. + // if there's an exception, ensure we always release the stream and response so that the connection is freed. + responseStream.Dispose(); + receivedResponse.Dispose(); throw; } } diff --git a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs index 265f952..6ced0e1 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs @@ -182,10 +182,11 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req // resources can be linked such that their disposal is deferred. if (response.LeaveOpen) { - response.LinkedDisposables = [receivedResponse]; + response.LinkedDisposables = [receivedResponse, responseStream]; } else { + responseStream.Dispose(); receivedResponse.Dispose(); } @@ -202,7 +203,9 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req } catch { - receivedResponse.Dispose(); // ensure we always release the response so the connection is freed. + // if there's an exception, ensure we always release the stream and response so that the connection is freed. + responseStream.Dispose(); + receivedResponse.Dispose(); throw; } } diff --git a/src/Elastic.Transport/Responses/TransportResponse.cs b/src/Elastic.Transport/Responses/TransportResponse.cs index 5f65983..a23f2e2 100644 --- a/src/Elastic.Transport/Responses/TransportResponse.cs +++ b/src/Elastic.Transport/Responses/TransportResponse.cs @@ -37,6 +37,9 @@ public override string ToString() => ApiCallDetails?.DebugInformation // ReSharper disable once ConstantNullCoalescingCondition ?? $"{nameof(ApiCallDetails)} not set, likely a bug, reverting to default ToString(): {base.ToString()}"; + /// + /// Allows other disposable resources to to be disposed along with the response. + /// [JsonIgnore] internal IEnumerable? LinkedDisposables { get; set; } diff --git a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs index e006567..1933da8 100644 --- a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs +++ b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs @@ -33,6 +33,23 @@ public async Task StreamResponse_ShouldNotBeDisposed() _ = sr.ReadToEndAsync(); } + //[Fact] + //public async Task StreamResponse_MemoryStreamShouldNotBeDisposed() + //{ + // var nodePool = new SingleNodePool(Server.Uri); + // var memoryStreamFactory = new TrackMemoryStreamFactory(); + // var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + // .MemoryStreamFactory(memoryStreamFactory); + + // var transport = new DistributedTransport(config); + + // _ = await transport.PostAsync(Path, PostData.String("{}")); + + // var memoryStream = memoryStreamFactory.Created.Last(); + + // memoryStream.IsDisposed.Should().BeFalse(); + //} + [Fact] public async Task StreamResponse_MemoryStreamShouldNotBeDisposed() { @@ -52,13 +69,12 @@ public async Task StreamResponse_MemoryStreamShouldNotBeDisposed() } [Fact] - public async Task StringResponse_MemoryStreamShouldBeDisposed() + public async Task StringResponse_MemoryStreamShouldBeDisposed() { var nodePool = new SingleNodePool(Server.Uri); var memoryStreamFactory = new TrackMemoryStreamFactory(); var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) - .MemoryStreamFactory(memoryStreamFactory) - .DisableDirectStreaming(true); + .MemoryStreamFactory(memoryStreamFactory); var transport = new DistributedTransport(config); @@ -69,6 +85,27 @@ public async Task StringResponse_MemoryStreamShouldBeDisposed() memoryStream.IsDisposed.Should().BeTrue(); } + [Fact] + public async Task Response_MemoryStreamShouldBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory); + + var transport = new DistributedTransport(config); + + _ = await transport.PostAsync(Path, PostData.String("{}")); + + var memoryStream = memoryStreamFactory.Created.Last(); + + memoryStream.IsDisposed.Should().BeTrue(); + } + + private class TestResponse : TransportResponse + { + } + private class TrackDisposeStream : MemoryStream { public TrackDisposeStream() { } diff --git a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs index 7e143a2..742fe4f 100644 --- a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs +++ b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs @@ -29,8 +29,8 @@ public async Task StreamResponseWithPotentialBodyAndDisableDirectStreaming_Memor await AssertResponse(true, expectedDisposed: false); [Fact] - public async Task StreamResponseWith204StatusCode_StreamIsDisposed() => - await AssertResponse(false, 204); + public async Task StreamResponseWith204StatusCode_MemoryStreamIsDisposed() => + await AssertResponse(true, 204); [Fact] public async Task StreamResponseForHeadRequest_StreamIsDisposed() => @@ -223,8 +223,6 @@ public override Task DeserializeResponseAsync(Serializer serializer, Api Task.FromResult(new TestResponse { ApiCallDetails = response }); } - private class TestRequestParameters : RequestParameters { } - private class TrackDisposeStream : MemoryStream { public TrackDisposeStream() { } From af68fe71ea68bfcafa49a3115f411835c96577e2 Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Wed, 30 Oct 2024 08:45:11 +0000 Subject: [PATCH 8/9] Update tests and fix post data missed disposal --- .../Pipeline/DefaultResponseBuilder.cs | 6 +- .../Requests/Body/PostData.cs | 6 + .../Http/StreamResponseTests.cs | 106 +++++++++++++----- .../ResponseBuilderDisposeTests.cs | 72 +++--------- 4 files changed, 103 insertions(+), 87 deletions(-) diff --git a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs index 533c545..e7b2a77 100644 --- a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs +++ b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs @@ -230,10 +230,7 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, if (details.HttpStatusCode.HasValue && requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) { - // In this scenario, we always dispose as we've explicitly skipped reading the response - if (ownsStream) - responseStream.Dispose(); - + ConditionalDisposal(responseStream, ownsStream, response); return null; } @@ -296,7 +293,6 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, { // Note the exception this handles is ONLY thrown after a check if the stream length is zero. // When the length is zero, `default` is returned by Deserialize(Async) instead. - ConditionalDisposal(responseStream, ownsStream, response); return default; } diff --git a/src/Elastic.Transport/Requests/Body/PostData.cs b/src/Elastic.Transport/Requests/Body/PostData.cs index 68f2c52..9e59775 100644 --- a/src/Elastic.Transport/Requests/Body/PostData.cs +++ b/src/Elastic.Transport/Requests/Body/PostData.cs @@ -111,6 +111,7 @@ protected void FinishStream(Stream writableStream, MemoryStream buffer, ITranspo buffer.Position = 0; buffer.CopyTo(writableStream, BufferSize); WrittenBytes ??= buffer.ToArray(); + buffer.Dispose(); } /// @@ -132,5 +133,10 @@ protected async buffer.Position = 0; await buffer.CopyToAsync(writableStream, BufferSize, ctx).ConfigureAwait(false); WrittenBytes ??= buffer.ToArray(); +#if NET + await buffer.DisposeAsync().ConfigureAwait(false); +#else + buffer.Dispose(); +#endif } } diff --git a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs index 1933da8..7ec6720 100644 --- a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs +++ b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs @@ -5,7 +5,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using System.Text.Json; +using System.Text; using System.Threading.Tasks; using Elastic.Transport.IntegrationTests.Plumbing; using Elastic.Transport.Products.Elasticsearch; @@ -33,23 +33,6 @@ public async Task StreamResponse_ShouldNotBeDisposed() _ = sr.ReadToEndAsync(); } - //[Fact] - //public async Task StreamResponse_MemoryStreamShouldNotBeDisposed() - //{ - // var nodePool = new SingleNodePool(Server.Uri); - // var memoryStreamFactory = new TrackMemoryStreamFactory(); - // var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) - // .MemoryStreamFactory(memoryStreamFactory); - - // var transport = new DistributedTransport(config); - - // _ = await transport.PostAsync(Path, PostData.String("{}")); - - // var memoryStream = memoryStreamFactory.Created.Last(); - - // memoryStream.IsDisposed.Should().BeFalse(); - //} - [Fact] public async Task StreamResponse_MemoryStreamShouldNotBeDisposed() { @@ -63,9 +46,9 @@ public async Task StreamResponse_MemoryStreamShouldNotBeDisposed() _ = await transport.PostAsync(Path, PostData.String("{}")); - var memoryStream = memoryStreamFactory.Created.Last(); - - memoryStream.IsDisposed.Should().BeFalse(); + // When disable direct streaming, we have 1 for the original content, 1 for the buffered request bytes and the last for the buffered response + memoryStreamFactory.Created.Count.Should().Be(3); + memoryStreamFactory.Created.Last().IsDisposed.Should().BeFalse(); } [Fact] @@ -80,13 +63,36 @@ public async Task StringResponse_MemoryStreamShouldBeDisposed() _ = await transport.PostAsync(Path, PostData.String("{}")); - var memoryStream = memoryStreamFactory.Created.Last(); + memoryStreamFactory.Created.Count.Should().Be(2); + foreach (var memoryStream in memoryStreamFactory.Created) + { + memoryStream.IsDisposed.Should().BeTrue(); + } + } + + [Fact] + public async Task WhenInvalidJson_MemoryStreamShouldBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory) + .DisableDirectStreaming(true); + + var transport = new DistributedTransport(config); + + var payload = new Payload { ResponseJsonString = " " }; + _ = await transport.PostAsync(Path, PostData.Serializable(payload)); - memoryStream.IsDisposed.Should().BeTrue(); + memoryStreamFactory.Created.Count.Should().Be(3); + foreach (var memoryStream in memoryStreamFactory.Created) + { + memoryStream.IsDisposed.Should().BeTrue(); + } } [Fact] - public async Task Response_MemoryStreamShouldBeDisposed() + public async Task WhenNoContent_MemoryStreamShouldBeDisposed() { var nodePool = new SingleNodePool(Server.Uri); var memoryStreamFactory = new TrackMemoryStreamFactory(); @@ -95,11 +101,37 @@ public async Task Response_MemoryStreamShouldBeDisposed() var transport = new DistributedTransport(config); - _ = await transport.PostAsync(Path, PostData.String("{}")); + var payload = new Payload { ResponseJsonString = "", StatusCode = 204 }; + _ = await transport.PostAsync(Path, PostData.Serializable(payload)); - var memoryStream = memoryStreamFactory.Created.Last(); + // We expect one for sending the request payload, but as the response is 204, we shouldn't + // see other memory streams being created for the response. + memoryStreamFactory.Created.Count.Should().Be(1); + foreach (var memoryStream in memoryStreamFactory.Created) + { + memoryStream.IsDisposed.Should().BeTrue(); + } + } + + [Fact] + public async Task PlainText_MemoryStreamShouldBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory) + .DisableDirectStreaming(true); + + var transport = new DistributedTransport(config); + + var payload = new Payload { ResponseJsonString = "text", ContentType = "text/plain" }; + _ = await transport.PostAsync(Path, PostData.Serializable(payload)); - memoryStream.IsDisposed.Should().BeTrue(); + memoryStreamFactory.Created.Count.Should().Be(3); + foreach (var memoryStream in memoryStreamFactory.Created) + { + memoryStream.IsDisposed.Should().BeTrue(); + } } private class TestResponse : TransportResponse @@ -150,9 +182,27 @@ public override MemoryStream Create(byte[] bytes, int index, int count) } } +public class Payload +{ + public string ResponseJsonString { get; set; } = "{}"; + public string ContentType { get; set; } = "application/json"; + public int StatusCode { get; set; } = 200; +} + [ApiController, Route("[controller]")] public class StreamResponseController : ControllerBase { [HttpPost] - public Task Post([FromBody] JsonElement body) => Task.FromResult(body); + public async Task Post([FromBody] Payload payload) + { + Response.ContentType = payload.ContentType; + + if (payload.StatusCode != 204) + { + await Response.BodyWriter.WriteAsync(Encoding.UTF8.GetBytes(payload.ResponseJsonString)); + await Response.BodyWriter.CompleteAsync(); + } + + return StatusCode(payload.StatusCode); + } } diff --git a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs index 742fe4f..8b79c6a 100644 --- a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs +++ b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs @@ -29,66 +29,34 @@ public async Task StreamResponseWithPotentialBodyAndDisableDirectStreaming_Memor await AssertResponse(true, expectedDisposed: false); [Fact] - public async Task StreamResponseWith204StatusCode_MemoryStreamIsDisposed() => - await AssertResponse(true, 204); + public async Task ResponseWithPotentialBodyButInvalidMimeType_MemoryStreamIsDisposed() => + await AssertResponse(true, mimeType: "application/not-valid", expectedDisposed: true); [Fact] - public async Task StreamResponseForHeadRequest_StreamIsDisposed() => - await AssertResponse(false, httpMethod: HttpMethod.HEAD); + public async Task ResponseWithPotentialBodyButSkippedStatusCode_MemoryStreamIsDisposed() => + await AssertResponse(true, skipStatusCode: 200, expectedDisposed: true); [Fact] - public async Task StreamResponseWithZeroContentLength_StreamIsDisposed() => - await AssertResponse(false, contentLength: 0); - - [Fact] - public async Task ResponseWithPotentialBody_StreamIsDisposed() => - await AssertResponse(false, expectedDisposed: true); - - [Fact] - public async Task ResponseWithPotentialBodyButInvalidMimeType_StreamIsDisposed() => - await AssertResponse(false, mimeType: "application/not-valid", expectedDisposed: true); - - [Fact] - public async Task ResponseWithPotentialBodyButSkippedStatusCode_StreamIsDisposed() => - await AssertResponse(false, skipStatusCode: 200, expectedDisposed: true); - - [Fact] - public async Task ResponseWithPotentialBodyButEmptyJson_StreamIsDisposed() => - await AssertResponse(false, responseJson: " ", expectedDisposed: true); + public async Task ResponseWithPotentialBodyButEmptyJson_MemoryStreamIsDisposed() => + await AssertResponse(true, responseJson: " ", expectedDisposed: true); [Fact] // NOTE: The empty string here hits a fast path in STJ which returns default if the stream length is zero. - public async Task ResponseWithPotentialBodyButNullResponseDuringDeserialization_StreamIsDisposed() => - await AssertResponse(false, responseJson: "", expectedDisposed: true); + public async Task ResponseWithPotentialBodyButNullResponseDuringDeserialization_MemoryStreamIsDisposed() => + await AssertResponse(true, responseJson: "", expectedDisposed: true); [Fact] - public async Task ResponseWithPotentialBodyAndCustomResponseBuilder_StreamIsDisposed() => - await AssertResponse(false, customResponseBuilder: new TestCustomResponseBuilder(), expectedDisposed: true); + public async Task ResponseWithPotentialBodyAndCustomResponseBuilder_MemoryStreamIsDisposed() => + await AssertResponse(true, customResponseBuilder: new TestCustomResponseBuilder(), expectedDisposed: true); [Fact] // NOTE: We expect one memory stream factory creation when handling error responses public async Task ResponseWithPotentialBodyAndErrorResponse_StreamIsDisposed() => - await AssertResponse(false, productRegistration: new TestProductRegistration(), expectedDisposed: true, memoryStreamCreateExpected: 1); - - [Fact] - public async Task ResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => - await AssertResponse(true, expectedDisposed: true); - - [Fact] - public async Task ResponseWith204StatusCode_StreamIsDisposed() => - await AssertResponse(false, 204); - - [Fact] - public async Task ResponseForHeadRequest_StreamIsDisposed() => - await AssertResponse(false, httpMethod: HttpMethod.HEAD); - - [Fact] - public async Task ResponseWithZeroContentLength_StreamIsDisposed() => - await AssertResponse(false, contentLength: 0); + await AssertResponse(true, productRegistration: new TestProductRegistration(), expectedDisposed: true); [Fact] public async Task StringResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => - await AssertResponse(true, expectedDisposed: true, memoryStreamCreateExpected: 1); + await AssertResponse(false, expectedDisposed: true, memoryStreamCreateExpected: 1); private async Task AssertResponse(bool disableDirectStreaming, int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, bool expectedDisposed = true, string mimeType = "application/json", string responseJson = "{}", int skipStatusCode = -1, @@ -136,13 +104,11 @@ private async Task AssertResponse(bool disableDirectStreaming, int statusCode if (disableDirectStreaming) { var memoryStream = memoryStreamFactory.Created[0]; - stream.IsDisposed.Should().BeTrue(); memoryStream.IsDisposed.Should().Be(expectedDisposed); } - else - { - stream.IsDisposed.Should().Be(expectedDisposed); - } + + // The latest implementation should never dispose the incoming stream and assumes the caller will handler disposal + stream.IsDisposed.Should().Be(false); stream = new TrackDisposeStream(); var ct = new CancellationToken(); @@ -156,13 +122,11 @@ private async Task AssertResponse(bool disableDirectStreaming, int statusCode if (disableDirectStreaming) { var memoryStream = memoryStreamFactory.Created[0]; - stream.IsDisposed.Should().BeTrue(); memoryStream.IsDisposed.Should().Be(expectedDisposed); } - else - { - stream.IsDisposed.Should().Be(expectedDisposed); - } + + // The latest implementation should never dispose the incoming stream and assumes the caller will handler disposal + stream.IsDisposed.Should().Be(false); } private class TestProductRegistration : ProductRegistration From 20bcdf2650cf5d3f2fbbfd3fae343800bf3f4ce5 Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Wed, 30 Oct 2024 08:49:55 +0000 Subject: [PATCH 9/9] Update comments --- src/Elastic.Transport/Responses/TransportResponse.cs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/Elastic.Transport/Responses/TransportResponse.cs b/src/Elastic.Transport/Responses/TransportResponse.cs index a23f2e2..4e7f4a8 100644 --- a/src/Elastic.Transport/Responses/TransportResponse.cs +++ b/src/Elastic.Transport/Responses/TransportResponse.cs @@ -40,9 +40,20 @@ public override string ToString() => ApiCallDetails?.DebugInformation /// /// Allows other disposable resources to to be disposed along with the response. /// + /// + /// While it's slightly confusing to have this on the base type which is NOT IDisposable, it avoids + /// specialised type checking in the request invoker and response builder code. Currently, only used by + /// StreamResponse and kept internal. If we later make this public, we might need to refine this. + /// [JsonIgnore] internal IEnumerable? LinkedDisposables { get; set; } + /// + /// Allows the response to identify that the response stream should NOT be automatically disposed. + /// + /// + /// Currently only used by StreamResponse and therefore internal. + /// [JsonIgnore] internal virtual bool LeaveOpen { get; } = false; }