Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added an option to avoid interrupting the request pipeline #425

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public ClientRateLimitMiddleware(RequestDelegate next,
IClientPolicyStore policyStore,
IRateLimitConfiguration config,
ILogger<ClientRateLimitMiddleware> logger)
: base(next, options?.Value, new ClientRateLimitProcessor(options?.Value, policyStore, processingStrategy), config)
: base(next, options?.Value, new ClientRateLimitProcessor(options?.Value, policyStore, processingStrategy), config, logger)
{
_logger = logger;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public IpRateLimitMiddleware(RequestDelegate next,
IRateLimitConfiguration config,
ILogger<IpRateLimitMiddleware> logger
)
: base(next, options?.Value, new IpRateLimitProcessor(options?.Value, policyStore, processingStrategy), config)
: base(next, options?.Value, new IpRateLimitProcessor(options?.Value, policyStore, processingStrategy), config, logger)
{
_logger = logger;
}
Expand Down
108 changes: 62 additions & 46 deletions src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;

namespace AspNetCoreRateLimit
{
Expand All @@ -14,17 +16,20 @@ public abstract class RateLimitMiddleware<TProcessor>
private readonly TProcessor _processor;
private readonly RateLimitOptions _options;
private readonly IRateLimitConfiguration _config;
private readonly ILogger<RateLimitMiddleware<TProcessor>> _logger;

protected RateLimitMiddleware(
RequestDelegate next,
RateLimitOptions options,
TProcessor processor,
IRateLimitConfiguration config)
IRateLimitConfiguration config,
ILogger<RateLimitMiddleware<TProcessor>> logger)
{
_next = next;
_options = options;
_processor = processor;
_config = config;
_logger = logger;
_config.RegisterResolvers();
}

Expand All @@ -47,29 +52,51 @@ public async Task Invoke(HttpContext context)
return;
}

var rules = await _processor.GetMatchingRulesAsync(identity, context.RequestAborted);

var rulesDict = new Dictionary<RateLimitRule, RateLimitCounter>();

foreach (var rule in rules)
try
{
// increment counter
var rateLimitCounter = await _processor.ProcessRequestAsync(identity, rule, context.RequestAborted);
var rules = await _processor.GetMatchingRulesAsync(identity, context.RequestAborted);

if (rule.Limit > 0)
var rulesDict = new Dictionary<RateLimitRule, RateLimitCounter>();

foreach (var rule in rules)
{
// check if key expired
if (rateLimitCounter.Timestamp + rule.PeriodTimespan.Value < DateTime.UtcNow)
{
continue;
}
// increment counter
var rateLimitCounter = await _processor.ProcessRequestAsync(identity, rule, context.RequestAborted);

// check if limit is reached
if (rateLimitCounter.Count > rule.Limit)
if (rule.Limit > 0)
{
//compute retry after value
var retryAfter = rateLimitCounter.Timestamp.RetryAfterFrom(rule);
// check if key expired
if (rateLimitCounter.Timestamp + rule.PeriodTimespan.Value < DateTime.UtcNow)
{
continue;
}

// check if limit is reached
if (rateLimitCounter.Count > rule.Limit)
{
//compute retry after value
var retryAfter = rateLimitCounter.Timestamp.RetryAfterFrom(rule);

// log blocked request
LogBlockedRequest(context, identity, rateLimitCounter, rule);

if (_options.RequestBlockedBehaviorAsync != null)
{
await _options.RequestBlockedBehaviorAsync(context, identity, rateLimitCounter, rule);
}

if (!rule.MonitorMode)
{
// break execution
await ReturnQuotaExceededResponse(context, rule, retryAfter);

return;
}
}
}
// if limit is zero or less, block the request.
else
{
// log blocked request
LogBlockedRequest(context, identity, rateLimitCounter, rule);

Expand All @@ -80,45 +107,34 @@ public async Task Invoke(HttpContext context)

if (!rule.MonitorMode)
{
// break execution
await ReturnQuotaExceededResponse(context, rule, retryAfter);
// break execution (Int32 max used to represent infinity)
await ReturnQuotaExceededResponse(context, rule,
int.MaxValue.ToString(CultureInfo.InvariantCulture));

return;
}
}

rulesDict.Add(rule, rateLimitCounter);
}
// if limit is zero or less, block the request.
else
{
// log blocked request
LogBlockedRequest(context, identity, rateLimitCounter, rule);

if (_options.RequestBlockedBehaviorAsync != null)
{
await _options.RequestBlockedBehaviorAsync(context, identity, rateLimitCounter, rule);
}
// set X-Rate-Limit headers for the longest period
if (rulesDict.Any() && !_options.DisableRateLimitHeaders)
{
var rule = rulesDict.OrderByDescending(x => x.Key.PeriodTimespan).FirstOrDefault();
var headers = _processor.GetRateLimitHeaders(rule.Value, rule.Key, context.RequestAborted);

if (!rule.MonitorMode)
{
// break execution (Int32 max used to represent infinity)
await ReturnQuotaExceededResponse(context, rule, int.MaxValue.ToString(System.Globalization.CultureInfo.InvariantCulture));
headers.Context = context;

return;
}
context.Response.OnStarting(SetRateLimitHeaders, state: headers);
}

rulesDict.Add(rule, rateLimitCounter);
}

// set X-Rate-Limit headers for the longest period
if (rulesDict.Any() && !_options.DisableRateLimitHeaders)
catch (Exception e)
{
var rule = rulesDict.OrderByDescending(x => x.Key.PeriodTimespan).FirstOrDefault();
var headers = _processor.GetRateLimitHeaders(rule.Value, rule.Key, context.RequestAborted);

headers.Context = context;

context.Response.OnStarting(SetRateLimitHeaders, state: headers);
if (_options.DoNotInterruptRequestPipelineOnFailure)
_logger.LogError(e, "An error occured while processing the rate limit");
else
throw;
}

await _next.Invoke(context);
Expand Down
5 changes: 5 additions & 0 deletions src/AspNetCoreRateLimit/Models/RateLimitOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,10 @@ public class RateLimitOptions
/// Gets or sets behavior after the request is blocked
/// </summary>
public Func<HttpContext, ClientRequestIdentity, RateLimitCounter, RateLimitRule, Task> RequestBlockedBehaviorAsync { get; set; }

/// <summary>
/// Gets or sets the behavior that determines whether the request pipeline should be aborted in case of any rate limiting issues (i.e. Redis or SQLServer is not available when used as a distributed counter store).
/// </summary>
public bool DoNotInterruptRequestPipelineOnFailure { get; set; }
}
}