diff --git a/pkg/kthena-router/datastore/store.go b/pkg/kthena-router/datastore/store.go index ab9e1211f..c85240aac 100644 --- a/pkg/kthena-router/datastore/store.go +++ b/pkg/kthena-router/datastore/store.go @@ -1099,7 +1099,7 @@ func (s *store) DeleteModelRoute(namespacedName string) error { s.triggerCallbacks("ModelRoute", EventData{ EventType: EventDelete, ModelName: modelName, - ModelRoute: nil, + ModelRoute: deletedRoute, }) return nil } diff --git a/pkg/kthena-router/filters/ratelimit/ratelimit.go b/pkg/kthena-router/filters/ratelimit/ratelimit.go index 1d40d1e7a..20f44e3da 100644 --- a/pkg/kthena-router/filters/ratelimit/ratelimit.go +++ b/pkg/kthena-router/filters/ratelimit/ratelimit.go @@ -98,7 +98,7 @@ func NewTokenRateLimiter() *TokenRateLimiter { } // RateLimit checks if the request is within rate limits for both input and output tokens -func (r *TokenRateLimiter) RateLimit(model, prompt string) error { +func (r *TokenRateLimiter) RateLimit(limiterKey, prompt string) error { // Estimate input tokens tokens, err := r.tokenizer.CalculateTokenNum(prompt) if err != nil { @@ -107,8 +107,8 @@ func (r *TokenRateLimiter) RateLimit(model, prompt string) error { } r.mutex.RLock() - inputLimiter, hasInputLimit := r.inputLimiter[model] - outputLimiter, hasOutputLimit := r.outputLimiter[model] + inputLimiter, hasInputLimit := r.inputLimiter[limiterKey] + outputLimiter, hasOutputLimit := r.outputLimiter[limiterKey] r.mutex.RUnlock() // Check input token rate limit @@ -126,9 +126,9 @@ func (r *TokenRateLimiter) RateLimit(model, prompt string) error { } // RecordOutputTokens records the actual output tokens consumed after response generation -func (r *TokenRateLimiter) RecordOutputTokens(model string, tokenCount int) { +func (r *TokenRateLimiter) RecordOutputTokens(limiterKey string, tokenCount int) { r.mutex.RLock() - outputLimiter, exists := r.outputLimiter[model] + outputLimiter, exists := r.outputLimiter[limiterKey] r.mutex.RUnlock() if exists { @@ -137,7 +137,7 @@ func (r *TokenRateLimiter) RecordOutputTokens(model string, tokenCount int) { } // AddOrUpdateLimiter adds or updates rate limiter for a model -func (r *TokenRateLimiter) AddOrUpdateLimiter(model string, ratelimit *networkingv1alpha1.RateLimit) error { +func (r *TokenRateLimiter) AddOrUpdateLimiter(limiterKey string, ratelimit *networkingv1alpha1.RateLimit) error { r.mutex.Lock() defer r.mutex.Unlock() @@ -161,10 +161,10 @@ func (r *TokenRateLimiter) AddOrUpdateLimiter(model string, ratelimit *networkin // Create global rate limiters if ratelimit.InputTokensPerUnit != nil { - r.inputLimiter[model] = NewGlobalRateLimiter( + r.inputLimiter[limiterKey] = NewGlobalRateLimiter( r.redisClient, "kthena:ratelimit", - model, + limiterKey, "input", *ratelimit.InputTokensPerUnit, ratelimit.Unit, @@ -172,10 +172,10 @@ func (r *TokenRateLimiter) AddOrUpdateLimiter(model string, ratelimit *networkin } if ratelimit.OutputTokensPerUnit != nil { - r.outputLimiter[model] = NewGlobalRateLimiter( + r.outputLimiter[limiterKey] = NewGlobalRateLimiter( r.redisClient, "kthena:ratelimit", - model, + limiterKey, "output", *ratelimit.OutputTokensPerUnit, ratelimit.Unit, @@ -186,14 +186,14 @@ func (r *TokenRateLimiter) AddOrUpdateLimiter(model string, ratelimit *networkin duration := getTimeUnitDuration(ratelimit.Unit) if ratelimit.InputTokensPerUnit != nil { - r.inputLimiter[model] = NewLocalLimiter( + r.inputLimiter[limiterKey] = NewLocalLimiter( rate.Limit(float64(*ratelimit.InputTokensPerUnit)/duration.Seconds()), int(*ratelimit.InputTokensPerUnit), ) } if ratelimit.OutputTokensPerUnit != nil { - r.outputLimiter[model] = NewLocalLimiter( + r.outputLimiter[limiterKey] = NewLocalLimiter( rate.Limit(float64(*ratelimit.OutputTokensPerUnit)/duration.Seconds()), int(*ratelimit.OutputTokensPerUnit), ) @@ -204,12 +204,12 @@ func (r *TokenRateLimiter) AddOrUpdateLimiter(model string, ratelimit *networkin } // DeleteLimiter deletes rate limiter for a model -func (r *TokenRateLimiter) DeleteLimiter(model string) { +func (r *TokenRateLimiter) DeleteLimiter(limiterKey string) { r.mutex.Lock() defer r.mutex.Unlock() - delete(r.inputLimiter, model) - delete(r.outputLimiter, model) + delete(r.inputLimiter, limiterKey) + delete(r.outputLimiter, limiterKey) } func getTimeUnitDuration(unit networkingv1alpha1.RateLimitUnit) time.Duration { diff --git a/pkg/kthena-router/filters/ratelimit/ratelimit_test.go b/pkg/kthena-router/filters/ratelimit/ratelimit_test.go index cdaad0365..d9a3759ea 100644 --- a/pkg/kthena-router/filters/ratelimit/ratelimit_test.go +++ b/pkg/kthena-router/filters/ratelimit/ratelimit_test.go @@ -17,6 +17,7 @@ limitations under the License. package ratelimit import ( + "fmt" "testing" "time" @@ -25,26 +26,26 @@ import ( func TestTokenRateLimiter_Basic(t *testing.T) { rl := NewTokenRateLimiter() - model := "test-model" + limiterKey := testLimiterKey("default", "test-route") prompt := "hello world" // 3 tokens tokens := uint32(10) unit := networkingv1alpha1.Second - rl.AddOrUpdateLimiter(model, &networkingv1alpha1.RateLimit{ + rl.AddOrUpdateLimiter(limiterKey, &networkingv1alpha1.RateLimit{ InputTokensPerUnit: &tokens, Unit: unit, }) // Should allow up to 10 tokens immediately for i := 0; i < 3; i++ { - err := rl.RateLimit(model, prompt) + err := rl.RateLimit(limiterKey, prompt) if err != nil { t.Fatalf("unexpected error on allowed request: %v, %d", err, i) } } // 4th request should be rate limited - err := rl.RateLimit(model, prompt) + err := rl.RateLimit(limiterKey, prompt) if err == nil { t.Fatalf("expected rate limit error, got nil") } @@ -56,7 +57,7 @@ func TestTokenRateLimiter_Basic(t *testing.T) { func TestTokenRateLimiter_NoLimiter(t *testing.T) { rl := NewTokenRateLimiter() // No limiter added, should always allow - err := rl.RateLimit("unknown-model", "test") + err := rl.RateLimit(testLimiterKey("unknown", "route"), "test") if err != nil { t.Fatalf("expected nil error for unknown model, got %v", err) } @@ -64,25 +65,25 @@ func TestTokenRateLimiter_NoLimiter(t *testing.T) { func TestTokenRateLimiter_ResetAfterTime(t *testing.T) { rl := NewTokenRateLimiter() - model := "test-model" + limiterKey := testLimiterKey("default", "test-route") prompt := "hello world" tokens := uint32(10) unit := networkingv1alpha1.Second - rl.AddOrUpdateLimiter(model, &networkingv1alpha1.RateLimit{ + rl.AddOrUpdateLimiter(limiterKey, &networkingv1alpha1.RateLimit{ InputTokensPerUnit: &tokens, Unit: unit, }) // Use up tokens for i := 0; i < 3; i++ { - err := rl.RateLimit(model, prompt) + err := rl.RateLimit(limiterKey, prompt) if err != nil { t.Fatalf("unexpected error: %v", err) } } // Should be rate limited now - err := rl.RateLimit(model, prompt) + err := rl.RateLimit(limiterKey, prompt) if err == nil { t.Fatalf("expected rate limit error, got nil") } @@ -92,7 +93,7 @@ func TestTokenRateLimiter_ResetAfterTime(t *testing.T) { // Wait for refill time.Sleep(1100 * time.Millisecond) - err = rl.RateLimit(model, prompt) + err = rl.RateLimit(limiterKey, prompt) if err != nil { t.Fatalf("expected nil after refill, got %v", err) } @@ -100,48 +101,48 @@ func TestTokenRateLimiter_ResetAfterTime(t *testing.T) { func TestTokenRateLimiter_OutputTokenRecording(t *testing.T) { rl := NewTokenRateLimiter() - model := "test-model" + limiterKey := testLimiterKey("default", "test-route") tokens := uint32(10) unit := networkingv1alpha1.Second - rl.AddOrUpdateLimiter(model, &networkingv1alpha1.RateLimit{ + rl.AddOrUpdateLimiter(limiterKey, &networkingv1alpha1.RateLimit{ OutputTokensPerUnit: &tokens, Unit: unit, }) // Record output tokens - this should not block/error - rl.RecordOutputTokens(model, 5) - rl.RecordOutputTokens(model, 3) - rl.RecordOutputTokens(model, 2) // Total: 10 tokens consumed + rl.RecordOutputTokens(limiterKey, 5) + rl.RecordOutputTokens(limiterKey, 3) + rl.RecordOutputTokens(limiterKey, 2) // Total: 10 tokens consumed // Recording more tokens should still work (just consumes from the bucket) - rl.RecordOutputTokens(model, 1) + rl.RecordOutputTokens(limiterKey, 1) } func TestTokenRateLimiter_CombinedInputOutput(t *testing.T) { rl := NewTokenRateLimiter() - model := "test-model" + limiterKey := testLimiterKey("default", "test-route") prompt := "hello world hello world" // Should be ~6 tokens inputTokens := uint32(8) // Allow only one request (6 tokens < 8, but two requests = 12 > 8) outputTokens := uint32(10) // Allow output recording unit := networkingv1alpha1.Second - rl.AddOrUpdateLimiter(model, &networkingv1alpha1.RateLimit{ + rl.AddOrUpdateLimiter(limiterKey, &networkingv1alpha1.RateLimit{ InputTokensPerUnit: &inputTokens, OutputTokensPerUnit: &outputTokens, Unit: unit, }) // First request should be allowed - err := rl.RateLimit(model, prompt) + err := rl.RateLimit(limiterKey, prompt) if err != nil { t.Fatalf("unexpected error on first request: %v", err) } // Record output tokens used - rl.RecordOutputTokens(model, 2) + rl.RecordOutputTokens(limiterKey, 2) // Second request should be rate limited due to input token exhaustion - err = rl.RateLimit(model, prompt) + err = rl.RateLimit(limiterKey, prompt) if err == nil { t.Fatalf("expected rate limit error after exhausting input tokens") } @@ -153,72 +154,72 @@ func TestTokenRateLimiter_CombinedInputOutput(t *testing.T) { func TestTokenRateLimiter_OutputNoLimiter(t *testing.T) { rl := NewTokenRateLimiter() // No limiter added, should not error when recording output tokens - rl.RecordOutputTokens("unknown-model", 100) + rl.RecordOutputTokens(testLimiterKey("unknown", "route"), 100) // RecordOutputTokens doesn't return error, just silently does nothing } func TestTokenRateLimiter_DeleteLimiter(t *testing.T) { rl := NewTokenRateLimiter() - model := "test-model" + limiterKey := testLimiterKey("default", "test-route") inputTokens := uint32(3) // Very restrictive outputTokens := uint32(5) unit := networkingv1alpha1.Second - rl.AddOrUpdateLimiter(model, &networkingv1alpha1.RateLimit{ + rl.AddOrUpdateLimiter(limiterKey, &networkingv1alpha1.RateLimit{ InputTokensPerUnit: &inputTokens, OutputTokensPerUnit: &outputTokens, Unit: unit, }) // Verify limiter exists and restricts - err := rl.RateLimit(model, "hello world") // ~3 tokens + err := rl.RateLimit(limiterKey, "hello world") // ~3 tokens if err != nil { t.Fatalf("first request should be allowed: %v", err) } - err = rl.RateLimit(model, "hello world") // Should be rate limited + err = rl.RateLimit(limiterKey, "hello world") // Should be rate limited if err == nil { t.Fatalf("expected rate limit error") } // Delete limiters - rl.DeleteLimiter(model) + rl.DeleteLimiter(limiterKey) // Should now be unrestricted for i := 0; i < 10; i++ { - err = rl.RateLimit(model, "hello world") + err = rl.RateLimit(limiterKey, "hello world") if err != nil { t.Fatalf("expected nil after deletion, got %v", err) } } // Recording output tokens should work without error - rl.RecordOutputTokens(model, 100) + rl.RecordOutputTokens(limiterKey, 100) } func TestTokenRateLimiter_OutputRateLimit(t *testing.T) { rl := NewTokenRateLimiter() - model := "test-model" + limiterKey := testLimiterKey("default", "test-route") prompt := "hello world" outputTokens := uint32(5) // Very low limit unit := networkingv1alpha1.Second - rl.AddOrUpdateLimiter(model, &networkingv1alpha1.RateLimit{ + rl.AddOrUpdateLimiter(limiterKey, &networkingv1alpha1.RateLimit{ OutputTokensPerUnit: &outputTokens, Unit: unit, }) // First request should be allowed (has 5 tokens available) - err := rl.RateLimit(model, prompt) + err := rl.RateLimit(limiterKey, prompt) if err != nil { t.Fatalf("first request should be allowed: %v", err) } // Consume most tokens - rl.RecordOutputTokens(model, 5) + rl.RecordOutputTokens(limiterKey, 5) // Next request should be blocked due to insufficient output tokens - err = rl.RateLimit(model, prompt) + err = rl.RateLimit(limiterKey, prompt) if err == nil { t.Fatalf("expected output rate limit error") } @@ -229,19 +230,19 @@ func TestTokenRateLimiter_OutputRateLimit(t *testing.T) { func TestTokenRateLimiter_InputAndOutputErrors(t *testing.T) { rl := NewTokenRateLimiter() - model := "test-model" longPrompt := "hello world hello world hello world" // Should be ~9 tokens inputTokens := uint32(5) // Very low input limit outputTokens := uint32(10) // Higher output limit unit := networkingv1alpha1.Second // Test input rate limit error - rl.AddOrUpdateLimiter(model+"-input", &networkingv1alpha1.RateLimit{ + inputLimiterKey := testLimiterKey("default", "test-route-input") + rl.AddOrUpdateLimiter(inputLimiterKey, &networkingv1alpha1.RateLimit{ InputTokensPerUnit: &inputTokens, Unit: unit, }) - err := rl.RateLimit(model+"-input", longPrompt) + err := rl.RateLimit(inputLimiterKey, longPrompt) if err == nil { t.Fatalf("expected input rate limit error") } @@ -250,25 +251,26 @@ func TestTokenRateLimiter_InputAndOutputErrors(t *testing.T) { } // Test output rate limit error - rl.AddOrUpdateLimiter(model+"-output", &networkingv1alpha1.RateLimit{ + outputLimiterKey := testLimiterKey("default", "test-route-output") + rl.AddOrUpdateLimiter(outputLimiterKey, &networkingv1alpha1.RateLimit{ OutputTokensPerUnit: &outputTokens, Unit: unit, }) // First make a successful request to establish the limiter - err = rl.RateLimit(model+"-output", "short") + err = rl.RateLimit(outputLimiterKey, "short") if err != nil { t.Fatalf("first request should succeed: %v", err) } // Consume all available output tokens - rl.RecordOutputTokens(model+"-output", 10) // Consume all 10 tokens + rl.RecordOutputTokens(outputLimiterKey, 10) // Consume all 10 tokens // Wait a bit for the tokens to be recorded time.Sleep(10 * time.Millisecond) // Next request should be blocked due to insufficient output tokens (< 1 token available) - err = rl.RateLimit(model+"-output", "short") // Short prompt to avoid input limit + err = rl.RateLimit(outputLimiterKey, "short") // Short prompt to avoid input limit if err == nil { t.Fatalf("expected output rate limit error") } @@ -276,3 +278,7 @@ func TestTokenRateLimiter_InputAndOutputErrors(t *testing.T) { t.Fatalf("expected OutputRateLimitExceededError, got %T: %v", err, err) } } + +func testLimiterKey(namespace, modelRouteName string) string { + return fmt.Sprintf("%s/%s", namespace, modelRouteName) +} diff --git a/pkg/kthena-router/router/router.go b/pkg/kthena-router/router/router.go index d0f2edce8..d1f9bb7d4 100644 --- a/pkg/kthena-router/router/router.go +++ b/pkg/kthena-router/router/router.go @@ -101,21 +101,27 @@ func NewRouter(store datastore.Store, routerConfigPath string) *Router { tokenizerInstance := tokenizer.NewSimpleEstimateTokenizer() store.RegisterCallback("ModelRoute", func(data datastore.EventData) { + routeKey := fmt.Sprintf("%s/%s", + data.ModelRoute.Namespace, + data.ModelRoute.Name, + ) switch data.EventType { case datastore.EventAdd, datastore.EventUpdate: if data.ModelRoute == nil || data.ModelRoute.Spec.RateLimit == nil { return } - klog.Infof("add or update rate limit for model %s", data.ModelName) + // Use namespace/routename as the rate limit key + klog.Infof("add or update rate limit for route %s", routeKey) - // Configure the unified rate limiter for this model - if err := loadRateLimiter.AddOrUpdateLimiter(data.ModelName, data.ModelRoute.Spec.RateLimit); err != nil { - klog.Errorf("failed to configure rate limiter for model %s: %v", data.ModelName, err) + // Configure the unified rate limiter for this route + if err := loadRateLimiter.AddOrUpdateLimiter(routeKey, data.ModelRoute.Spec.RateLimit); err != nil { + klog.Errorf("failed to configure rate limiter for route %s: %v", routeKey, err) } case datastore.EventDelete: - klog.Infof("delete rate limit for model %s", data.ModelName) - loadRateLimiter.DeleteLimiter(data.ModelName) + // Use namespace/routename as the rate limit key + klog.Infof("delete rate limit for route %s", routeKey) + loadRateLimiter.DeleteLimiter(routeKey) } }) @@ -228,6 +234,28 @@ func (r *Router) HandlerFunc() gin.HandlerFunc { // Store model name in context for metrics middleware c.Set("model", modelName) + // Get gateway key from context if available (set by Gateway listener) + var gatewayKey string + if key, exists := c.Get(GatewayKey); exists { + if k, ok := key.(string); ok { + gatewayKey = k + } + } + if gatewayKey != "" { + accesslog.SetGatewayAPIInfo(c, gatewayKey, "", "") + } + + // Early route matching + matchedModelServerName, matchedIsLora, matchedModelRoute, matchedMatchError := r.store.MatchModelServer(modelName, c.Request, gatewayKey) + c.Set("matchedModelServerName", matchedModelServerName) + c.Set("matchedIsLora", matchedIsLora) + if matchedModelRoute != nil { + c.Set("matchedModelRoute", matchedModelRoute) + } + if matchedMatchError != nil { + c.Set("matchedMatchError", matchedMatchError) + } + // Create metrics recorder for this request path := c.Request.URL.Path metricsRecorder := metrics.NewRequestMetricsRecorder(r.metrics, modelName, path) @@ -274,8 +302,18 @@ func (r *Router) HandlerFunc() gin.HandlerFunc { // Record input tokens immediately metricsRecorder.RecordInputTokens(inputTokens) + // Determine rate limit key + var rateLimitKey string + + if matchedModelRoute != nil { + rateLimitKey = fmt.Sprintf("%s/%s", matchedModelRoute.Namespace, matchedModelRoute.Name) + } else { + // HTTPRoute or fallback: use model-scoped + rateLimitKey = modelName + } + // Apply rate limiting using the unified rate limiter - if err := r.loadRateLimiter.RateLimit(modelName, promptStr); err != nil { + if err := r.loadRateLimiter.RateLimit(rateLimitKey, promptStr); err != nil { var errorMsg string var errorType string var tokenType string @@ -348,8 +386,23 @@ func (r *Router) doLoadbalance(c *gin.Context, modelRequest ModelRequest) { var isLora bool var err error - // Try to match ModelRoute first - modelServerName, isLora, modelRoute, err = r.store.MatchModelServer(modelName, c.Request, gatewayKey) + // Retrieve cached ModelRoute matching results from the context + if cachedServerName, exists := c.Get("matchedModelServerName"); exists { + modelServerName = cachedServerName.(types.NamespacedName) + if cachedIsLora, ok := c.Get("matchedIsLora"); ok { + isLora = cachedIsLora.(bool) + } + if cachedRoute, ok := c.Get("matchedModelRoute"); ok { + modelRoute = cachedRoute.(*v1alpha1.ModelRoute) + } + if cachedErr, ok := c.Get("matchedMatchError"); ok { + err = cachedErr.(error) + } + } else { + // Fallback to match if not cached + modelServerName, isLora, modelRoute, err = r.store.MatchModelServer(modelName, c.Request, gatewayKey) + } + if err != nil { accesslog.SetError(c, "model_server_matching", fmt.Sprintf("can't find corresponding model server: %v", err)) } @@ -806,9 +859,11 @@ func (r *Router) proxyModelEndpoint( if resp.Usage.TotalTokens <= 0 { return } - // Record output tokens for rate limiting + // Record output tokens for rate limiting using the route key if r.loadRateLimiter != nil { - r.loadRateLimiter.RecordOutputTokens(modelName, resp.Usage.CompletionTokens) + if rateLimitKeyVal, ok := c.Get("rateLimitKey"); ok { + r.loadRateLimiter.RecordOutputTokens(rateLimitKeyVal.(string), resp.Usage.CompletionTokens) + } } // Update access log with output tokens if accessCtx := accesslog.GetAccessLogContext(c); accessCtx != nil { @@ -1110,7 +1165,9 @@ func (r *Router) proxyToPDDisaggregated( // Record output tokens for rate limiting if outputTokens > 0 && r.loadRateLimiter != nil { - r.loadRateLimiter.RecordOutputTokens(ctx.Model, outputTokens) + if rateLimitKeyVal, ok := c.Get("rateLimitKey"); ok { + r.loadRateLimiter.RecordOutputTokens(rateLimitKeyVal.(string), outputTokens) + } } // Record output token metrics