Skip to content

Commit

Permalink
Use WithRequestLimit with 0 to skip rate limit
Browse files Browse the repository at this point in the history
  • Loading branch information
klaidliadon committed Sep 19, 2024
1 parent ae11543 commit c32694f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 21 deletions.
16 changes: 6 additions & 10 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,16 @@ func WithIncrement(ctx context.Context, value int) context.Context {
return context.WithValue(ctx, incrementKey, value)
}

func getIncrement(ctx context.Context) int {
if value, ok := ctx.Value(incrementKey).(int); ok {
return value
}
return 1
func getIncrement(ctx context.Context) (int, bool) {
value, ok := ctx.Value(incrementKey).(int)
return value, ok
}

func WithRequestLimit(ctx context.Context, value int) context.Context {
return context.WithValue(ctx, requestLimitKey, value)
}

func getRequestLimit(ctx context.Context) int {
if value, ok := ctx.Value(requestLimitKey).(int); ok {
return value
}
return 0
func getRequestLimit(ctx context.Context) (int, bool) {
value, ok := ctx.Value(requestLimitKey).(int)
return value, ok
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ go 1.17

require github.com/cespare/xxhash/v2 v2.3.0

require golang.org/x/sync v0.7.0 // indirect
require golang.org/x/sync v0.7.0
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
Expand Down
24 changes: 16 additions & 8 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,26 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string
currentWindow := time.Now().UTC().Truncate(l.windowLength)
ctx := r.Context()

limit := l.requestLimit
if val := getRequestLimit(ctx); val > 0 {
limit = val
limit, ok := getRequestLimit(ctx)
if !ok {
limit = l.requestLimit
}

if limit <= 0 {
return false
}

setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit))
setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix()))

increment, ok := getIncrement(r.Context())
if !ok {
increment = 1
}
if increment > 1 {
setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment))
}

l.mu.Lock()
_, rateFloat, err := l.calculateRate(key, limit)
if err != nil {
Expand All @@ -88,11 +101,6 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string
}
rate := int(math.Round(rateFloat))

increment := getIncrement(r.Context())
if increment > 1 {
setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment))
}

if rate+increment > limit {
setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate))

Expand Down

0 comments on commit c32694f

Please sign in to comment.