Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package slo_aware_router
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/rand"
"sync"
Expand Down Expand Up @@ -93,6 +94,10 @@ func SLOAwareRouterFactory(name string, rawParameters json.RawMessage, handle pl
}
}

if err := parameters.validate(); err != nil {
return nil, fmt.Errorf("invalid SLOAwareRouter config: %w", err)
}

predictor, err := startPredictor(handle)
if err != nil {
return nil, fmt.Errorf("failed to start latency predictor: %w", err)
Expand All @@ -101,6 +106,50 @@ func SLOAwareRouterFactory(name string, rawParameters json.RawMessage, handle pl
return NewSLOAwareRouter(parameters, predictor).WithName(name), nil
}

func (c *Config) validate() error {
var errs []error

if c.SamplingMean <= 0 {
errs = append(errs, fmt.Errorf("samplingMean must be > 0, got %f", c.SamplingMean))
}

if c.MaxSampledTokens <= 0 {
errs = append(errs, fmt.Errorf("maxSampledTokens must be > 0, got %d", c.MaxSampledTokens))
}

if c.SLOBufferFactor <= 0 {
errs = append(errs, fmt.Errorf("sloBufferFactor must be > 0, got %f", c.SLOBufferFactor))
}

if c.NegHeadroomTTFTWeight < 0 || c.NegHeadroomTPOTWeight < 0 ||
c.HeadroomTTFTWeight < 0 || c.HeadroomTPOTWeight < 0 {
errs = append(errs, errors.New("all headroom weights must be >= 0"))
}

if c.CompositeKVWeight < 0 || c.CompositeQueueWeight < 0 || c.CompositePrefixWeight < 0 {
errs = append(errs, errors.New("composite weights must be >= 0"))
}

if c.EpsilonExploreSticky < 0 || c.EpsilonExploreSticky > 1 {
errs = append(errs, fmt.Errorf("epsilonExploreSticky must be in [0, 1], got %f", c.EpsilonExploreSticky))
}
if c.EpsilonExploreNeg < 0 || c.EpsilonExploreNeg > 1 {
errs = append(errs, fmt.Errorf("epsilonExploreNeg must be in [0, 1], got %f", c.EpsilonExploreNeg))
}

if c.AffinityGateTau <= 0 || c.AffinityGateTau > 1 {
Copy link
Contributor

@kaushikmitr kaushikmitr Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AffinityGateTau can be 0

errs = append(errs, fmt.Errorf("affinityGateTau must be in (0, 1], got %f", c.AffinityGateTau))
}
if c.AffinityGateTauGlobal <= 0 || c.AffinityGateTauGlobal > 1 {
errs = append(errs, fmt.Errorf("affinityGateTauGlobal must be in (0, 1], got %f", c.AffinityGateTauGlobal))
}

if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}

func NewSLOAwareRouter(config Config, predictor latencypredictor.PredictorInterface) *SLOAwareRouter {
strategy := headroomStrategy(config.HeadroomSelectionStrategy)
if strategy == "" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package slo_aware_router

import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
Expand All @@ -31,6 +32,7 @@ import (
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
"sigs.k8s.io/gateway-api-inference-extension/test/utils"
)

// mockPredictor implements PredictorInterface for testing
Expand Down Expand Up @@ -534,3 +536,161 @@ func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) {
})
}
}

func TestSLOAwareRouterFactory(t *testing.T) {
tests := []struct {
name string
pluginName string
jsonParams string
expectErr bool
}{
{
name: "valid config with all fields",
pluginName: "full-config",
jsonParams: `{
"samplingMean": 150.0,
"maxSampledTokens": 30,
"sloBufferFactor": 1.2,
"negHeadroomTTFTWeight": 0.7,
"negHeadroomTPOTWeight": 0.3,
"headroomTTFTWeight": 0.9,
"headroomTPOTWeight": 0.1,
"headroomSelectionStrategy": "least",
"compositeKVWeight": 1.0,
"compositeQueueWeight": 0.8,
"compositePrefixWeight": 0.5,
"epsilonExploreSticky": 0.02,
"epsilonExploreNeg": 0.03,
"affinityGateTau": 0.85,
"affinityGateTauGlobal": 0.95,
"selectionMode": "linear"
}`,
expectErr: false,
},
{
name: "valid config with minimal override (uses defaults)",
pluginName: "minimal",
jsonParams: `{}`,
expectErr: false,
},
{
name: "valid config with composite strategy",
pluginName: "composite",
jsonParams: `{
"headroomSelectionStrategy": "composite-least",
"selectionMode": "linear"
}`,
expectErr: false,
},
{
name: "invalid samplingMean <= 0",
pluginName: "bad-sampling-mean",
jsonParams: `{"samplingMean": -1.0}`,
expectErr: true,
},
{
name: "invalid maxSampledTokens <= 0",
pluginName: "bad-max-tokens",
jsonParams: `{"maxSampledTokens": 0}`,
expectErr: true,
},
{
name: "invalid sloBufferFactor <= 0",
pluginName: "bad-buffer",
jsonParams: `{"sloBufferFactor": 0}`,
expectErr: true,
},
{
name: "negative headroom weight",
pluginName: "neg-weight",
jsonParams: `{"negHeadroomTTFTWeight": -0.1}`,
expectErr: true,
},
{
name: "epsilonExploreSticky > 1",
pluginName: "epsilon-too-high",
jsonParams: `{"epsilonExploreSticky": 1.1}`,
expectErr: true,
},
{
name: "epsilonExploreNeg < 0",
pluginName: "epsilon-negative",
jsonParams: `{"epsilonExploreNeg": -0.1}`,
expectErr: true,
},
{
name: "affinityGateTau out of (0,1]",
pluginName: "tau-invalid",
jsonParams: `{"affinityGateTau": 1.5}`,
expectErr: true,
},
{
name: "affinityGateTauGlobal <= 0",
pluginName: "tau-global-zero",
jsonParams: `{"affinityGateTauGlobal": 0}`,
expectErr: true,
},
{
name: "multiple validation errors",
pluginName: "multi-error",
jsonParams: `{
"samplingMean": -1,
"maxSampledTokens": 0,
"epsilonExploreSticky": 2.0,
"headroomSelectionStrategy": "unknown"
}`,
expectErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handle := utils.NewTestHandle(context.Background())
rawParams := json.RawMessage(tt.jsonParams)
plugin, err := SLOAwareRouterFactory(tt.pluginName, rawParams, handle)

if tt.expectErr {
assert.Error(t, err)
assert.Nil(t, plugin)
} else {
assert.NoError(t, err)
assert.NotNil(t, plugin)
}
})
}
}

func TestSLOAwareRouterFactoryInvalidJSON(t *testing.T) {
invalidTests := []struct {
name string
jsonParams string
}{
{
name: "malformed JSON",
jsonParams: `{"samplingMean": 100.0, "maxSampledTokens":`, // incomplete
},
{
name: "samplingMean as string",
jsonParams: `{"samplingMean": "100"}`,
},
{
name: "maxSampledTokens as float",
jsonParams: `{"maxSampledTokens": 20.5}`,
},
{
name: "headroomSelectionStrategy as number",
jsonParams: `{"headroomSelectionStrategy": 123}`,
},
}

for _, tt := range invalidTests {
t.Run(tt.name, func(t *testing.T) {
handle := utils.NewTestHandle(context.Background())
rawParams := json.RawMessage(tt.jsonParams)
plugin, err := SLOAwareRouterFactory("test", rawParams, handle)

assert.Error(t, err)
assert.Nil(t, plugin)
})
}
}