Skip to content

Commit

Permalink
Refactor RequestData (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mpdreamz authored Oct 31, 2024
1 parent f4c42c8 commit d5141a9
Show file tree
Hide file tree
Showing 48 changed files with 988 additions and 1,037 deletions.
15 changes: 8 additions & 7 deletions Elastic.Transport.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ See the LICENSE file in the project root for more information</s:String>
&lt;/Entry.Match&gt;&#xD;
&lt;Entry.SortBy&gt;&#xD;
&lt;Kind Is="Member" /&gt;&#xD;
&lt;Name Is="Enter Pattern Here" /&gt;&#xD;
&lt;Name /&gt;&#xD;
&lt;/Entry.SortBy&gt;&#xD;
&lt;/Entry&gt;&#xD;
&lt;Entry DisplayName="Fields"&gt;&#xD;
Expand All @@ -119,7 +119,7 @@ See the LICENSE file in the project root for more information</s:String>
&lt;Entry.SortBy&gt;&#xD;
&lt;Access /&gt;&#xD;
&lt;Readonly /&gt;&#xD;
&lt;Name Is="Enter Pattern Here" /&gt;&#xD;
&lt;Name /&gt;&#xD;
&lt;/Entry.SortBy&gt;&#xD;
&lt;/Entry&gt;&#xD;
&lt;Entry DisplayName="Constructors"&gt;&#xD;
Expand All @@ -139,7 +139,7 @@ See the LICENSE file in the project root for more information</s:String>
&lt;/Entry.Match&gt;&#xD;
&lt;Entry.SortBy&gt;&#xD;
&lt;Access /&gt;&#xD;
&lt;Name Is="Enter Pattern Here" /&gt;&#xD;
&lt;Name /&gt;&#xD;
&lt;/Entry.SortBy&gt;&#xD;
&lt;/Entry&gt;&#xD;
&lt;Entry DisplayName="Setup/Teardown Methods" Priority="100"&gt;&#xD;
Expand Down Expand Up @@ -203,7 +203,7 @@ See the LICENSE file in the project root for more information</s:String>
&lt;/Entry.Match&gt;&#xD;
&lt;Entry.SortBy&gt;&#xD;
&lt;Kind Is="Member" /&gt;&#xD;
&lt;Name Is="Enter Pattern Here" /&gt;&#xD;
&lt;Name /&gt;&#xD;
&lt;/Entry.SortBy&gt;&#xD;
&lt;/Entry&gt;&#xD;
&lt;Entry DisplayName="Fields"&gt;&#xD;
Expand All @@ -218,7 +218,7 @@ See the LICENSE file in the project root for more information</s:String>
&lt;Entry.SortBy&gt;&#xD;
&lt;Access /&gt;&#xD;
&lt;Readonly /&gt;&#xD;
&lt;Name Is="Enter Pattern Here" /&gt;&#xD;
&lt;Name /&gt;&#xD;
&lt;/Entry.SortBy&gt;&#xD;
&lt;/Entry&gt;&#xD;
&lt;Entry DisplayName="Constructors"&gt;&#xD;
Expand All @@ -238,7 +238,7 @@ See the LICENSE file in the project root for more information</s:String>
&lt;/Entry.Match&gt;&#xD;
&lt;Entry.SortBy&gt;&#xD;
&lt;Access /&gt;&#xD;
&lt;Name Is="Enter Pattern Here" /&gt;&#xD;
&lt;Name /&gt;&#xD;
&lt;/Entry.SortBy&gt;&#xD;
&lt;/Entry&gt;&#xD;
&lt;Entry DisplayName="Interface Implementations"&gt;&#xD;
Expand All @@ -251,7 +251,7 @@ See the LICENSE file in the project root for more information</s:String>
&lt;Entry.SortBy&gt;&#xD;
&lt;ImplementsInterface Name="IDisposable" /&gt;&#xD;
&lt;Access /&gt;&#xD;
&lt;Name Is="Enter Pattern Here" /&gt;&#xD;
&lt;Name /&gt;&#xD;
&lt;/Entry.SortBy&gt;&#xD;
&lt;/Entry&gt;&#xD;
&lt;Entry DisplayName="All other members" /&gt;&#xD;
Expand Down Expand Up @@ -505,6 +505,7 @@ See the LICENSE file in the project root for more information</s:String>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ECSharpKeepExistingMigration/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ECSharpPlaceEmbeddedOnSameLineMigration/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ECSharpUseContinuousIndentInsideBracesMigration/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002EMemberReordering_002EMigrations_002ECSharpFileLayoutPatternRemoveIsAttributeUpgrade/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002EAlwaysTreatStructAsNotReorderableMigration/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/Environment/SettingsMigration/IsMigratorApplied/=JetBrains_002EReSharper_002EPsi_002ECSharp_002ECodeStyle_002ESettingsUpgrade_002EMigrateBlankLinesAroundFieldToBlankLinesAroundProperty/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/HighlightingManager/HighlightingEnabledByDefault/@EntryValue">False</s:Boolean>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information

#nullable enable
using System;
using System.Collections.Generic;
using System.IO;
Expand Down Expand Up @@ -34,7 +35,7 @@ public class VirtualClusterRequestInvoker : IRequestInvoker
{
private static readonly object Lock = new();

private static byte[] _defaultResponseBytes;
private static byte[]? _defaultResponseBytes;

private VirtualCluster _cluster;
private readonly TestableDateTimeProvider _dateTimeProvider;
Expand All @@ -45,7 +46,9 @@ public class VirtualClusterRequestInvoker : IRequestInvoker

internal VirtualClusterRequestInvoker(VirtualCluster cluster, TestableDateTimeProvider dateTimeProvider)
{
UpdateCluster(cluster);
_cluster = cluster;
_calls = cluster.Nodes.ToDictionary(n => n.Uri.Port, v => new State());
_productRegistration = cluster.ProductRegistration;
_dateTimeProvider = dateTimeProvider;
_productRegistration = cluster.ProductRegistration;
_inMemoryRequestInvoker = new InMemoryRequestInvoker();
Expand Down Expand Up @@ -100,52 +103,55 @@ private static object DefaultResponse

private void UpdateCluster(VirtualCluster cluster)
{
if (cluster == null) return;

lock (Lock)
{
_cluster = cluster;
_calls = cluster.Nodes.ToDictionary(n => n.Uri.Port, v => new State());
_productRegistration = cluster.ProductRegistration;
}

}

private bool IsSniffRequest(RequestData requestData) => _productRegistration.IsSniffRequest(requestData);
private bool IsSniffRequest(Endpoint endpoint) => _productRegistration.IsSniffRequest(endpoint);

private bool IsPingRequest(RequestData requestData) => _productRegistration.IsPingRequest(requestData);
private bool IsPingRequest(Endpoint endpoint) => _productRegistration.IsPingRequest(endpoint);

/// <inheritdoc cref="IRequestInvoker.RequestAsync{TResponse}"/>>
public Task<TResponse> RequestAsync<TResponse>(RequestData requestData, CancellationToken cancellationToken)
public Task<TResponse> RequestAsync<TResponse>(Endpoint endpoint, RequestData requestData, PostData? postData, CancellationToken cancellationToken)
where TResponse : TransportResponse, new() =>
Task.FromResult(Request<TResponse>(requestData));
Task.FromResult(Request<TResponse>(endpoint, requestData, postData));

/// <inheritdoc cref="IRequestInvoker.Request{TResponse}"/>>
public TResponse Request<TResponse>(RequestData requestData)
public TResponse Request<TResponse>(Endpoint endpoint, RequestData requestData, PostData? postData)
where TResponse : TransportResponse, new()
{
if (!_calls.ContainsKey(requestData.Uri.Port))
throw new Exception($"Expected a call to happen on port {requestData.Uri.Port} but received none");
if (!_calls.ContainsKey(endpoint.Uri.Port))
throw new Exception($"Expected a call to happen on port {endpoint.Uri.Port} but received none");

try
{
var state = _calls[requestData.Uri.Port];
if (IsSniffRequest(requestData))
var state = _calls[endpoint.Uri.Port];
if (IsSniffRequest(endpoint))
{
_ = Interlocked.Increment(ref state.Sniffed);
return HandleRules<TResponse, ISniffRule>(
endpoint,
requestData,
postData,
nameof(VirtualCluster.Sniff),
_cluster.SniffingRules,
requestData.RequestTimeout,
(r) => UpdateCluster(r.NewClusterState),
(r) => _productRegistration.CreateSniffResponseBytes(_cluster.Nodes, _cluster.ElasticsearchVersion, _cluster.PublishAddressOverride, _cluster.SniffShouldReturnFqnd)
);
}
if (IsPingRequest(requestData))
if (IsPingRequest(endpoint))
{
_ = Interlocked.Increment(ref state.Pinged);
return HandleRules<TResponse, IRule>(
endpoint,
requestData,
postData,
nameof(VirtualCluster.Ping),
_cluster.PingingRules,
requestData.PingTimeout,
Expand All @@ -155,7 +161,9 @@ public TResponse Request<TResponse>(RequestData requestData)
}
_ = Interlocked.Increment(ref state.Called);
return HandleRules<TResponse, IClientCallRule>(
endpoint,
requestData,
postData,
nameof(VirtualCluster.ClientCalls),
_cluster.ClientCallRules,
requestData.RequestTimeout,
Expand All @@ -165,22 +173,23 @@ public TResponse Request<TResponse>(RequestData requestData)
}
catch (TheException e)
{
return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse<TResponse>(requestData, e, null, null, Stream.Null, null, -1, null, null);
return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse<TResponse>(endpoint, requestData, postData, e, null, null, Stream.Null, null, -1, null, null);
}
}

private TResponse HandleRules<TResponse, TRule>(
Endpoint endpoint,
RequestData requestData,
PostData? postData,
string origin,
IList<TRule> rules,
TimeSpan timeout,
Action<TRule> beforeReturn,
Func<TRule, byte[]> successResponse
Func<TRule, byte[]?> successResponse
)
where TResponse : TransportResponse, new()
where TRule : IRule
{
requestData.MadeItToResponse = true;
if (rules.Count == 0)
throw new Exception($"No {origin} defined for the current VirtualCluster, so we do not know how to respond");

Expand All @@ -189,32 +198,31 @@ Func<TRule, byte[]> successResponse
var always = rule.Times.Match(t => true, t => false);
var times = rule.Times.Match(t => -1, t => t);

if (rule.OnPort == null || rule.OnPort.Value != requestData.Uri.Port) continue;
if (rule.OnPort == null || rule.OnPort.Value != endpoint.Uri.Port) continue;

if (always)
return Always<TResponse, TRule>(requestData, timeout, beforeReturn, successResponse, rule);
return Always<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);

if (rule.ExecuteCount > times) continue;

return Sometimes<TResponse, TRule>(requestData, timeout, beforeReturn, successResponse, rule);
return Sometimes<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);
}
foreach (var rule in rules.Where(s => !s.OnPort.HasValue))
{
var always = rule.Times.Match(t => true, t => false);
var times = rule.Times.Match(t => -1, t => t);
if (always)
return Always<TResponse, TRule>(requestData, timeout, beforeReturn, successResponse, rule);
return Always<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);

if (rule.ExecuteCount > times) continue;

return Sometimes<TResponse, TRule>(requestData, timeout, beforeReturn, successResponse, rule);
return Sometimes<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);
}
var count = _calls.Select(kv => kv.Value.Called).Sum();
throw new Exception($@"No global or port specific {origin} rule ({requestData.Uri.Port}) matches any longer after {count} calls in to the cluster");
throw new Exception($@"No global or port specific {origin} rule ({endpoint.Uri.Port}) matches any longer after {count} calls in to the cluster");
}

private TResponse Always<TResponse, TRule>(RequestData requestData, TimeSpan timeout, Action<TRule> beforeReturn,
Func<TRule, byte[]> successResponse, TRule rule
private TResponse Always<TResponse, TRule>(Endpoint endpoint, RequestData requestData, PostData? postData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse, TRule rule
)
where TResponse : TransportResponse, new()
where TRule : IRule
Expand All @@ -231,12 +239,12 @@ private TResponse Always<TResponse, TRule>(RequestData requestData, TimeSpan tim
}

return rule.Succeeds
? Success<TResponse, TRule>(requestData, beforeReturn, successResponse, rule)
: Fail<TResponse, TRule>(requestData, rule);
? Success<TResponse, TRule>(endpoint, requestData, postData, beforeReturn, successResponse, rule)
: Fail<TResponse, TRule>(endpoint, requestData, postData, rule);
}

private TResponse Sometimes<TResponse, TRule>(
RequestData requestData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]> successResponse, TRule rule
Endpoint endpoint, RequestData requestData, PostData? postData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse, TRule rule
)
where TResponse : TransportResponse, new()
where TRule : IRule
Expand All @@ -253,16 +261,16 @@ private TResponse Sometimes<TResponse, TRule>(
}

if (rule.Succeeds)
return Success<TResponse, TRule>(requestData, beforeReturn, successResponse, rule);
return Success<TResponse, TRule>(endpoint, requestData, postData, beforeReturn, successResponse, rule);

return Fail<TResponse, TRule>(requestData, rule);
return Fail<TResponse, TRule>(endpoint, requestData, postData, rule);
}

private TResponse Fail<TResponse, TRule>(RequestData requestData, TRule rule, RuleOption<Exception, int> returnOverride = null)
private TResponse Fail<TResponse, TRule>(Endpoint endpoint, RequestData requestData, PostData? postData, TRule rule, RuleOption<Exception, int>? returnOverride = null)
where TResponse : TransportResponse, new()
where TRule : IRule
{
var state = _calls[requestData.Uri.Port];
var state = _calls[endpoint.Uri.Port];
_ = Interlocked.Increment(ref state.Failures);
var ret = returnOverride ?? rule.Return;
rule.RecordExecuted();
Expand All @@ -271,25 +279,25 @@ private TResponse Fail<TResponse, TRule>(RequestData requestData, TRule rule, Ru
throw new TheException();

return ret.Match(
(e) => throw e,
(statusCode) => _inMemoryRequestInvoker.BuildResponse<TResponse>(requestData, CallResponse(rule),
e => throw e,
statusCode => _inMemoryRequestInvoker.BuildResponse<TResponse>(endpoint, requestData, postData, CallResponse(rule),
//make sure we never return a valid status code in Fail responses because of a bad rule.
statusCode >= 200 && statusCode < 300 ? 502 : statusCode, rule.ReturnContentType)
);
}

private TResponse Success<TResponse, TRule>(RequestData requestData, Action<TRule> beforeReturn, Func<TRule, byte[]> successResponse,
private TResponse Success<TResponse, TRule>(Endpoint endpoint, RequestData requestData, PostData? postData, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse,
TRule rule
)
where TResponse : TransportResponse, new()
where TRule : IRule
{
var state = _calls[requestData.Uri.Port];
var state = _calls[endpoint.Uri.Port];
_ = Interlocked.Increment(ref state.Successes);
rule.RecordExecuted();

beforeReturn?.Invoke(rule);
return _inMemoryRequestInvoker.BuildResponse<TResponse>(requestData, successResponse(rule), contentType: rule.ReturnContentType);
return _inMemoryRequestInvoker.BuildResponse<TResponse>(endpoint, requestData, postData, successResponse(rule), contentType: rule.ReturnContentType);
}

private static byte[] CallResponse<TRule>(TRule rule)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ public class VirtualizedCluster
private Func<ITransport<ITransportConfiguration>, Func<RequestConfigurationDescriptor, IRequestConfiguration>, Task<TransportResponse>> _asyncCall;
private Func<ITransport<ITransportConfiguration>, Func<RequestConfigurationDescriptor, IRequestConfiguration>, TransportResponse> _syncCall;

private class VirtualResponse : TransportResponse { }
private class VirtualResponse : TransportResponse;

private static readonly EndpointPath RootPath = new(HttpMethod.GET, "/");

internal VirtualizedCluster(TestableDateTimeProvider dateTimeProvider, TransportConfiguration settings)
{
Expand All @@ -27,24 +29,20 @@ internal VirtualizedCluster(TestableDateTimeProvider dateTimeProvider, Transport
_exposingRequestPipeline = new ExposingPipelineFactory<ITransportConfiguration>(settings, _dateTimeProvider);

_syncCall = (t, r) => t.Request<VirtualResponse>(
method: HttpMethod.GET,
path: "/",
path: RootPath,
postData: PostData.Serializable(new { }),
requestParameters: new DefaultRequestParameters(),
openTelemetryData: default,
localConfiguration: r?.Invoke(new RequestConfigurationDescriptor(null)),
localConfiguration: r?.Invoke(new RequestConfigurationDescriptor()),
responseBuilder: null
);
_asyncCall = async (t, r) =>
{
var res = await t.RequestAsync<VirtualResponse>
(
method: HttpMethod.GET,
path: "/",
path: RootPath,
postData: PostData.Serializable(new { }),
requestParameters: new DefaultRequestParameters(),
openTelemetryData: default,
localConfiguration: r?.Invoke(new RequestConfigurationDescriptor(null)),
localConfiguration: r?.Invoke(new RequestConfigurationDescriptor()),
responseBuilder: null,
CancellationToken.None
).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ public sealed class ElasticsearchMockProductRegistration : MockProductRegistrati
public override byte[] CreateSniffResponseBytes(IReadOnlyList<Node> nodes, string stackVersion, string publishAddressOverride, bool returnFullyQualifiedDomainNames) =>
ElasticsearchSniffResponseFactory.Create(nodes, stackVersion, publishAddressOverride, returnFullyQualifiedDomainNames);

public override bool IsSniffRequest(RequestData requestData) =>
requestData.PathAndQuery.StartsWith(ElasticsearchProductRegistration.SniffPath, StringComparison.Ordinal);
public override bool IsSniffRequest(Endpoint endpoint) =>
endpoint.PathAndQuery.StartsWith(ElasticsearchProductRegistration.SniffPath, StringComparison.Ordinal);

public override bool IsPingRequest(RequestData requestData) =>
requestData.Method == HttpMethod.HEAD &&
(requestData.PathAndQuery == string.Empty || requestData.PathAndQuery.StartsWith("?"));
public override bool IsPingRequest(Endpoint endpoint) =>
endpoint.Method == HttpMethod.HEAD && (endpoint.PathAndQuery == string.Empty || endpoint.PathAndQuery.StartsWith("?"));
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public abstract class MockProductRegistration
/// see <see cref="VirtualClusterRequestInvoker.Request{TResponse}"/> uses this to determine if the current request is a sniff request and should follow
/// the sniffing rules
/// </summary>
public abstract bool IsSniffRequest(RequestData requestData);
public abstract bool IsSniffRequest(Endpoint endpoint);

public abstract bool IsPingRequest(RequestData requestData);
public abstract bool IsPingRequest(Endpoint endpoint);
}
Loading

0 comments on commit d5141a9

Please sign in to comment.