Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

configurable filter chains #169

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pkg/ext-proc/backend/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,27 @@ type K8sDatastore struct {
inferencePool *v1alpha1.InferencePool
InferenceModels *sync.Map
pods *sync.Map

filterConfigMap *corev1.ConfigMap
}

type K8sDatastoreOption func(*K8sDatastore)

func (ds *K8sDatastore) GetFilterConfigMap() *corev1.ConfigMap {
if ds == nil {
return nil
}
ds.poolMu.RLock()
defer ds.poolMu.RUnlock()
return ds.filterConfigMap
}

func WithFilterConfigMap(filterConfigMap *corev1.ConfigMap) K8sDatastoreOption {
return func(store *K8sDatastore) {
store.filterConfigMap = filterConfigMap
}
}

// WithPods can be used in tests to override the pods.
func WithPods(pods []*PodMetrics) K8sDatastoreOption {
return func(store *K8sDatastore) {
Expand Down
53 changes: 53 additions & 0 deletions pkg/ext-proc/backend/filterconfig_reconciler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package backend

import (
"context"

corev1 "k8s.io/api/core/v1"
"k8s.io/klog/v2"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/predicate"
)

type FilterConfigReconciler struct {
client.Client
Datastore *K8sDatastore
}

func (c *FilterConfigReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
cm := &corev1.ConfigMap{}
if err := c.Get(ctx, req.NamespacedName, cm); err != nil {
if client.IgnoreNotFound(err) != nil {
klog.Errorf("unable to get ConfigMap, err: %v", err)
return ctrl.Result{}, err
}
c.Datastore.poolMu.Lock()
defer c.Datastore.poolMu.Unlock()
klog.V(1).Info("filter config deleted, reset filter config")
c.Datastore.filterConfigMap = nil
return ctrl.Result{}, nil
}

c.Datastore.poolMu.Lock()
defer c.Datastore.poolMu.Unlock()

if cm.DeletionTimestamp != nil {
klog.V(1).Info("filter config deleting, reset filter config")
c.Datastore.filterConfigMap = nil
return ctrl.Result{}, nil
}

klog.V(1).Infof("update filter config to: %++v", cm.Data)
c.Datastore.filterConfigMap = cm.DeepCopy()
return ctrl.Result{}, nil
}

func (c *FilterConfigReconciler) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
For(&corev1.ConfigMap{}).
WithEventFilter(predicate.NewPredicateFuncs(func(object client.Object) bool {
return object.GetName() == "filter-config" && object.GetNamespace() == "default"
})).
Complete(c)
}
21 changes: 20 additions & 1 deletion pkg/ext-proc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ var (
"refreshMetricsInterval",
50*time.Millisecond,
"interval to refresh metrics")
enableFilterConfiguration = flag.Bool(
"enableFilterConfiguration",
false,
"Whether to enable configuring filters in `default/filter-config` configmap, ONLY FOR DEV NOW.",
)

scheme = runtime.NewScheme()
)
Expand Down Expand Up @@ -133,6 +138,15 @@ func main() {
klog.Fatalf("Failed setting up EndpointSliceReconciler: %v", err)
}

if *enableFilterConfiguration {
if err := (&backend.FilterConfigReconciler{
Datastore: datastore,
Client: mgr.GetClient(),
}).SetupWithManager(mgr); err != nil {
klog.Error(err, "Error setting up FilterConfigReconciler")
}
}

// Start health and ext-proc servers in goroutines
healthSvr := startHealthServer(datastore, *grpcHealthPort)
extProcSvr := startExternalProcessorServer(
Expand Down Expand Up @@ -193,6 +207,11 @@ func startExternalProcessorServer(
) *grpc.Server {
svr := grpc.NewServer()

var orchestrator scheduling.FilterOrchestrator
if *enableFilterConfiguration {
orchestrator = scheduling.NewFilterOrchestrator(datastore)
}

go func() {
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
Expand All @@ -209,7 +228,7 @@ func startExternalProcessorServer(
// Register ext_proc handlers
extProcPb.RegisterExternalProcessorServer(
svr,
handlers.NewServer(pp, scheduling.NewScheduler(pp), targetPodHeader, datastore),
handlers.NewServer(pp, scheduling.NewScheduler(pp, scheduling.WithOrchestrator(orchestrator)), targetPodHeader, datastore),
)

// Blocking and will return when shutdown is complete.
Expand Down
45 changes: 28 additions & 17 deletions pkg/ext-proc/scheduling/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,46 @@ import (
"errors"
"math"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend"

klog "k8s.io/klog/v2"
)

type Filter interface {
type FilterChain interface {
Name() string
Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
}

// filter applies current filterFunc, and then recursively applies next filters depending success or
// filterChainImpl applies current filterFunc, and then recursively applies next filters depending success or
// failure of the current filterFunc.
// It can be used to construct a flow chart algorithm.
type filter struct {
type filterChainImpl struct {
name string
filter filterFunc
filter filter
// nextOnSuccess filter will be applied after successfully applying the current filter.
// The filtered results will be passed to the next filter.
nextOnSuccess *filter
nextOnSuccess *filterChainImpl
// nextOnFailure filter will be applied if current filter fails.
// The original input will be passed to the next filter.
nextOnFailure *filter
nextOnFailure *filterChainImpl
// nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
// success or failure of the current filter.
// NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
// nextOnSuccessOrFailure, in the success and failure scenarios, respectively.
nextOnSuccessOrFailure *filter
nextOnSuccessOrFailure *filterChainImpl
}

func (f *filter) Name() string {
func (f *filterChainImpl) Name() string {
if f == nil {
return "nil"
}
return f.name
}

func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
func (f *filterChainImpl) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
klog.V(3).Infof("Running filter %q on request %v with %v pods", f.name, req, len(pods))

filtered, err := f.filter(req, pods)
Expand Down Expand Up @@ -71,11 +74,11 @@ func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend
}
}

// filterFunc filters a set of input pods to a subset.
type filterFunc func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
// filter filters a set of input pods to a subset.
type filter func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)

// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
func toFilterFunc(pp podPredicate) filterFunc {
// toFilter is a helper function to convert a per pod filter func to the FilterFunc.
func toFilter(pp podPredicate) filter {
return func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
filtered := []*backend.PodMetrics{}
for _, pod := range pods {
Expand Down Expand Up @@ -120,10 +123,6 @@ func leastQueuingFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*bac
return filtered, nil
}

func lowQueueingPodPredicate(_ *LLMRequest, pod *backend.PodMetrics) bool {
return pod.WaitingQueueSize < queueingThresholdLoRA
}

// leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range
// (max-min) by the number of pods, and finds the pods that fall into the first range.
// The intuition is that if there are multiple pods that share similar KV cache in the low range, we
Expand Down Expand Up @@ -152,6 +151,12 @@ func leastKVCacheFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*bac
return filtered, nil
}

func dropRequestFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
klog.Infof("Dropping request %v", req)
return []*backend.PodMetrics{}, status.Errorf(
codes.ResourceExhausted, "dropping request due to limited backend resources")
}

// podPredicate is a filter function to check whether a pod is desired.
type podPredicate func(req *LLMRequest, pod *backend.PodMetrics) bool

Expand Down Expand Up @@ -185,3 +190,9 @@ func noQueueAndLessThanKVCacheThresholdPredicate(queueThreshold int, kvCacheThre
return pod.WaitingQueueSize <= queueThreshold && pod.KVCacheUsagePercent <= kvCacheThreshold
}
}

func lowQueueingPodPredicate(queueingThresholdLoRA int) podPredicate {
return func(_ *LLMRequest, pod *backend.PodMetrics) bool {
return pod.WaitingQueueSize < queueingThresholdLoRA
}
}
10 changes: 5 additions & 5 deletions pkg/ext-proc/scheduling/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ func TestFilter(t *testing.T) {
input []*backend.PodMetrics
output []*backend.PodMetrics
err bool
filter *filter
filter *filterChainImpl
}{
{
name: "simple filter without successor, failure",
filter: &filter{filter: func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
filter: &filterChainImpl{filter: func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
return nil, errors.New("filter error")
}},
err: true,
Expand Down Expand Up @@ -216,7 +216,7 @@ func TestFilter(t *testing.T) {
func TestFilterFunc(t *testing.T) {
tests := []struct {
name string
f filterFunc
f filter
req *LLMRequest
input []*backend.PodMetrics
output []*backend.PodMetrics
Expand Down Expand Up @@ -302,7 +302,7 @@ func TestFilterFunc(t *testing.T) {
},
{
name: "noQueueAndLessThanKVCacheThresholdPredicate",
f: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(0, 0.8)),
f: toFilter(noQueueAndLessThanKVCacheThresholdPredicate(0, 0.8)),
input: []*backend.PodMetrics{
{
// This pod should be returned.
Expand Down Expand Up @@ -337,7 +337,7 @@ func TestFilterFunc(t *testing.T) {
},
{
name: "low LoRA cost",
f: toFilterFunc(lowLoRACostPredicate),
f: toFilter(lowLoRACostPredicate),
req: &LLMRequest{
Model: "model",
ResolvedTargetModel: "model",
Expand Down
Loading