From c32694ff67dae7654a85a8a5e063035d7036500b Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Thu, 19 Sep 2024 15:08:02 +0200 Subject: [PATCH] Use `WithRequestLimit` with `0` to skip rate limit --- context.go | 16 ++++++---------- go.mod | 2 +- go.sum | 2 -- limiter.go | 24 ++++++++++++++++-------- 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/context.go b/context.go index 138110d..3988dd6 100644 --- a/context.go +++ b/context.go @@ -13,20 +13,16 @@ func WithIncrement(ctx context.Context, value int) context.Context { return context.WithValue(ctx, incrementKey, value) } -func getIncrement(ctx context.Context) int { - if value, ok := ctx.Value(incrementKey).(int); ok { - return value - } - return 1 +func getIncrement(ctx context.Context) (int, bool) { + value, ok := ctx.Value(incrementKey).(int) + return value, ok } func WithRequestLimit(ctx context.Context, value int) context.Context { return context.WithValue(ctx, requestLimitKey, value) } -func getRequestLimit(ctx context.Context) int { - if value, ok := ctx.Value(requestLimitKey).(int); ok { - return value - } - return 0 +func getRequestLimit(ctx context.Context) (int, bool) { + value, ok := ctx.Value(requestLimitKey).(int) + return value, ok } diff --git a/go.mod b/go.mod index 998cbf5..a63a858 100644 --- a/go.mod +++ b/go.mod @@ -4,4 +4,4 @@ go 1.17 require github.com/cespare/xxhash/v2 v2.3.0 -require golang.org/x/sync v0.7.0 // indirect +require golang.org/x/sync v0.7.0 diff --git a/go.sum b/go.sum index 09aebbf..2cc97dd 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= -github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= diff --git a/limiter.go b/limiter.go index d5ef436..fe52545 100644 --- a/limiter.go +++ b/limiter.go @@ -72,13 +72,26 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string currentWindow := time.Now().UTC().Truncate(l.windowLength) ctx := r.Context() - limit := l.requestLimit - if val := getRequestLimit(ctx); val > 0 { - limit = val + limit, ok := getRequestLimit(ctx) + if !ok { + limit = l.requestLimit } + + if limit <= 0 { + return false + } + setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit)) setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix())) + increment, ok := getIncrement(r.Context()) + if !ok { + increment = 1 + } + if increment > 1 { + setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment)) + } + l.mu.Lock() _, rateFloat, err := l.calculateRate(key, limit) if err != nil { @@ -88,11 +101,6 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string } rate := int(math.Round(rateFloat)) - increment := getIncrement(r.Context()) - if increment > 1 { - setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment)) - } - if rate+increment > limit { setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate))