Skip to content

Added the ability for plugins to receive the request headers and modify them #760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
10 changes: 8 additions & 2 deletions pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (s *StreamingServer) HandleRequestBody(
Model: model,
ResolvedTargetModel: modelName,
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
Headers: reqCtx.RequestHeaders,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lines 73 and 75 are duplicates?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, fixed

Prompt: prompt,
}
logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq)
Expand Down Expand Up @@ -109,7 +110,7 @@ func (s *StreamingServer) HandleRequestBody(
reqCtx.TargetPod = targetPod.NamespacedName.String()
reqCtx.TargetEndpoint = endpoint

s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes))
s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes), res.MutatedHeaders)

reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{
// The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
Expand Down Expand Up @@ -151,7 +152,12 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ
return err
}
endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
s.populateRequestHeaderResponse(reqCtx, endpoint, 0)
s.populateRequestHeaderResponse(reqCtx, endpoint, 0, nil)
}

for _, header := range req.RequestHeaders.Headers.Headers {
reqCtx.RequestHeaders[header.Key] = header.Value
}

Comment on lines +158 to +161
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part was already merged into main

return nil
}
16 changes: 14 additions & 2 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ type RequestContext struct {
RequestState StreamRequestState
modelServerStreaming bool

RequestHeaders map[string]string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also available already in main but in different field - Request, which contains body and headers.


reqHeaderResp *extProcPb.ProcessingResponse
reqBodyResp *extProcPb.ProcessingResponse
reqTrailerResp *extProcPb.ProcessingResponse
Expand Down Expand Up @@ -117,7 +119,8 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
// Create request context to share states during life time of an HTTP request.
// See https://github.com/envoyproxy/envoy/issues/17540.
reqCtx := &RequestContext{
RequestState: RequestReceived,
RequestState: RequestReceived,
RequestHeaders: make(map[string]string),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already in main in Request field

}

var body []byte
Expand Down Expand Up @@ -358,7 +361,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
return nil
}

func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int) {
func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int, mutatedHeaders map[string]string) {
headers := []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Expand All @@ -377,6 +380,15 @@ func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext,
},
})
}
// Add headers added by filters/scorers
for key, value := range mutatedHeaders {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The headers is a list. If the plugins updated the value of an existing header key, should we update the header here instead of just appending?

headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: key,
RawValue: []byte(value),
},
})
}

targetEndpointValue := &structpb.Struct{
Fields: map[string]*structpb.Value{
Expand Down
1 change: 1 addition & 0 deletions pkg/epp/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types

s.runPostSchedulePlugins(sCtx, result)

result.MutatedHeaders = sCtx.MutatedHeaders
return result, nil
}

Expand Down
136 changes: 101 additions & 35 deletions pkg/epp/scheduling/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func TestSchedule(t *testing.T) {
},
},
},
MutatedHeaders: make(map[string]string),
},
},
{
Expand Down Expand Up @@ -172,6 +173,7 @@ func TestSchedule(t *testing.T) {
},
},
},
MutatedHeaders: make(map[string]string),
},
},
{
Expand Down Expand Up @@ -242,30 +244,41 @@ func TestSchedule(t *testing.T) {

func TestSchedulePlugins(t *testing.T) {
tp1 := &TestPlugin{
NameRes: "test1",
ScoreRes: 0.3,
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}},
NameRes: "test1",
ScoreRes: 0.3,
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}},
ReceivedRequestHeaders: make(map[string]string),
}
tp2 := &TestPlugin{
NameRes: "test2",
ScoreRes: 0.8,
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}},
NameRes: "test2",
ScoreRes: 0.8,
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}},
ReceivedRequestHeaders: make(map[string]string),
}
tp_filterAll := &TestPlugin{
NameRes: "filter all",
FilterRes: []k8stypes.NamespacedName{},
NameRes: "filter all",
FilterRes: []k8stypes.NamespacedName{},
ReceivedRequestHeaders: make(map[string]string),
}
tp_headers := &TestPlugin{
NameRes: "headers",
FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}},
ExtraHeaders: map[string]string{"x-unit-test": "test 1 2 3"},
ReceivedRequestHeaders: make(map[string]string),
}
pickerPlugin := &TestPlugin{
NameRes: "picker",
PickRes: k8stypes.NamespacedName{Name: "pod1"},
}

tests := []struct {
name string
config SchedulerConfig
input []*backendmetrics.FakePodMetrics
wantTargetPod k8stypes.NamespacedName
targetPodScore float64
name string
config SchedulerConfig
input []*backendmetrics.FakePodMetrics
requestHeaders map[string]string
wantTargetPod k8stypes.NamespacedName
wantMutatedHeaders map[string]string
targetPodScore float64
// Number of expected pods to score (after filter)
numPodsToScore int
err bool
Expand All @@ -287,10 +300,12 @@ func TestSchedulePlugins(t *testing.T) {
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
},
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
targetPodScore: 1.1,
numPodsToScore: 2,
err: false,
requestHeaders: make(map[string]string),
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
wantMutatedHeaders: make(map[string]string),
targetPodScore: 1.1,
numPodsToScore: 2,
err: false,
},
{
name: "all plugins executed successfully, different scorers weights",
Expand All @@ -309,10 +324,12 @@ func TestSchedulePlugins(t *testing.T) {
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
},
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
targetPodScore: 50,
numPodsToScore: 2,
err: false,
requestHeaders: make(map[string]string),
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
wantMutatedHeaders: make(map[string]string),
targetPodScore: 50,
numPodsToScore: 2,
err: false,
},
{
name: "filter all",
Expand All @@ -331,9 +348,37 @@ func TestSchedulePlugins(t *testing.T) {
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
},
requestHeaders: make(map[string]string),
numPodsToScore: 0,
err: true, // no available pods to server after filter all
},
{
name: "Mutate a header",
config: SchedulerConfig{
preSchedulePlugins: []plugins.PreSchedule{tp1, tp2},
filters: []plugins.Filter{tp_headers},
scorers: map[plugins.Scorer]int{
tp1: 1,
tp2: 1,
},
picker: pickerPlugin,
postSchedulePlugins: []plugins.PostSchedule{tp1, tp2},
},
input: []*backendmetrics.FakePodMetrics{
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}},
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}},
},
requestHeaders: map[string]string{
"Content-type": "application/json",
"x-session-id": "qazw-edcr-tgby-nhyu",
},
wantTargetPod: k8stypes.NamespacedName{Name: "pod1"},
wantMutatedHeaders: map[string]string{"x-unit-test": "test 1 2 3"},
targetPodScore: 1.1,
numPodsToScore: 2,
err: false, // no available pods to server after filter all
},
}

for _, test := range tests {
Expand All @@ -356,7 +401,10 @@ func TestSchedulePlugins(t *testing.T) {
// Initialize the scheduler
scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config)

req := &types.LLMRequest{Model: "test-model"}
req := &types.LLMRequest{
Model: "test-model",
Headers: test.requestHeaders,
}
got, err := scheduler.Schedule(context.Background(), req)

// Validate error state
Expand All @@ -372,7 +420,10 @@ func TestSchedulePlugins(t *testing.T) {
wantPod := &types.PodMetrics{
Pod: &backend.Pod{NamespacedName: test.wantTargetPod},
}
wantRes := &types.Result{TargetPod: wantPod}
wantRes := &types.Result{
TargetPod: wantPod,
MutatedHeaders: test.wantMutatedHeaders,
}
if diff := cmp.Diff(wantRes, got); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
}
Expand All @@ -390,6 +441,9 @@ func TestSchedulePlugins(t *testing.T) {
if tp.FilterCallCount != 1 {
t.Errorf("Plugin %s Filter() called %d times, expected 1", plugin.Name(), tp.FilterCallCount)
}
if len(test.requestHeaders) != len(tp.ReceivedRequestHeaders) {
t.Errorf("Count of received request headers is %d, expected %d", len(tp.ReceivedRequestHeaders), len(test.requestHeaders))
}
}

for plugin := range test.config.scorers {
Expand Down Expand Up @@ -419,6 +473,10 @@ func TestSchedulePlugins(t *testing.T) {
t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", plugin.Name(), tp.PostScheduleCallCount)
}
}

if len(test.wantMutatedHeaders) != len(got.MutatedHeaders) {
t.Errorf("Count of mutated headers is %d, expected %d", len(got.MutatedHeaders), len(test.wantMutatedHeaders))
}
})
}
}
Expand All @@ -437,18 +495,20 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics {

// TestPlugin is an implementation useful in unit tests.
type TestPlugin struct {
NameRes string
ScoreCallCount int
NumOfScoredPods int
ScoreRes float64
FilterCallCount int
FilterRes []k8stypes.NamespacedName
PreScheduleCallCount int
PostScheduleCallCount int
PickCallCount int
NumOfPickerCandidates int
PickRes k8stypes.NamespacedName
WinnderPodScore float64
NameRes string
ScoreCallCount int
NumOfScoredPods int
ScoreRes float64
FilterCallCount int
FilterRes []k8stypes.NamespacedName
PreScheduleCallCount int
PostScheduleCallCount int
PickCallCount int
NumOfPickerCandidates int
PickRes k8stypes.NamespacedName
WinnderPodScore float64
ExtraHeaders map[string]string
ReceivedRequestHeaders map[string]string
}

func (tp *TestPlugin) Name() string { return tp.NameRes }
Expand All @@ -459,6 +519,12 @@ func (tp *TestPlugin) PreSchedule(ctx *types.SchedulingContext) {

func (tp *TestPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
tp.FilterCallCount++
for key, value := range tp.ExtraHeaders {
ctx.MutatedHeaders[key] = value
}
for key, value := range ctx.Req.Headers {
tp.ReceivedRequestHeaders[key] = value
}
return findPods(ctx, tp.FilterRes...)

}
Expand Down
17 changes: 12 additions & 5 deletions pkg/epp/scheduling/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ type LLMRequest struct {
// Target models is a map of target model name to weight.
TargetModels map[string]int
Prompt string
// Headers during request processing contains all of the request headers.
// During response processing it contains all of the response headers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove spaces

Headers map[string]string
// Resolved target model is the final target model after traffic split.
ResolvedTargetModel string
Critical bool
Expand All @@ -58,6 +61,8 @@ type SchedulingContext struct {
Logger logr.Logger
Req *LLMRequest
PodsSnapshot []Pod
// MutatedHeaders is used by the plugins to add/modify headers
MutatedHeaders map[string]string
}

func (pm *PodMetrics) String() string {
Expand All @@ -83,10 +88,11 @@ type PodMetrics struct {
func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext {
logger := log.FromContext(ctx).WithValues("request", req)
return &SchedulingContext{
Context: ctx,
Logger: logger,
Req: req,
PodsSnapshot: pods,
Context: ctx,
Logger: logger,
Req: req,
PodsSnapshot: pods,
MutatedHeaders: make(map[string]string),
}
}

Expand All @@ -100,5 +106,6 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod {

// Result captures the scheduler result.
type Result struct {
TargetPod Pod
TargetPod Pod
MutatedHeaders map[string]string
}