@@ -18,6 +18,7 @@ package slo_aware_router
1818
1919import (
2020 "context"
21+ "encoding/json"
2122 "errors"
2223 "fmt"
2324 "strconv"
@@ -31,6 +32,7 @@ import (
3132 schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3233 requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
3334 latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
35+ "sigs.k8s.io/gateway-api-inference-extension/test/utils"
3436)
3537
3638// mockPredictor implements PredictorInterface for testing
@@ -534,3 +536,161 @@ func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) {
534536 })
535537 }
536538}
539+
540+ func TestSLOAwareRouterFactory (t * testing.T ) {
541+ tests := []struct {
542+ name string
543+ pluginName string
544+ jsonParams string
545+ expectErr bool
546+ }{
547+ {
548+ name : "valid config with all fields" ,
549+ pluginName : "full-config" ,
550+ jsonParams : `{
551+ "samplingMean": 150.0,
552+ "maxSampledTokens": 30,
553+ "sloBufferFactor": 1.2,
554+ "negHeadroomTTFTWeight": 0.7,
555+ "negHeadroomTPOTWeight": 0.3,
556+ "headroomTTFTWeight": 0.9,
557+ "headroomTPOTWeight": 0.1,
558+ "headroomSelectionStrategy": "least",
559+ "compositeKVWeight": 1.0,
560+ "compositeQueueWeight": 0.8,
561+ "compositePrefixWeight": 0.5,
562+ "epsilonExploreSticky": 0.02,
563+ "epsilonExploreNeg": 0.03,
564+ "affinityGateTau": 0.85,
565+ "affinityGateTauGlobal": 0.95,
566+ "selectionMode": "linear"
567+ }` ,
568+ expectErr : false ,
569+ },
570+ {
571+ name : "valid config with minimal override (uses defaults)" ,
572+ pluginName : "minimal" ,
573+ jsonParams : `{}` ,
574+ expectErr : false ,
575+ },
576+ {
577+ name : "valid config with composite strategy" ,
578+ pluginName : "composite" ,
579+ jsonParams : `{
580+ "headroomSelectionStrategy": "composite-least",
581+ "selectionMode": "linear"
582+ }` ,
583+ expectErr : false ,
584+ },
585+ {
586+ name : "invalid samplingMean <= 0" ,
587+ pluginName : "bad-sampling-mean" ,
588+ jsonParams : `{"samplingMean": -1.0}` ,
589+ expectErr : true ,
590+ },
591+ {
592+ name : "invalid maxSampledTokens <= 0" ,
593+ pluginName : "bad-max-tokens" ,
594+ jsonParams : `{"maxSampledTokens": 0}` ,
595+ expectErr : true ,
596+ },
597+ {
598+ name : "invalid sloBufferFactor <= 0" ,
599+ pluginName : "bad-buffer" ,
600+ jsonParams : `{"sloBufferFactor": 0}` ,
601+ expectErr : true ,
602+ },
603+ {
604+ name : "negative headroom weight" ,
605+ pluginName : "neg-weight" ,
606+ jsonParams : `{"negHeadroomTTFTWeight": -0.1}` ,
607+ expectErr : true ,
608+ },
609+ {
610+ name : "epsilonExploreSticky > 1" ,
611+ pluginName : "epsilon-too-high" ,
612+ jsonParams : `{"epsilonExploreSticky": 1.1}` ,
613+ expectErr : true ,
614+ },
615+ {
616+ name : "epsilonExploreNeg < 0" ,
617+ pluginName : "epsilon-negative" ,
618+ jsonParams : `{"epsilonExploreNeg": -0.1}` ,
619+ expectErr : true ,
620+ },
621+ {
622+ name : "affinityGateTau out of (0,1]" ,
623+ pluginName : "tau-invalid" ,
624+ jsonParams : `{"affinityGateTau": 1.5}` ,
625+ expectErr : true ,
626+ },
627+ {
628+ name : "affinityGateTauGlobal <= 0" ,
629+ pluginName : "tau-global-zero" ,
630+ jsonParams : `{"affinityGateTauGlobal": 0}` ,
631+ expectErr : true ,
632+ },
633+ {
634+ name : "multiple validation errors" ,
635+ pluginName : "multi-error" ,
636+ jsonParams : `{
637+ "samplingMean": -1,
638+ "maxSampledTokens": 0,
639+ "epsilonExploreSticky": 2.0,
640+ "headroomSelectionStrategy": "unknown"
641+ }` ,
642+ expectErr : true ,
643+ },
644+ }
645+
646+ for _ , tt := range tests {
647+ t .Run (tt .name , func (t * testing.T ) {
648+ handle := utils .NewTestHandle (context .Background ())
649+ rawParams := json .RawMessage (tt .jsonParams )
650+ plugin , err := SLOAwareRouterFactory (tt .pluginName , rawParams , handle )
651+
652+ if tt .expectErr {
653+ assert .Error (t , err )
654+ assert .Nil (t , plugin )
655+ } else {
656+ assert .NoError (t , err )
657+ assert .NotNil (t , plugin )
658+ }
659+ })
660+ }
661+ }
662+
663+ func TestSLOAwareRouterFactoryInvalidJSON (t * testing.T ) {
664+ invalidTests := []struct {
665+ name string
666+ jsonParams string
667+ }{
668+ {
669+ name : "malformed JSON" ,
670+ jsonParams : `{"samplingMean": 100.0, "maxSampledTokens":` , // incomplete
671+ },
672+ {
673+ name : "samplingMean as string" ,
674+ jsonParams : `{"samplingMean": "100"}` ,
675+ },
676+ {
677+ name : "maxSampledTokens as float" ,
678+ jsonParams : `{"maxSampledTokens": 20.5}` ,
679+ },
680+ {
681+ name : "headroomSelectionStrategy as number" ,
682+ jsonParams : `{"headroomSelectionStrategy": 123}` ,
683+ },
684+ }
685+
686+ for _ , tt := range invalidTests {
687+ t .Run (tt .name , func (t * testing.T ) {
688+ handle := utils .NewTestHandle (context .Background ())
689+ rawParams := json .RawMessage (tt .jsonParams )
690+ plugin , err := SLOAwareRouterFactory ("test" , rawParams , handle )
691+
692+ assert .Error (t , err )
693+ assert .Nil (t , plugin )
694+ })
695+ }
696+ }
0 commit comments