From ab54af96ab5a9180af92c5deb55c84cebdc1c497 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 30 Apr 2025 13:33:10 +0300 Subject: [PATCH 1/7] Added the ability for plugins to receive the request headers and modify them --- pkg/epp/handlers/request.go | 10 +- pkg/epp/handlers/server.go | 16 +++- pkg/epp/scheduling/scheduler.go | 1 + pkg/epp/scheduling/scheduler_test.go | 136 ++++++++++++++++++++------- pkg/epp/scheduling/types/types.go | 20 ++-- 5 files changed, 136 insertions(+), 47 deletions(-) diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 8d30e543d..9a343b06f 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -66,6 +66,7 @@ func (s *StreamingServer) HandleRequestBody( Model: model, ResolvedTargetModel: modelName, Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, + Headers: reqCtx.RequestHeaders, } logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) @@ -104,7 +105,7 @@ func (s *StreamingServer) HandleRequestBody( reqCtx.TargetPod = targetPod.NamespacedName.String() reqCtx.TargetEndpoint = endpoint - s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes)) + s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes), res.MutatedHeaders) reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{ // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header @@ -146,7 +147,12 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ return err } endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) - s.populateRequestHeaderResponse(reqCtx, endpoint, 0) + s.populateRequestHeaderResponse(reqCtx, endpoint, 0, nil) } + + for _, header := range req.RequestHeaders.Headers.Headers { + reqCtx.RequestHeaders[header.Key] = header.Value + } + return nil } diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 5e23c7a0a..4b5c7c56a 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -86,6 +86,8 @@ type RequestContext struct { RequestState StreamRequestState modelServerStreaming bool + RequestHeaders map[string]string + reqHeaderResp *extProcPb.ProcessingResponse reqBodyResp *extProcPb.ProcessingResponse reqTrailerResp *extProcPb.ProcessingResponse @@ -117,7 +119,8 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) // Create request context to share states during life time of an HTTP request. // See https://github.com/envoyproxy/envoy/issues/17540. reqCtx := &RequestContext{ - RequestState: RequestReceived, + RequestState: RequestReceived, + RequestHeaders: make(map[string]string), } var body []byte @@ -358,7 +361,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces return nil } -func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int) { +func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int, mutatedHeaders map[string]string) { headers := []*configPb.HeaderValueOption{ { Header: &configPb.HeaderValue{ @@ -377,6 +380,15 @@ func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, }, }) } + // Add headers added by filters/scorers + for key, value := range mutatedHeaders { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: key, + RawValue: []byte(value), + }, + }) + } targetEndpointValue := &structpb.Struct{ Fields: map[string]*structpb.Value{ diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 04d24ea24..bd1e81f72 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -119,6 +119,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types s.runPostSchedulePlugins(sCtx, result) + result.MutatedHeaders = sCtx.MutatedHeaders return result, nil } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index 311f44e9f..9fc86008b 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -108,6 +108,7 @@ func TestSchedule(t *testing.T) { }, }, }, + MutatedHeaders: make(map[string]string), }, }, { @@ -171,6 +172,7 @@ func TestSchedule(t *testing.T) { }, }, }, + MutatedHeaders: make(map[string]string), }, }, { @@ -241,18 +243,27 @@ func TestSchedule(t *testing.T) { func TestSchedulePlugins(t *testing.T) { tp1 := &TestPlugin{ - NameRes: "test1", - ScoreRes: 0.3, - FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, + NameRes: "test1", + ScoreRes: 0.3, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, + ReceivedRequestHeaders: make(map[string]string), } tp2 := &TestPlugin{ - NameRes: "test2", - ScoreRes: 0.8, - FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + NameRes: "test2", + ScoreRes: 0.8, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + ReceivedRequestHeaders: make(map[string]string), } tp_filterAll := &TestPlugin{ - NameRes: "filter all", - FilterRes: []k8stypes.NamespacedName{}, + NameRes: "filter all", + FilterRes: []k8stypes.NamespacedName{}, + ReceivedRequestHeaders: make(map[string]string), + } + tp_headers := &TestPlugin{ + NameRes: "headers", + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + ExtraHeaders: map[string]string{"x-unit-test": "test 1 2 3"}, + ReceivedRequestHeaders: make(map[string]string), } pickerPlugin := &TestPlugin{ NameRes: "picker", @@ -260,11 +271,13 @@ func TestSchedulePlugins(t *testing.T) { } tests := []struct { - name string - config SchedulerConfig - input []*backendmetrics.FakePodMetrics - wantTargetPod k8stypes.NamespacedName - targetPodScore float64 + name string + config SchedulerConfig + input []*backendmetrics.FakePodMetrics + requestHeaders map[string]string + wantTargetPod k8stypes.NamespacedName + wantMutatedHeaders map[string]string + targetPodScore float64 // Number of expected pods to score (after filter) numPodsToScore int err bool @@ -286,10 +299,12 @@ func TestSchedulePlugins(t *testing.T) { {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - targetPodScore: 1.1, - numPodsToScore: 2, - err: false, + requestHeaders: make(map[string]string), + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + wantMutatedHeaders: make(map[string]string), + targetPodScore: 1.1, + numPodsToScore: 2, + err: false, }, { name: "all plugins executed successfully, different scorers weights", @@ -308,10 +323,12 @@ func TestSchedulePlugins(t *testing.T) { {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - targetPodScore: 50, - numPodsToScore: 2, - err: false, + requestHeaders: make(map[string]string), + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + wantMutatedHeaders: make(map[string]string), + targetPodScore: 50, + numPodsToScore: 2, + err: false, }, { name: "filter all", @@ -330,9 +347,37 @@ func TestSchedulePlugins(t *testing.T) { {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, + requestHeaders: make(map[string]string), numPodsToScore: 0, err: true, // no available pods to server after filter all }, + { + name: "Mutate a header", + config: SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp_headers}, + scorers: map[plugins.Scorer]int{ + tp1: 1, + tp2: 1, + }, + picker: pickerPlugin, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + requestHeaders: map[string]string{ + "Content-type": "application/json", + "x-session-id": "qazw-edcr-tgby-nhyu", + }, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + wantMutatedHeaders: map[string]string{"x-unit-test": "test 1 2 3"}, + targetPodScore: 1.1, + numPodsToScore: 2, + err: false, // no available pods to server after filter all + }, } for _, test := range tests { @@ -355,7 +400,10 @@ func TestSchedulePlugins(t *testing.T) { // Initialize the scheduler scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) - req := &types.LLMRequest{Model: "test-model"} + req := &types.LLMRequest{ + Model: "test-model", + Headers: test.requestHeaders, + } got, err := scheduler.Schedule(context.Background(), req) // Validate error state @@ -371,7 +419,10 @@ func TestSchedulePlugins(t *testing.T) { wantPod := &types.PodMetrics{ Pod: &backendmetrics.Pod{NamespacedName: test.wantTargetPod}, } - wantRes := &types.Result{TargetPod: wantPod} + wantRes := &types.Result{ + TargetPod: wantPod, + MutatedHeaders: test.wantMutatedHeaders, + } if diff := cmp.Diff(wantRes, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } @@ -389,6 +440,9 @@ func TestSchedulePlugins(t *testing.T) { if tp.FilterCallCount != 1 { t.Errorf("Plugin %s Filter() called %d times, expected 1", plugin.Name(), tp.FilterCallCount) } + if len(test.requestHeaders) != len(tp.ReceivedRequestHeaders) { + t.Errorf("Count of received request headers is %d, expected %d", len(tp.ReceivedRequestHeaders), len(test.requestHeaders)) + } } for plugin := range test.config.scorers { @@ -418,6 +472,10 @@ func TestSchedulePlugins(t *testing.T) { t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", plugin.Name(), tp.PostScheduleCallCount) } } + + if len(test.wantMutatedHeaders) != len(got.MutatedHeaders) { + t.Errorf("Count of mutated headers is %d, expected %d", len(got.MutatedHeaders), len(test.wantMutatedHeaders)) + } }) } } @@ -436,18 +494,20 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics { // TestPlugin is an implementation useful in unit tests. type TestPlugin struct { - NameRes string - ScoreCallCount int - NumOfScoredPods int - ScoreRes float64 - FilterCallCount int - FilterRes []k8stypes.NamespacedName - PreScheduleCallCount int - PostScheduleCallCount int - PickCallCount int - NumOfPickerCandidates int - PickRes k8stypes.NamespacedName - WinnderPodScore float64 + NameRes string + ScoreCallCount int + NumOfScoredPods int + ScoreRes float64 + FilterCallCount int + FilterRes []k8stypes.NamespacedName + PreScheduleCallCount int + PostScheduleCallCount int + PickCallCount int + NumOfPickerCandidates int + PickRes k8stypes.NamespacedName + WinnderPodScore float64 + ExtraHeaders map[string]string + ReceivedRequestHeaders map[string]string } func (tp *TestPlugin) Name() string { return tp.NameRes } @@ -458,6 +518,12 @@ func (tp *TestPlugin) PreSchedule(ctx *types.SchedulingContext) { func (tp *TestPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { tp.FilterCallCount++ + for key, value := range tp.ExtraHeaders { + ctx.MutatedHeaders[key] = value + } + for key, value := range ctx.Req.Headers { + tp.ReceivedRequestHeaders[key] = value + } return findPods(ctx, tp.FilterRes...) } diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 5198515be..64cd73c0f 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -31,6 +31,7 @@ type LLMRequest struct { // Target models is a map of target model name to weight. TargetModels map[string]int Prompt string + Headers map[string]string // Resolved target model is the final target model after traffic split. ResolvedTargetModel string Critical bool @@ -54,9 +55,10 @@ type ScoredPod struct { // SchedulingContext holds contextual information during a scheduling operation. type SchedulingContext struct { context.Context - Logger logr.Logger - Req *LLMRequest - PodsSnapshot []Pod + Logger logr.Logger + Req *LLMRequest + PodsSnapshot []Pod + MutatedHeaders map[string]string } func (pm *PodMetrics) String() string { @@ -82,10 +84,11 @@ type PodMetrics struct { func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { logger := log.FromContext(ctx).WithValues("request", req) return &SchedulingContext{ - Context: ctx, - Logger: logger, - Req: req, - PodsSnapshot: pods, + Context: ctx, + Logger: logger, + Req: req, + PodsSnapshot: pods, + MutatedHeaders: make(map[string]string), } } @@ -99,5 +102,6 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod { // Result captures the scheduler result. type Result struct { - TargetPod Pod + TargetPod Pod + MutatedHeaders map[string]string } From 0a1093d0bbacc453fc84db5238ec4a2d4ac2e8cc Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 30 Apr 2025 13:33:10 +0300 Subject: [PATCH 2/7] Added the ability for plugins to receive the request headers and modify them --- pkg/epp/handlers/request.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index bebacd8f0..989aab6cc 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -72,6 +72,7 @@ func (s *StreamingServer) HandleRequestBody( Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, Headers: reqCtx.RequestHeaders, Prompt: prompt, + Headers: reqCtx.RequestHeaders, } logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) From 840f7d4662e453353dd48dff843fe5feb7c6977d Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 30 Apr 2025 15:03:48 +0300 Subject: [PATCH 3/7] Updated test after rebase --- pkg/epp/scheduling/scheduler_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index 0e3679a04..fab17f08e 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -365,9 +365,9 @@ func TestSchedulePlugins(t *testing.T) { postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, }, input: []*backendmetrics.FakePodMetrics{ - {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, - {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, - {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, requestHeaders: map[string]string{ "Content-type": "application/json", From 8fd2ca51ef9abca821860eb7001f8b656fcff58e Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 30 Apr 2025 15:28:20 +0300 Subject: [PATCH 4/7] Removed duplicate field --- pkg/epp/handlers/request.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 989aab6cc..bebacd8f0 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -72,7 +72,6 @@ func (s *StreamingServer) HandleRequestBody( Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, Headers: reqCtx.RequestHeaders, Prompt: prompt, - Headers: reqCtx.RequestHeaders, } logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) From b7224e0e4ab028e301ec263213942f285ed5f167 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 30 Apr 2025 15:34:02 +0300 Subject: [PATCH 5/7] Added a go doc comment for the new field --- pkg/epp/scheduling/types/types.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 19d4b3757..cca716d94 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -32,6 +32,8 @@ type LLMRequest struct { // Target models is a map of target model name to weight. TargetModels map[string]int Prompt string + // Headers during request processing contains all of the request headers. + // During response processing it contains all of the response headers. Headers map[string]string // Resolved target model is the final target model after traffic split. ResolvedTargetModel string From e5846d154b1c7086580c68794e9aaa92e2ab7f95 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 30 Apr 2025 15:36:01 +0300 Subject: [PATCH 6/7] Added a go doc comment for the new field --- pkg/epp/scheduling/types/types.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index cca716d94..103a9d78c 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -61,6 +61,7 @@ type SchedulingContext struct { Logger logr.Logger Req *LLMRequest PodsSnapshot []Pod + // MutatedHeaders is used by the plugins to add/modify headers MutatedHeaders map[string]string } From 796f5b56d4dec50437f3055281e6085e4a300c46 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 30 Apr 2025 15:52:25 +0300 Subject: [PATCH 7/7] Ran go fmt on the updated file --- pkg/epp/scheduling/types/types.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 103a9d78c..f0e49452d 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -34,7 +34,7 @@ type LLMRequest struct { Prompt string // Headers during request processing contains all of the request headers. // During response processing it contains all of the response headers. - Headers map[string]string + Headers map[string]string // Resolved target model is the final target model after traffic split. ResolvedTargetModel string Critical bool @@ -58,9 +58,9 @@ type ScoredPod struct { // SchedulingContext holds contextual information during a scheduling operation. type SchedulingContext struct { context.Context - Logger logr.Logger - Req *LLMRequest - PodsSnapshot []Pod + Logger logr.Logger + Req *LLMRequest + PodsSnapshot []Pod // MutatedHeaders is used by the plugins to add/modify headers MutatedHeaders map[string]string }