diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go index c72ec5874..d0e77cca3 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -19,6 +19,7 @@ package slo_aware_router import ( "context" "encoding/json" + "errors" "fmt" "math/rand" "sync" @@ -93,6 +94,10 @@ func SLOAwareRouterFactory(name string, rawParameters json.RawMessage, handle pl } } + if err := parameters.validate(); err != nil { + return nil, fmt.Errorf("invalid SLOAwareRouter config: %w", err) + } + predictor, err := startPredictor(handle) if err != nil { return nil, fmt.Errorf("failed to start latency predictor: %w", err) @@ -101,6 +106,50 @@ func SLOAwareRouterFactory(name string, rawParameters json.RawMessage, handle pl return NewSLOAwareRouter(parameters, predictor).WithName(name), nil } +func (c *Config) validate() error { + var errs []error + + if c.SamplingMean <= 0 { + errs = append(errs, fmt.Errorf("samplingMean must be > 0, got %f", c.SamplingMean)) + } + + if c.MaxSampledTokens <= 0 { + errs = append(errs, fmt.Errorf("maxSampledTokens must be > 0, got %d", c.MaxSampledTokens)) + } + + if c.SLOBufferFactor <= 0 { + errs = append(errs, fmt.Errorf("sloBufferFactor must be > 0, got %f", c.SLOBufferFactor)) + } + + if c.NegHeadroomTTFTWeight < 0 || c.NegHeadroomTPOTWeight < 0 || + c.HeadroomTTFTWeight < 0 || c.HeadroomTPOTWeight < 0 { + errs = append(errs, errors.New("all headroom weights must be >= 0")) + } + + if c.CompositeKVWeight < 0 || c.CompositeQueueWeight < 0 || c.CompositePrefixWeight < 0 { + errs = append(errs, errors.New("composite weights must be >= 0")) + } + + if c.EpsilonExploreSticky < 0 || c.EpsilonExploreSticky > 1 { + errs = append(errs, fmt.Errorf("epsilonExploreSticky must be in [0, 1], got %f", c.EpsilonExploreSticky)) + } + if c.EpsilonExploreNeg < 0 || c.EpsilonExploreNeg > 1 { + errs = append(errs, fmt.Errorf("epsilonExploreNeg must be in [0, 1], got %f", c.EpsilonExploreNeg)) + } + + if c.AffinityGateTau <= 0 || c.AffinityGateTau > 1 { + errs = append(errs, fmt.Errorf("affinityGateTau must be in (0, 1], got %f", c.AffinityGateTau)) + } + if c.AffinityGateTauGlobal <= 0 || c.AffinityGateTauGlobal > 1 { + errs = append(errs, fmt.Errorf("affinityGateTauGlobal must be in (0, 1], got %f", c.AffinityGateTauGlobal)) + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} + func NewSLOAwareRouter(config Config, predictor latencypredictor.PredictorInterface) *SLOAwareRouter { strategy := headroomStrategy(config.HeadroomSelectionStrategy) if strategy == "" { diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go index 21aaa375d..e4a28dae6 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go @@ -18,6 +18,7 @@ package slo_aware_router import ( "context" + "encoding/json" "errors" "fmt" "strconv" @@ -31,6 +32,7 @@ import ( schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" + "sigs.k8s.io/gateway-api-inference-extension/test/utils" ) // mockPredictor implements PredictorInterface for testing @@ -534,3 +536,161 @@ func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) { }) } } + +func TestSLOAwareRouterFactory(t *testing.T) { + tests := []struct { + name string + pluginName string + jsonParams string + expectErr bool + }{ + { + name: "valid config with all fields", + pluginName: "full-config", + jsonParams: `{ + "samplingMean": 150.0, + "maxSampledTokens": 30, + "sloBufferFactor": 1.2, + "negHeadroomTTFTWeight": 0.7, + "negHeadroomTPOTWeight": 0.3, + "headroomTTFTWeight": 0.9, + "headroomTPOTWeight": 0.1, + "headroomSelectionStrategy": "least", + "compositeKVWeight": 1.0, + "compositeQueueWeight": 0.8, + "compositePrefixWeight": 0.5, + "epsilonExploreSticky": 0.02, + "epsilonExploreNeg": 0.03, + "affinityGateTau": 0.85, + "affinityGateTauGlobal": 0.95, + "selectionMode": "linear" + }`, + expectErr: false, + }, + { + name: "valid config with minimal override (uses defaults)", + pluginName: "minimal", + jsonParams: `{}`, + expectErr: false, + }, + { + name: "valid config with composite strategy", + pluginName: "composite", + jsonParams: `{ + "headroomSelectionStrategy": "composite-least", + "selectionMode": "linear" + }`, + expectErr: false, + }, + { + name: "invalid samplingMean <= 0", + pluginName: "bad-sampling-mean", + jsonParams: `{"samplingMean": -1.0}`, + expectErr: true, + }, + { + name: "invalid maxSampledTokens <= 0", + pluginName: "bad-max-tokens", + jsonParams: `{"maxSampledTokens": 0}`, + expectErr: true, + }, + { + name: "invalid sloBufferFactor <= 0", + pluginName: "bad-buffer", + jsonParams: `{"sloBufferFactor": 0}`, + expectErr: true, + }, + { + name: "negative headroom weight", + pluginName: "neg-weight", + jsonParams: `{"negHeadroomTTFTWeight": -0.1}`, + expectErr: true, + }, + { + name: "epsilonExploreSticky > 1", + pluginName: "epsilon-too-high", + jsonParams: `{"epsilonExploreSticky": 1.1}`, + expectErr: true, + }, + { + name: "epsilonExploreNeg < 0", + pluginName: "epsilon-negative", + jsonParams: `{"epsilonExploreNeg": -0.1}`, + expectErr: true, + }, + { + name: "affinityGateTau out of (0,1]", + pluginName: "tau-invalid", + jsonParams: `{"affinityGateTau": 1.5}`, + expectErr: true, + }, + { + name: "affinityGateTauGlobal <= 0", + pluginName: "tau-global-zero", + jsonParams: `{"affinityGateTauGlobal": 0}`, + expectErr: true, + }, + { + name: "multiple validation errors", + pluginName: "multi-error", + jsonParams: `{ + "samplingMean": -1, + "maxSampledTokens": 0, + "epsilonExploreSticky": 2.0, + "headroomSelectionStrategy": "unknown" + }`, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handle := utils.NewTestHandle(context.Background()) + rawParams := json.RawMessage(tt.jsonParams) + plugin, err := SLOAwareRouterFactory(tt.pluginName, rawParams, handle) + + if tt.expectErr { + assert.Error(t, err) + assert.Nil(t, plugin) + } else { + assert.NoError(t, err) + assert.NotNil(t, plugin) + } + }) + } +} + +func TestSLOAwareRouterFactoryInvalidJSON(t *testing.T) { + invalidTests := []struct { + name string + jsonParams string + }{ + { + name: "malformed JSON", + jsonParams: `{"samplingMean": 100.0, "maxSampledTokens":`, // incomplete + }, + { + name: "samplingMean as string", + jsonParams: `{"samplingMean": "100"}`, + }, + { + name: "maxSampledTokens as float", + jsonParams: `{"maxSampledTokens": 20.5}`, + }, + { + name: "headroomSelectionStrategy as number", + jsonParams: `{"headroomSelectionStrategy": 123}`, + }, + } + + for _, tt := range invalidTests { + t.Run(tt.name, func(t *testing.T) { + handle := utils.NewTestHandle(context.Background()) + rawParams := json.RawMessage(tt.jsonParams) + plugin, err := SLOAwareRouterFactory("test", rawParams, handle) + + assert.Error(t, err) + assert.Nil(t, plugin) + }) + } +}