Skip to content

Commit 2fda22a

Browse files
committed
add config validation in predicted-latency-scorer plugin
Signed-off-by: CYJiang <[email protected]>
1 parent a51e074 commit 2fda22a

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package slo_aware_router
1919
import (
2020
"context"
2121
"encoding/json"
22+
"errors"
2223
"fmt"
2324
"math/rand"
2425
"sync"
@@ -93,6 +94,10 @@ func SLOAwareRouterFactory(name string, rawParameters json.RawMessage, handle pl
9394
}
9495
}
9596

97+
if err := parameters.validate(); err != nil {
98+
return nil, fmt.Errorf("invalid SLOAwareRouter config: %w", err)
99+
}
100+
96101
predictor, err := startPredictor(handle)
97102
if err != nil {
98103
return nil, fmt.Errorf("failed to start latency predictor: %w", err)
@@ -101,6 +106,50 @@ func SLOAwareRouterFactory(name string, rawParameters json.RawMessage, handle pl
101106
return NewSLOAwareRouter(parameters, predictor).WithName(name), nil
102107
}
103108

109+
func (c *Config) validate() error {
110+
var errs []error
111+
112+
if c.SamplingMean <= 0 {
113+
errs = append(errs, fmt.Errorf("samplingMean must be > 0, got %f", c.SamplingMean))
114+
}
115+
116+
if c.MaxSampledTokens <= 0 {
117+
errs = append(errs, fmt.Errorf("maxSampledTokens must be > 0, got %d", c.MaxSampledTokens))
118+
}
119+
120+
if c.SLOBufferFactor <= 0 {
121+
errs = append(errs, fmt.Errorf("sloBufferFactor must be > 0, got %f", c.SLOBufferFactor))
122+
}
123+
124+
if c.NegHeadroomTTFTWeight < 0 || c.NegHeadroomTPOTWeight < 0 ||
125+
c.HeadroomTTFTWeight < 0 || c.HeadroomTPOTWeight < 0 {
126+
errs = append(errs, fmt.Errorf("all headroom weights must be >= 0"))
127+
}
128+
129+
if c.CompositeKVWeight < 0 || c.CompositeQueueWeight < 0 || c.CompositePrefixWeight < 0 {
130+
errs = append(errs, fmt.Errorf("composite weights must be >= 0"))
131+
}
132+
133+
if c.EpsilonExploreSticky < 0 || c.EpsilonExploreSticky > 1 {
134+
errs = append(errs, fmt.Errorf("epsilonExploreSticky must be in [0, 1], got %f", c.EpsilonExploreSticky))
135+
}
136+
if c.EpsilonExploreNeg < 0 || c.EpsilonExploreNeg > 1 {
137+
errs = append(errs, fmt.Errorf("epsilonExploreNeg must be in [0, 1], got %f", c.EpsilonExploreNeg))
138+
}
139+
140+
if c.AffinityGateTau <= 0 || c.AffinityGateTau > 1 {
141+
errs = append(errs, fmt.Errorf("affinityGateTau must be in (0, 1], got %f", c.AffinityGateTau))
142+
}
143+
if c.AffinityGateTauGlobal <= 0 || c.AffinityGateTauGlobal > 1 {
144+
errs = append(errs, fmt.Errorf("affinityGateTauGlobal must be in (0, 1], got %f", c.AffinityGateTauGlobal))
145+
}
146+
147+
if len(errs) > 0 {
148+
return errors.Join(errs...)
149+
}
150+
return nil
151+
}
152+
104153
func NewSLOAwareRouter(config Config, predictor latencypredictor.PredictorInterface) *SLOAwareRouter {
105154
strategy := headroomStrategy(config.HeadroomSelectionStrategy)
106155
if strategy == "" {

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package slo_aware_router
1818

1919
import (
2020
"context"
21+
"encoding/json"
2122
"errors"
2223
"fmt"
2324
"strconv"
@@ -31,6 +32,7 @@ import (
3132
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3233
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
3334
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
35+
"sigs.k8s.io/gateway-api-inference-extension/test/utils"
3436
)
3537

3638
// mockPredictor implements PredictorInterface for testing
@@ -534,3 +536,161 @@ func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) {
534536
})
535537
}
536538
}
539+
540+
func TestSLOAwareRouterFactory(t *testing.T) {
541+
tests := []struct {
542+
name string
543+
pluginName string
544+
jsonParams string
545+
expectErr bool
546+
}{
547+
{
548+
name: "valid config with all fields",
549+
pluginName: "full-config",
550+
jsonParams: `{
551+
"samplingMean": 150.0,
552+
"maxSampledTokens": 30,
553+
"sloBufferFactor": 1.2,
554+
"negHeadroomTTFTWeight": 0.7,
555+
"negHeadroomTPOTWeight": 0.3,
556+
"headroomTTFTWeight": 0.9,
557+
"headroomTPOTWeight": 0.1,
558+
"headroomSelectionStrategy": "least",
559+
"compositeKVWeight": 1.0,
560+
"compositeQueueWeight": 0.8,
561+
"compositePrefixWeight": 0.5,
562+
"epsilonExploreSticky": 0.02,
563+
"epsilonExploreNeg": 0.03,
564+
"affinityGateTau": 0.85,
565+
"affinityGateTauGlobal": 0.95,
566+
"selectionMode": "linear"
567+
}`,
568+
expectErr: false,
569+
},
570+
{
571+
name: "valid config with minimal override (uses defaults)",
572+
pluginName: "minimal",
573+
jsonParams: `{}`,
574+
expectErr: false,
575+
},
576+
{
577+
name: "valid config with composite strategy",
578+
pluginName: "composite",
579+
jsonParams: `{
580+
"headroomSelectionStrategy": "composite-least",
581+
"selectionMode": "linear"
582+
}`,
583+
expectErr: false,
584+
},
585+
{
586+
name: "invalid samplingMean <= 0",
587+
pluginName: "bad-sampling-mean",
588+
jsonParams: `{"samplingMean": -1.0}`,
589+
expectErr: true,
590+
},
591+
{
592+
name: "invalid maxSampledTokens <= 0",
593+
pluginName: "bad-max-tokens",
594+
jsonParams: `{"maxSampledTokens": 0}`,
595+
expectErr: true,
596+
},
597+
{
598+
name: "invalid sloBufferFactor <= 0",
599+
pluginName: "bad-buffer",
600+
jsonParams: `{"sloBufferFactor": 0}`,
601+
expectErr: true,
602+
},
603+
{
604+
name: "negative headroom weight",
605+
pluginName: "neg-weight",
606+
jsonParams: `{"negHeadroomTTFTWeight": -0.1}`,
607+
expectErr: true,
608+
},
609+
{
610+
name: "epsilonExploreSticky > 1",
611+
pluginName: "epsilon-too-high",
612+
jsonParams: `{"epsilonExploreSticky": 1.1}`,
613+
expectErr: true,
614+
},
615+
{
616+
name: "epsilonExploreNeg < 0",
617+
pluginName: "epsilon-negative",
618+
jsonParams: `{"epsilonExploreNeg": -0.1}`,
619+
expectErr: true,
620+
},
621+
{
622+
name: "affinityGateTau out of (0,1]",
623+
pluginName: "tau-invalid",
624+
jsonParams: `{"affinityGateTau": 1.5}`,
625+
expectErr: true,
626+
},
627+
{
628+
name: "affinityGateTauGlobal <= 0",
629+
pluginName: "tau-global-zero",
630+
jsonParams: `{"affinityGateTauGlobal": 0}`,
631+
expectErr: true,
632+
},
633+
{
634+
name: "multiple validation errors",
635+
pluginName: "multi-error",
636+
jsonParams: `{
637+
"samplingMean": -1,
638+
"maxSampledTokens": 0,
639+
"epsilonExploreSticky": 2.0,
640+
"headroomSelectionStrategy": "unknown"
641+
}`,
642+
expectErr: true,
643+
},
644+
}
645+
646+
for _, tt := range tests {
647+
t.Run(tt.name, func(t *testing.T) {
648+
handle := utils.NewTestHandle(context.Background())
649+
rawParams := json.RawMessage(tt.jsonParams)
650+
plugin, err := SLOAwareRouterFactory(tt.pluginName, rawParams, handle)
651+
652+
if tt.expectErr {
653+
assert.Error(t, err)
654+
assert.Nil(t, plugin)
655+
} else {
656+
assert.NoError(t, err)
657+
assert.NotNil(t, plugin)
658+
}
659+
})
660+
}
661+
}
662+
663+
func TestSLOAwareRouterFactoryInvalidJSON(t *testing.T) {
664+
invalidTests := []struct {
665+
name string
666+
jsonParams string
667+
}{
668+
{
669+
name: "malformed JSON",
670+
jsonParams: `{"samplingMean": 100.0, "maxSampledTokens":`, // incomplete
671+
},
672+
{
673+
name: "samplingMean as string",
674+
jsonParams: `{"samplingMean": "100"}`,
675+
},
676+
{
677+
name: "maxSampledTokens as float",
678+
jsonParams: `{"maxSampledTokens": 20.5}`,
679+
},
680+
{
681+
name: "headroomSelectionStrategy as number",
682+
jsonParams: `{"headroomSelectionStrategy": 123}`,
683+
},
684+
}
685+
686+
for _, tt := range invalidTests {
687+
t.Run(tt.name, func(t *testing.T) {
688+
handle := utils.NewTestHandle(context.Background())
689+
rawParams := json.RawMessage(tt.jsonParams)
690+
plugin, err := SLOAwareRouterFactory("test", rawParams, handle)
691+
692+
assert.Error(t, err)
693+
assert.Nil(t, plugin)
694+
})
695+
}
696+
}

0 commit comments

Comments
 (0)