diff --git a/.gitignore b/.gitignore index 4442b6516..599999112 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,9 @@ go.work.sum # generated docs site + +# tokenizer lib +lib + +# local configuration files +.envrc diff --git a/.golangci.yml b/.golangci.yml index d2364062e..a42307fce 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,7 +1,7 @@ run: timeout: 5m allow-parallel-runners: true - + # Settings related to issues issues: # Report issues on new code only (since we're brining in from upstream) @@ -9,6 +9,7 @@ issues: # Which dirs to exclude: issues from them won't be reported exclude-dirs: - bin + linters: disable-all: true enable: @@ -18,7 +19,7 @@ linters: - fatcontext - ginkgolinter - gocritic - - govet + # - govet # do not enable - this causes some metalinter issue - loggercheck - misspell - perfsprint @@ -27,17 +28,13 @@ linters: - makezero - errcheck - goconst - - gofmt - - goimports - - gosimple - ineffassign - nakedret - prealloc - - typecheck - unparam - unused - + linters-settings: revive: rules: - - name: comment-spacings + - name: comment-spacings \ No newline at end of file diff --git a/.tekton/buildah-build.yaml b/.tekton/buildah-build.yaml index ad4ab4f40..d680a2333 100644 --- a/.tekton/buildah-build.yaml +++ b/.tekton/buildah-build.yaml @@ -44,6 +44,15 @@ spec: USERNAME=$(jq -r '.auths["quay.io"].username' /root/.docker/config.json) PASSWORD=$(jq -r '.auths["quay.io"].password' /root/.docker/config.json) + echo "🔐 Extracting Git credentials from workspace..." + GIT_USER=$(cat /workspace/git-auth/username) + GIT_TOKEN=$(cat /workspace/git-auth/token) + + if [ -z "$GIT_USER" ] || [ -z "$GIT_TOKEN" ]; then + echo "❌ Error: Missing git-auth credentials" + exit 1 + fi + if [ "$USERNAME" = "null" ] || [ "$PASSWORD" = "null" ]; then echo "❌ Error: Missing registry credentials" exit 1 @@ -56,8 +65,10 @@ spec: export DOCKER_CONFIG=/root/.docker export BUILDER=buildah export IMG=$(params.image_tag_base):$(params.dev-version) - + export GIT_NM_USER=$GIT_USER + export NM_TOKEN=$GIT_TOKEN + echo "🚀 Calling make buildah-build with IMG=$IMG..." - make buildah-build IMG=$IMG + make buildah-build IMG=$IMG echo "$IMG" > /tekton/results/image-url diff --git a/.tekton/go-build-task.yaml b/.tekton/go-build-task.yaml index eeb117976..579d20086 100644 --- a/.tekton/go-build-task.yaml +++ b/.tekton/go-build-task.yaml @@ -12,5 +12,24 @@ spec: script: | #!/bin/bash cd $(workspaces.source.path) + + echo "🔐 Extracting Git credentials from workspace..." + GIT_USER=$(cat /workspace/git-auth/username) + GIT_TOKEN=$(cat /workspace/git-auth/token) + + if [ -z "$GIT_USER" ] || [ -z "$GIT_TOKEN" ]; then + echo "❌ Error: Missing git-auth credentials" + exit 1 + fi + + echo "🔐 Configuring Git..." + git config --global user.email "ci-tag-bot@example.com" + git config --global user.name "ci-tag-bot" + git config --global url."https://${GIT_USER}:${GIT_TOKEN}@github.com".insteadOf "https://github.com" + git config --global --add safe.directory "$(pwd)" + + # required for go build with tokenizer lib linking + dnf install -y gcc-c++ libstdc++ libstdc++-devel && dnf clean all + go env -w GOFLAGS=-buildvcs=false make build diff --git a/.tekton/go-lint-task.yaml b/.tekton/go-lint-task.yaml index f42471a19..809a03223 100644 --- a/.tekton/go-lint-task.yaml +++ b/.tekton/go-lint-task.yaml @@ -11,6 +11,7 @@ spec: steps: - name: run-lint image: us.icr.io/ibm-hc4ai-operator/golangci-lint:v1.64.8 + # image: us.icr.io/ibm-hc4ai-operator/golangci-lint:v2.0.3 imagePullPolicy: IfNotPresent script: | #!/bin/bash diff --git a/.tekton/pipelinerun.yaml b/.tekton/pipelinerun.yaml index 29ef7b666..27cfe5c30 100644 --- a/.tekton/pipelinerun.yaml +++ b/.tekton/pipelinerun.yaml @@ -165,6 +165,9 @@ spec: workspaces: - name: source workspace: source + - name: git-auth + workspace: git-auth + - name: extract-version-and-registry params: @@ -328,6 +331,8 @@ spec: workspace: registry-secret - name: container-storage workspace: container-storage + - name: git-auth + workspace: git-auth - name: vulnerability-scan when: diff --git a/Dockerfile b/Dockerfile index a92cbb711..5f7631ee6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,15 +3,26 @@ FROM quay.io/projectquay/golang:1.24 AS builder ARG TARGETOS ARG TARGETARCH -# ENV GOPROXY=https://goproxy.io,direct +# Install build tools +RUN dnf install -y gcc-c++ libstdc++ libstdc++-devel && dnf clean all WORKDIR /workspace + +## NeuralMagic internal repos pull config +ARG GIT_NM_USER +ARG NM_TOKEN +### use git token +RUN echo -e "machine github.com\n\tlogin ${GIT_NM_USER}\n\tpassword ${NM_TOKEN}" >> ~/.netrc +ENV GOPRIVATE=github.com/neuralmagic +ENV GIT_TERMINAL_PROMPT=1 + # Copy the Go Modules manifests COPY go.mod go.mod COPY go.sum go.sum # cache deps before building and copying source so that we don't need to re-download as much # and so that source changes don't invalidate our downloaded layer RUN go mod download +RUN rm -rf ~/.netrc # remove git token # Copy the go source COPY cmd ./cmd @@ -19,12 +30,20 @@ COPY pkg ./pkg COPY internal ./internal COPY api ./api +# HuggingFace tokenizer bindings +RUN mkdir -p lib +RUN curl -L https://github.com/daulet/tokenizers/releases/download/v1.20.2/libtokenizers.${TARGETOS}-${TARGETARCH}.tar.gz | tar -xz -C lib +RUN ranlib lib/*.a + # Build # the GOARCH has not a default value to allow the binary be built according to the host where the command # was called. For example, if we call make image-build in a local env which has the Apple Silicon M1 SO # the docker BUILDPLATFORM arg will be linux/arm64 when for Apple x86 it will be linux/amd64. Therefore, # by leaving it empty we can ensure that the container and binary shipped on it will have the same platform. -RUN CGO_ENABLED=0 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} go build -o bin/epp cmd/epp/main.go cmd/epp/health.go +ENV CGO_ENABLED=1 +ENV GOOS=${TARGETOS:-linux} +ENV GOARCH=${TARGETARCH} +RUN go build -o bin/epp -ldflags="-extldflags '-L$(pwd)/lib'" cmd/epp/main.go cmd/epp/health.go # Use distroless as minimal base image to package the manager binary # Refer to https://github.com/GoogleContainerTools/distroless for more details diff --git a/Makefile b/Makefile index 0bfb19fc7..b51bc16b0 100644 --- a/Makefile +++ b/Makefile @@ -439,11 +439,20 @@ lint: check-golangci-lint ## Run lint golangci-lint run ##@ Build +LDFLAGS ?= -extldflags '-L$(shell pwd)/lib' +CGO_ENABLED=1 # Enable CGO + +.PHONY: download-tokenizer +download-tokenizer: ## Download the HuggingFace tokenizer bindings. + @echo "Downloading HuggingFace tokenizer bindings..." + mkdir -p lib + curl -L https://github.com/daulet/tokenizers/releases/download/v1.20.2/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib + ranlib lib/*.a .PHONY: build -build: check-go ## +build: check-go download-tokenizer ## @printf "\033[33;1m==== Building ====\033[0m\n" - go build -o bin/epp cmd/epp/main.go cmd/epp/health.go + go build -ldflags="$(LDFLAGS)" -o bin/epp cmd/epp/main.go cmd/epp/health.go ##@ Container Build/Push @@ -456,7 +465,12 @@ buildah-build: check-builder load-version-json ## Build and push image (multi-ar for arch in amd64; do \ ARCH_TAG=$$FINAL_TAG-$$arch; \ echo "📦 Building for architecture: $$arch"; \ - buildah build --arch=$$arch --os=linux --layers -t $(IMG)-$$arch . || exit 1; \ + buildah build \ + --arch=$$arch \ + --build-arg GIT_NM_USER=$(GIT_NM_USER) \ + --build-arg NM_TOKEN=$(NM_TOKEN) \ + --os=linux \ + --layers -t $(IMG)-$$arch . || exit 1; \ echo "🚀 Pushing image: $(IMG)-$$arch"; \ buildah push $(IMG)-$$arch docker://$(IMG)-$$arch || exit 1; \ done; \ @@ -474,7 +488,11 @@ buildah-build: check-builder load-version-json ## Build and push image (multi-ar sed -e '1 s/\(^FROM\)/FROM --platform=$${BUILDPLATFORM}/' Dockerfile > Dockerfile.cross; \ - docker buildx create --use --name image-builder || true; \ docker buildx use image-builder; \ - docker buildx build --push --platform=$(PLATFORMS) --tag $(IMG) -f Dockerfile.cross . || exit 1; \ + docker buildx build --push \ + --platform=$(PLATFORMS) \ + --build-arg GIT_NM_USER=$(GIT_NM_USER)\ + --build-arg NM_TOKEN=$(NM_TOKEN) \ + --tag $(IMG) -f Dockerfile.cross . || exit 1; \ docker buildx rm image-builder || true; \ rm Dockerfile.cross; \ elif [ "$(BUILDER)" = "podman" ]; then \ @@ -489,7 +507,13 @@ buildah-build: check-builder load-version-json ## Build and push image (multi-ar .PHONY: image-build image-build: check-container-tool load-version-json ## Build container image using $(CONTAINER_TOOL) @printf "\033[33;1m==== Building container image $(IMG) ====\033[0m\n" - $(CONTAINER_TOOL) build --build-arg TARGETOS=$(TARGETOS) --build-arg TARGETARCH=$(TARGETARCH) -t $(IMG) . + $(CONTAINER_TOOL) build --platform=$(TARGETOS)/$(TARGETARCH) \ + --build-arg TARGETOS=$(TARGETOS) \ + --build-arg TARGETARCH=$(TARGETARCH) \ + --build-arg GIT_NM_USER=$(GIT_NM_USER)\ + --build-arg NM_TOKEN=$(NM_TOKEN) \ + --progress=plain \ + -t $(IMG) . .PHONY: image-push image-push: check-container-tool load-version-json ## Push container image $(IMG) to registry diff --git a/README.md b/README.md index 4cdb17811..76a333eea 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,73 @@ This project offers tools for AI Inference, enabling developers to build [Inference Gateways]. +--- +## Temporary Fork Configuration + +To enable the KVCacheAwareScorer, the following environment variables must be configured: +``` +export ENABLE_KVCACHE_AWARE_SCORER=true +export KVCACHE_AWARE_SCORER_WEIGHT=1.0 +export KVCACHE_INDEXER_REDIS_ADDR= +export HF_TOKEN= +``` + +To enable the PrefixAwareScorer, the following environment variables must be configured: +``` +export ENABLE_PREFIX_AWARE_SCORER=true +export PREFIX_AWARE_SCORER_WEIGHT=1.0 +``` + +To enable the LoadAwareScorer, the following environment variables must be configured: +``` +export ENABLE_LOAD_AWARE_SCORER=true +export LOAD_AWARE_SCORER_WEIGHT=1.0 +``` + +To enable the SessionAwareScorer, the following environment variables must be configured: +``` +export ENABLE_SESSION_AWARE_SCORER=true +export SESSION_AWARE_SCORER_WEIGHT=1.0 +``` + +To enable Prefill/Decode (PD) processing, the following environment variable must be configured: +``` +export PD_ENABLED=true +``` + +To define the prompt length threshold (requests with a prompt longer than the value defined here will be processed using the prefill-decode process), the following environment variable must be configured: +``` +export PD_PROMPT_LEN_THRESHOLD=10 +``` + +Prefill configuration: + +To enable and configure the kv cache scorer for prefill, the following environment variables must be configured: +``` +export PREFILL_ENABLE_KVCACHE_AWARE_SCORER=true +export PREFILL_KVCACHE_AWARE_SCORER_WEIGHT=1.0 +``` + +To enable and configure the load aware scorer for prefill, the following environment variables must be configured: +``` +export PREFILL_ENABLE_LOAD_AWARE_SCORER=true +export PREFILL_LOAD_AWARE_SCORER_WEIGHT=1.0 +``` + +Decode configuration: + +To enable and configure the kv cache scorer for decode, the following environment variables must be configured: +``` +export DECODE_ENABLE_KVCACHE_AWARE_SCORER=true +export DECODE_KVCACHE_AWARE_SCORER_WEIGHT=1.0 +``` + +To enable and configure the load aware scorer for decode, the following environment variables must be configured: +``` +export DECODE_ENABLE_LOAD_AWARE_SCORER=true +export DECODE_LOAD_AWARE_SCORER_WEIGHT=1.0 +``` +--- [Inference Gateways]:#concepts-and-definitions ## Concepts and Definitions @@ -79,8 +146,8 @@ See our website at https://gateway-api-inference-extension.sigs.k8s.io/ for deta ## Roadmap As Inference Gateway builds towards a GA release. We will continue to expand our capabilities, namely: -1. Prefix-cache aware load balancing with interfaces for remote caches -1. Recommended LoRA adapter pipeline for automated rollout +1. Prefix-cache aware load balancing with interfaces for remote caches +1. Recommended LoRA adapter pipeline for automated rollout 1. Fairness and priority between workloads within the same criticality band 1. HPA support for autoscaling on aggregate metrics derived from the load balancer 1. Support for large multi-modal inputs and outputs @@ -104,4 +171,3 @@ Contributions are readily welcomed, follow the [dev guide](./docs/dev.md) to sta ### Code of conduct Participation in the Kubernetes community is governed by the [Kubernetes Code of Conduct](code-of-conduct.md). - diff --git a/deploy/components/vllm-p2p/vllm-deployment.yaml b/deploy/components/vllm-p2p/vllm-deployment.yaml index 19fd59c21..c9964962e 100644 --- a/deploy/components/vllm-p2p/vllm-deployment.yaml +++ b/deploy/components/vllm-p2p/vllm-deployment.yaml @@ -31,13 +31,12 @@ spec: - "-c" args: - | - export LMCACHE_DISTRIBUTED_URL=$${${POD_IP}}:80 && \ + export LMCACHE_DISTRIBUTED_URL=$${${POD_IP}} && \ vllm serve ${MODEL_NAME} \ --host 0.0.0.0 \ --port 8000 \ - --enable-chunked-prefill false \ --max-model-len ${MAX_MODEL_LEN} \ - --kv-transfer-config '{"kv_connector":"LMCacheConnector","kv_role":"kv_both"}' + --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_both"}' ports: - name: http containerPort: 8000 @@ -78,6 +77,10 @@ spec: secretKeyRef: name: ${HF_SECRET_NAME} key: ${HF_SECRET_KEY} + - name: VLLM_ENABLE_V1_MULTIPROCESSING + value: "1" + - name: VLLM_WORKER_MULTIPROC_METHOD + value: spawn - name: LMCACHE_LOOKUP_URL value: ${REDIS_HOST}:${REDIS_PORT} - name: LMCACHE_ENABLE_DEBUG diff --git a/deploy/environments/dev/kubernetes-kgateway/patch-deployments.yaml b/deploy/environments/dev/kubernetes-kgateway/patch-deployments.yaml index 00c87fbbf..a6b1d4a2b 100644 --- a/deploy/environments/dev/kubernetes-kgateway/patch-deployments.yaml +++ b/deploy/environments/dev/kubernetes-kgateway/patch-deployments.yaml @@ -29,4 +29,12 @@ spec: valueFrom: secretKeyRef: name: hf-token - key: ${HF_SECRET_KEY} \ No newline at end of file + key: ${HF_SECRET_KEY} + - name: ENABLE_KVCACHE_AWARE_SCORER + value: "true" + - name: KVCACHE_AWARE_SCORER_WEIGHT + value: "2.0" + - name: ENABLE_LOAD_AWARE_SCORER + value: "true" + - name: LOAD_AWARE_SCORER_WEIGHT + value: "1.0" diff --git a/go.mod b/go.mod index 7da237678..dff0542e9 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,15 @@ module sigs.k8s.io/gateway-api-inference-extension -go 1.24.0 +go 1.24.1 + +toolchain go1.24.2 require ( github.com/elastic/crd-ref-docs v0.1.0 github.com/envoyproxy/go-control-plane/envoy v1.32.4 github.com/go-logr/logr v1.4.2 github.com/google/go-cmp v0.7.0 + github.com/neuralmagic/llm-d-kv-cache-manager v0.0.0-20250430102735-86595011431d github.com/onsi/ginkgo/v2 v2.23.4 github.com/onsi/gomega v1.37.0 github.com/prometheus/client_golang v1.22.0 @@ -41,7 +44,9 @@ require ( github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 // indirect + github.com/daulet/tokenizers v1.20.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect @@ -69,6 +74,7 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/huandu/xstrings v1.3.3 // indirect github.com/imdario/mergo v0.3.11 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -90,6 +96,7 @@ require ( github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/procfs v0.15.1 // indirect + github.com/redis/go-redis/v9 v9.7.3 // indirect github.com/spf13/cobra v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect @@ -104,15 +111,15 @@ require ( go.opentelemetry.io/otel/trace v1.34.0 // indirect go.opentelemetry.io/proto/otlp v1.3.1 // indirect go.uber.org/automaxprocs v1.6.0 // indirect - golang.org/x/crypto v0.36.0 // indirect + golang.org/x/crypto v0.37.0 // indirect golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/mod v0.24.0 // indirect - golang.org/x/net v0.38.0 // indirect + golang.org/x/net v0.39.0 // indirect golang.org/x/oauth2 v0.27.0 // indirect - golang.org/x/sync v0.12.0 // indirect + golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.32.0 // indirect - golang.org/x/term v0.30.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/term v0.31.0 // indirect + golang.org/x/text v0.24.0 // indirect golang.org/x/time v0.7.0 // indirect golang.org/x/tools v0.31.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect diff --git a/go.sum b/go.sum index 11c244d44..ea299e2fd 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,10 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -24,10 +28,14 @@ github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 h1:boJj011Hh+874zpIySe github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/daulet/tokenizers v1.20.2 h1:tlq/vIOiBTKDPets3596aFvmJYLn3XI6LFKq4q9LKhQ= +github.com/daulet/tokenizers v1.20.2/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/elastic/crd-ref-docs v0.1.0 h1:Cr5kz89QB3Iuuj7dhAfLMApCrChEGAaIBTxGk/xuRKw= github.com/elastic/crd-ref-docs v0.1.0/go.mod h1:X83mMBdJt05heJUYiS3T0yJ/JkCuliuhSUNav5Gjo/U= github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= @@ -100,6 +108,8 @@ github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWm github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA= @@ -147,6 +157,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= +github.com/neuralmagic/llm-d-kv-cache-manager v0.0.0-20250430102735-86595011431d h1:6YSxvAG4ve5jy0nTLs509OMU5fuiQ3JNQdZxqiu8PgQ= +github.com/neuralmagic/llm-d-kv-cache-manager v0.0.0-20250430102735-86595011431d/go.mod h1:VB+KcEemkO1ZKpz/hgUPQMU9oSLv2uCLW6y6c+r8jkQ= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= @@ -172,6 +184,8 @@ github.com/prometheus/common v0.63.0 h1:YR/EIY1o3mEFP/kZCD7iDMnLPlGyuU2Gb3HIcXnA github.com/prometheus/common v0.63.0/go.mod h1:VVFF/fBIoToEnWRVkYoXEkq3R3paCoxG9PXP74SnV18= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= +github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -226,8 +240,8 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -238,17 +252,15 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= -golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -256,13 +268,13 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= +golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/pkg/epp/backend/metrics/pod_metrics.go b/pkg/epp/backend/metrics/pod_metrics.go index 901697cb4..5ebf8484e 100644 --- a/pkg/epp/backend/metrics/pod_metrics.go +++ b/pkg/epp/backend/metrics/pod_metrics.go @@ -32,7 +32,7 @@ import ( const ( fetchMetricsTimeout = 5 * time.Second - roleLabel = "llmd.org/role" + roleLabel = "llm-d.ai/role" rolePrefill = "prefill" roleDecode = "decode" roleBoth = "both" diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 203afc2f0..47cd37dee 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -31,6 +31,8 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +const emptyPrompt = "" + // HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling. func (s *StreamingServer) HandleRequestBody( ctx context.Context, @@ -68,6 +70,7 @@ func (s *StreamingServer) HandleRequestBody( Headers: reqCtx.RequestHeaders, ResolvedTargetModel: modelName, Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, + Prompt: emptyPrompt, } logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) @@ -76,6 +79,10 @@ func (s *StreamingServer) HandleRequestBody( if llmReq.Model != llmReq.ResolvedTargetModel { requestBodyMap["model"] = llmReq.ResolvedTargetModel } + // Extract prompt from the request body. + if prompt, ok := requestBodyMap["prompt"].(string); ok { + llmReq.Prompt = prompt + } requestBodyBytes, err = json.Marshal(requestBodyMap) if err != nil { @@ -152,7 +159,7 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ } for _, header := range req.RequestHeaders.Headers.Headers { - reqCtx.RequestHeaders[header.Key] = header.Value + reqCtx.RequestHeaders[header.Key] = string(header.RawValue) } return nil diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 6ea7d438c..11587fb1c 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -37,6 +37,7 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -66,6 +67,7 @@ type StreamingServer struct { type Scheduler interface { Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) + RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, tragetPodName string) (*schedulingtypes.Result, error) } // RequestContext stores context information during the life time of an HTTP request. @@ -189,6 +191,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) case *extProcPb.ProcessingRequest_RequestTrailers: // This is currently unused. case *extProcPb.ProcessingRequest_ResponseHeaders: + responseHeaders := make(map[string]string) for _, header := range v.ResponseHeaders.Headers.GetHeaders() { value := string(header.RawValue) @@ -199,27 +202,53 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) reqCtx.modelServerStreaming = true loggerTrace.Info("model server is streaming response") } + responseHeaders[header.Key] = value } - reqCtx.RequestState = ResponseRecieved - reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ResponseHeaders{ - ResponseHeaders: &extProcPb.HeadersResponse{ - Response: &extProcPb.CommonResponse{ - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - // This is for debugging purpose only. - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, - }, + llmReq := &schedulingtypes.LLMRequest{ + Model: reqCtx.Model, + Headers: responseHeaders, + ResolvedTargetModel: reqCtx.ResolvedTargetModel, + } + + var result *types.Result + result, err = s.scheduler.RunPostResponsePlugins(ctx, llmReq, reqCtx.TargetPod) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error handling response") + reqCtx.ResponseStatusCode = errutil.ModelServerError + } else { + headers := []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + // This is for debugging purpose only. + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }, + }, + } + + // Add headers added by PostResponse + for key, value := range result.MutatedHeaders { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: key, + RawValue: []byte(value), + }, + }) + } + + reqCtx.RequestState = ResponseRecieved + reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: headers, }, }, }, }, - }, + } } case *extProcPb.ProcessingRequest_ResponseBody: diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index 5c64228ca..3f064fe75 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -26,6 +26,7 @@ type SchedulerConfig struct { scorers map[plugins.Scorer]int // map from scorer to weight picker plugins.Picker postSchedulePlugins []plugins.PostSchedule + postResponsePlugins []plugins.PostResponse } var defPlugin = &defaultPlugin{} @@ -40,4 +41,5 @@ var defaultConfig = &SchedulerConfig{ scorers: map[plugins.Scorer]int{}, picker: defPlugin, postSchedulePlugins: []plugins.PostSchedule{}, + postResponsePlugins: []plugins.PostResponse{}, } diff --git a/pkg/epp/scheduling/config_utils.go b/pkg/epp/scheduling/config_utils.go new file mode 100644 index 000000000..4145dbe1b --- /dev/null +++ b/pkg/epp/scheduling/config_utils.go @@ -0,0 +1,84 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "context" + "fmt" + + "github.com/go-logr/logr" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" +) + +const ( + prefillKvCacheScorerEnablementEnvVar = "PREFILL_ENABLE_KVCACHE_AWARE_SCORER" + prefillLoadAwareScorerEnablementEnvVar = "PREFILL_ENABLE_LOAD_AWARE_SCORER" + decodeKvCacheScorerEnablementEnvVar = "DECODE_ENABLE_KVCACHE_AWARE_SCORER" + decodeLoadAwareScorerEnablementEnvVar = "DECODE_ENABLE_LOAD_AWARE_SCORER" + + prefillKvCacheScorerWeightEnvVar = "PREFILL_KVCACHE_AWARE_SCORER_WEIGHT" + prefillLoadAwareScorerWeightEnvVar = "PREFILL_LOAD_AWARE_SCORER_WEIGHT" + decodeKvCacheScorerWeightEnvVar = "DECODE_KVCACHE_AWARE_SCORER_WEIGHT" + decodeLoadAwareScorerWeightEnvVar = "DECODE_LOAD_AWARE_SCORER_WEIGHT" + + pdEnabledEnvKey = "PD_ENABLED" + + pdPromptLenThresholdEnvKey = "PD_PROMPT_LEN_THRESHOLD" + pdPromptLenThresholdDefault = 10 +) + +const ( + loadAwareScorerName = "LoadAwareScorer" + kvCacheAwareScorerName = "KVCacheAwareScorer" +) + +func addScorerByEnvironment(ctx context.Context, config *SchedulerConfig, scorerName string, scorerEnabledEnvKey string, weightEnvKey string, logger logr.Logger) { + if envutil.GetEnvString(scorerEnabledEnvKey, "false", logger) != "true" { + logger.Info(fmt.Sprintf("Skipping %s creation as it is not enabled", scorerName)) + return + } + + weight := envutil.GetEnvInt(weightEnvKey, 1, logger) + scorer, err := createScorerByName(ctx, scorerName) + if err != nil { + logger.Error(err, "Failed to create scorrer") + return + } + + defaultConfig.scorers[scorer] = weight + logger.Info("Initialized scorer", "scorer", scorerName, "weight", weight) +} + +func createScorerByName(ctx context.Context, name string) (plugins.Scorer, error) { + switch name { + case loadAwareScorerName: + return &scorer.LoadAwareScorer{}, nil + case kvCacheAwareScorerName: + return scorer.NewKVCacheAwareScorer(ctx) + } + return nil, fmt.Errorf("invalid scorer type %s", name) +} + +func getPDEnabledFromEnvironment(logger logr.Logger) bool { + return envutil.GetEnvString(pdEnabledEnvKey, "false", logger) == "true" +} + +func getPDPromptLenThresholdFromEnvironment(logger logr.Logger) int { + return envutil.GetEnvInt(pdPromptLenThresholdEnvKey, pdPromptLenThresholdDefault, logger) +} diff --git a/pkg/epp/scheduling/local_config.go b/pkg/epp/scheduling/local_config.go index 87098ae0d..a1812b0bd 100644 --- a/pkg/epp/scheduling/local_config.go +++ b/pkg/epp/scheduling/local_config.go @@ -17,12 +17,107 @@ limitations under the License. package scheduling import ( - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorers" + "context" + + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + kvCacheScorerEnablementEnvVar = "ENABLE_KVCACHE_AWARE_SCORER" + loadAwareScorerEnablementEnvVar = "ENABLE_LOAD_AWARE_SCORER" + prefixScorerEnablementEnvVar = "ENABLE_PREFIX_AWARE_SCORER" + sessionAwareScorerEnablementEnvVar = "ENABLE_SESSION_AWARE_SCORER" + pdFilterEnablementEnvVar = "ENABLE_PD_FILTER" + + kvCacheScorerWeightEnvVar = "KVCACHE_AWARE_SCORER_WEIGHT" + loadAwareScorerWeightEnvVar = "LOAD_AWARE_SCORER_WEIGHT" + prefixScorerWeightEnvVar = "PREFIX_AWARE_SCORER_WEIGHT" + sessionAwareScorerWeightEnvVar = "SESSION_AWARE_SCORER_WEIGHT" ) func init() { - defaultConfig.scorers[&scorers.LoadBasedScorer{}] = 1.0 + setDefaultConfig() +} + +func setDefaultConfig() { + // since the default config is a global variable, we add this function to minimize rebase conflicts. + // this configuration is a temporary state, it should be better streamlined. + setLoadAwareScorer() + setSessionAwareScorer() + setKVCacheAwareScorer() + setPrefixScorer() + + defaultConfig.picker = picker.NewMaxScorePicker() +} + +func setLoadAwareScorer() { + ctx := context.Background() + loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG) + + if envutil.GetEnvString(loadAwareScorerEnablementEnvVar, "false", loggerDebug) != "true" { + loggerDebug.Info("Skipping LoadAwareScorer creation as it is not enabled") + return + } + + loadBasedScorerWeight := envutil.GetEnvInt(loadAwareScorerWeightEnvVar, 1, loggerDebug) + defaultConfig.scorers[&scorer.LoadAwareScorer{}] = loadBasedScorerWeight + loggerDebug.Info("Initialized LoadAwareScorer", "weight", loadBasedScorerWeight) +} + +func setSessionAwareScorer() { + ctx := context.Background() + loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG) + + if envutil.GetEnvString(sessionAwareScorerEnablementEnvVar, "false", loggerDebug) != "true" { + loggerDebug.Info("Skipping SessionAwareScorer creation as it is not enabled") + return + } + + sessionBasedScorerWeight := envutil.GetEnvInt(sessionAwareScorerWeightEnvVar, 1, loggerDebug) + sessionAffinity := scorer.NewSessionAffinity() + + defaultConfig.scorers[sessionAffinity] = sessionBasedScorerWeight + defaultConfig.postResponsePlugins = append(defaultConfig.postResponsePlugins, sessionAffinity) + loggerDebug.Info("Initialized SessionAwareScorer", "weight", sessionBasedScorerWeight) +} + +func setKVCacheAwareScorer() { + ctx := context.Background() + loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG) + + if envutil.GetEnvString(kvCacheScorerEnablementEnvVar, "false", loggerDebug) != "true" { + loggerDebug.Info("Skipping KVCacheAwareScorer creation as it is not enabled") + return + } + + kvCacheScorer, err := scorer.NewKVCacheAwareScorer(ctx) + if err != nil { + loggerDebug.Error(err, "Failed to create KVCacheAwareScorer") + return + } + + kvCacheScorerWeight := envutil.GetEnvInt(kvCacheScorerWeightEnvVar, 1, loggerDebug) + defaultConfig.scorers[kvCacheScorer] = kvCacheScorerWeight + loggerDebug.Info("Initialized KVCacheAwareScorer", "weight", kvCacheScorerWeight) +} + +func setPrefixScorer() { + ctx := context.Background() + loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG) + + if envutil.GetEnvString(prefixScorerEnablementEnvVar, "false", loggerDebug) != "true" { + loggerDebug.Info("Skipping PrefixScorer creation as it is not enabled") + return + } + + prefixScorerWeight := envutil.GetEnvInt(prefixScorerWeightEnvVar, 1, loggerDebug) + prefixScorer := scorer.NewPrefixAwareScorer(nil) + defaultConfig.scorers[prefixScorer] = prefixScorerWeight // TODO: make configurable + defaultConfig.postSchedulePlugins = append(defaultConfig.postSchedulePlugins, prefixScorer) - // Added as a reference - // defaultConfig.filters = []plugins.Filter{filter.PDFilter} + loggerDebug.Info("Initialized PrefixAwareScorer", "weight", prefixScorerWeight) } diff --git a/pkg/epp/scheduling/pd_config.go b/pkg/epp/scheduling/pd_config.go new file mode 100644 index 000000000..3371093a1 --- /dev/null +++ b/pkg/epp/scheduling/pd_config.go @@ -0,0 +1,77 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "context" + + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +var prefillConfig = &SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{}, + filters: []plugins.Filter{filter.PrefillFilter}, + scorers: map[plugins.Scorer]int{}, + picker: picker.NewMaxScorePicker(), + postSchedulePlugins: []plugins.PostSchedule{}, + postResponsePlugins: []plugins.PostResponse{}, +} +var decodeConfig = &SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{}, + filters: []plugins.Filter{filter.DecodeFilter}, + scorers: map[plugins.Scorer]int{}, + picker: picker.NewMaxScorePicker(), + postSchedulePlugins: []plugins.PostSchedule{}, + postResponsePlugins: []plugins.PostResponse{}, +} + +var PDEnabled = false +var promptLengthThreshold int + +func init() { + ctx := context.Background() + loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG) + + loadPrefillConfiguration(ctx, loggerDebug) + loadDecodeConfiguration(ctx, loggerDebug) + + // set IsPDEnabled by environment + PDEnabled = getPDEnabledFromEnvironment(loggerDebug) + promptLengthThreshold = getPDPromptLenThresholdFromEnvironment(loggerDebug) + + // update default config if pd is enabled + if PDEnabled { + defaultConfig.filters = append(defaultConfig.filters, filter.DecodeFilter) + } +} + +func loadPrefillConfiguration(ctx context.Context, logger logr.Logger) { + // add scorers + addScorerByEnvironment(ctx, prefillConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger) + addScorerByEnvironment(ctx, prefillConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger) +} + +func loadDecodeConfiguration(ctx context.Context, logger logr.Logger) { + // add scorers + addScorerByEnvironment(ctx, decodeConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger) + addScorerByEnvironment(ctx, decodeConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger) +} diff --git a/pkg/epp/scheduling/pd_scheduler.go b/pkg/epp/scheduling/pd_scheduler.go new file mode 100644 index 000000000..37822201a --- /dev/null +++ b/pkg/epp/scheduling/pd_scheduler.go @@ -0,0 +1,90 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package scheduling implements request scheduling algorithms. +package scheduling + +import ( + "context" + "fmt" + + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + prefillPodHeader = "x-prefiller-url" +) + +func NewPDScheduler(datastore Datastore) *PDScheduler { + return NewPDSchedulerWithConfig(datastore, prefillConfig, decodeConfig, defaultConfig) +} + +func NewPDSchedulerWithConfig(datastore Datastore, pConfig *SchedulerConfig, dConfig *SchedulerConfig, defConfig *SchedulerConfig) *PDScheduler { + return &PDScheduler{ + datastore: datastore, + prefillScheduler: NewSchedulerWithConfig(datastore, pConfig), + decodeScheduler: NewSchedulerWithConfig(datastore, dConfig), + defaultScheduler: NewSchedulerWithConfig(datastore, defConfig), + } +} + +type PDScheduler struct { + datastore Datastore + prefillScheduler *Scheduler + decodeScheduler *Scheduler + defaultScheduler *Scheduler +} + +// Schedule finds the target pod based on metrics and the requested lora adapter. +// PD scheduler uses three base schedulers to process requests, the overall configuration is currently loaded from environment variables. +// If the request prompt is short enough (defined by the threshold in the configuration) - use the default behavior +// If the request prompt is long enough to use prefill-decode process: +// 1 - find the pod for prefill, save its url in a special header. For this, use the Scheduler configured for this goal, which uses the prefill filter +// and scorers according to the configuration. +// 2 - find the pod for decode, use the Scheduler configured for this goal, which uses the decode filer and scorers defined in the configuration +func (s *PDScheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { + logger := log.FromContext(ctx).WithValues("pd-schedule", req) + + if len(req.Prompt) < promptLengthThreshold { + // the prompt is short enough - use the default scheduling logic + return s.defaultScheduler.Schedule(ctx, req) + } + + sCtx, err := createSchedulerContext(ctx, req, s.datastore) + if err != nil { + return nil, err + } + + // prompt requires processing on two pods - prefill and decode + // start with calculating of the prefill pod + res, err := s.prefillScheduler.scheduleWithContext(ctx, sCtx, req, logger) + if err != nil { + return nil, err + } + + if res.TargetPod != nil { + url := fmt.Sprintf("http://%s:%d", res.TargetPod.GetPod().Address, sCtx.TargetPort) + sCtx.MutatedHeaders[prefillPodHeader] = url + } + + // get decode pod + return s.decodeScheduler.scheduleWithContext(ctx, sCtx, req, logger) +} + +func (s *PDScheduler) RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) { + return s.decodeScheduler.RunPostResponsePlugins(ctx, req, targetPodName) +} diff --git a/pkg/epp/scheduling/pd_scheduler_test.go b/pkg/epp/scheduling/pd_scheduler_test.go new file mode 100644 index 000000000..1cec19433 --- /dev/null +++ b/pkg/epp/scheduling/pd_scheduler_test.go @@ -0,0 +1,154 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "context" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + k8stypes "k8s.io/apimachinery/pkg/types" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +// Tests the default scheduler configuration and expected behavior. +func TestPDSchedule(t *testing.T) { + // Set configuration + PDEnabled = true + promptLengthThreshold = 10 + prefillConfig.filters = []plugins.Filter{filter.PrefillFilter} + prefillConfig.scorers = map[plugins.Scorer]int{} + decodeConfig.filters = []plugins.Filter{filter.DecodeFilter} + decodeConfig.scorers = map[plugins.Scorer]int{} + + pod1 := &backendmetrics.FakePodMetrics{ + Pod: &backendmetrics.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod1"}, + Address: "1.2.3.4", + Role: backendmetrics.Prefill, + }, + Metrics: &backendmetrics.Metrics{}, + } + pod2 := &backendmetrics.FakePodMetrics{ + Pod: &backendmetrics.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod2"}, + Address: "5.6.7.8", + Role: backendmetrics.Decode, + }, + Metrics: &backendmetrics.Metrics{}, + } + wantPod1 := &types.PodMetrics{ + Pod: &backendmetrics.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod1"}, + Address: "1.2.3.4", + Role: backendmetrics.Prefill, + }, + Metrics: &backendmetrics.Metrics{ + ActiveModels: map[string]int{}, + WaitingModels: map[string]int{}, + }, + } + wantPod2 := &types.PodMetrics{ + Pod: &backendmetrics.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod2"}, + Address: "5.6.7.8", + Role: backendmetrics.Decode, + }, + Metrics: &backendmetrics.Metrics{ + ActiveModels: map[string]int{}, + WaitingModels: map[string]int{}, + }, + } + + tests := []struct { + name string + req *types.LLMRequest + input []*backendmetrics.FakePodMetrics + wantRes *types.Result + err bool + }{ + { + name: "no pods in datastore", + req: &types.LLMRequest{ + Model: "any-model", + ResolvedTargetModel: "any-model", + Critical: true, + Prompt: "12345678901", + }, + input: []*backendmetrics.FakePodMetrics{}, + err: true, + }, + { + name: "one pod, short prompt", + req: &types.LLMRequest{ + Model: "critical", + ResolvedTargetModel: "critical", + Critical: true, + Prompt: "123", + }, + // pod1 will be picked because it is the only one pod + input: []*backendmetrics.FakePodMetrics{pod1}, + wantRes: &types.Result{ + TargetPod: &types.ScoredPod{ + Pod: wantPod1, + }, + MutatedHeaders: map[string]string{}, + }, + }, + { + name: "1P1D", + req: &types.LLMRequest{ + Model: "critical", + ResolvedTargetModel: "critical", + Critical: true, + Prompt: "12345678901", + }, + // pod2 will be picked because it is the decode pod + input: []*backendmetrics.FakePodMetrics{pod1, pod2}, + wantRes: &types.Result{ + TargetPod: &types.ScoredPod{ + Pod: wantPod2, + Score: 0.0, + }, + MutatedHeaders: map[string]string{"x-prefiller-url": "http://1.2.3.4:0"}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scheduler := NewPDScheduler(&fakeDataStore{pods: test.input}) + got, err := scheduler.Schedule(context.Background(), test.req) + + fmt.Printf("Test %s:\n", test.name) + fmt.Printf("Result: %#v\n", got) + fmt.Printf("Expected: %#v\n", test.wantRes) + + if test.err != (err != nil) { + t.Errorf("Unexpected error, got %v, want %v", err, test.err) + } + + if diff := cmp.Diff(test.wantRes, got); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) + } + }) + } +} diff --git a/pkg/epp/scheduling/plugins/filter/pd_filter.go b/pkg/epp/scheduling/plugins/filter/pd_filter.go index 945c615d3..fd4c5a8cc 100644 --- a/pkg/epp/scheduling/plugins/filter/pd_filter.go +++ b/pkg/epp/scheduling/plugins/filter/pd_filter.go @@ -16,54 +16,44 @@ limitations under the License. package filter import ( - "fmt" - "math/rand/v2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) -const ( - prefillPodHeader = "x-prefiller-url" -) - -var PDFilter = &baseFilter{ - name: "p/d filter", - filter: prefillDecodeFilterFunc, +// PrefillFilter - filters out all pods that are not marked as decode/both pod role +var PrefillFilter = &baseFilter{ + name: "prefill_filter", + filter: prefillFilterFunc, } -// prefillDecodeFilterFunc implements a pod selection strategy that filters out pods, -// which role is 'prefill', in addition a header with selected prefill pod is added -// -// Initial implementation: -// 1 - select one random pod marked as 'prefill' and add it name to header -// 2 - return a random pod that marked as "decode" or "both" -// -// Returns: -// - Filtered slice of pod metrics, could contain one or zerro elements -func prefillDecodeFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { - pPods := make([]types.Pod, 0) - dPods := make([]types.Pod, 0) +// prefillFilterFunc filters out all pods that are not marked as "prefill" +func prefillFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + filteredPods := make([]types.Pod, 0) for _, pod := range pods { if pod.GetPod().Role == metrics.Prefill { - pPods = append(pPods, pod) - } else if pod.GetPod().Role == metrics.Decode || pod.GetPod().Role == metrics.Both { - dPods = append(dPods, pod) + filteredPods = append(filteredPods, pod) } } - if len(pPods) > 0 { - // select a random prefill pod - randomIndex := rand.IntN(len(pPods)) - ctx.MutatedHeaders[prefillPodHeader] = fmt.Sprintf("http://%s:%d", pPods[randomIndex].GetPod().Address, ctx.TargetPort) - } + return filteredPods +} - if len(dPods) > 1 { - // leave only one pod - randomIndex := rand.IntN(len(dPods)) - return []types.Pod{dPods[randomIndex]} +// DecodeFilter - fiters out all pods that are not marked as prefill pod role +var DecodeFilter = &baseFilter{ + name: "decode_filter", + filter: decodeFilterFunc, +} + +// decodeFilterFunc filters out all pods that are not marked as "decode" or "both" +func decodeFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + filteredPods := make([]types.Pod, 0) + + for _, pod := range pods { + if pod.GetPod().Role == metrics.Decode || pod.GetPod().Role == metrics.Both { + filteredPods = append(filteredPods, pod) + } } - return dPods + return filteredPods } diff --git a/pkg/epp/scheduling/plugins/scorer/kvcache-aware-scorer.go b/pkg/epp/scheduling/plugins/scorer/kvcache-aware-scorer.go new file mode 100644 index 000000000..bc025751e --- /dev/null +++ b/pkg/epp/scheduling/plugins/scorer/kvcache-aware-scorer.go @@ -0,0 +1,142 @@ +/* +Copyright 2025 The Neural Magic Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scorer + +import ( + "context" + "fmt" + "os" + + kvcache "github.com/neuralmagic/llm-d-kv-cache-manager/pkg/kv-cache" + + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + kvCacheAwareScorerName = "kvcache-aware-scorer" + kvCacheRedisEnvVar = "KVCACHE_INDEXER_REDIS_ADDR" + huggingFaceTokenEnvVar = "HF_TOKEN" +) + +// KVCacheAwareScorer uses the KVCacheIndexer to score pods based on KVCache +// awareness. +type KVCacheAwareScorer struct { + kvCacheIndexer *kvcache.Indexer +} + +// NewKVCacheAwareScorer creates a new KVCacheAwareScorer instance. +// It initializes the KVCacheIndexer from environment variables. +// +// If the environment variables are not set, or if the indexer +// fails to initialize, an error is returned. +func NewKVCacheAwareScorer(ctx context.Context) (plugins.Scorer, error) { + config := kvcache.NewDefaultConfig() + + redisAddr := os.Getenv(kvCacheRedisEnvVar) + if redisAddr != "" { + config.KVBlockIndexerConfig.RedisKVBlockIndexerConfig.RedisAddr = redisAddr + } else { + return nil, fmt.Errorf("environment variable %s is not set", kvCacheRedisEnvVar) + } + + hfToken := os.Getenv(huggingFaceTokenEnvVar) + if hfToken != "" { + config.TokenizersPoolConfig.HFTokenizerConfig.HuggingFaceToken = hfToken + } else { + return nil, fmt.Errorf("environment variable %s is not set", huggingFaceTokenEnvVar) + } + + kvCacheIndexer, err := kvcache.NewKVCacheIndexer(config) + if err != nil { + return nil, fmt.Errorf("failed to create KVCacheIndexer: %w", err) + } + + go kvCacheIndexer.Run(ctx) + + return &KVCacheAwareScorer{ + kvCacheIndexer: kvCacheIndexer, + }, nil +} + +// Name returns the name of the scorer. +func (s *KVCacheAwareScorer) Name() string { + return kvCacheAwareScorerName +} + +// Score scores the provided pod based on the KVCache index state. +// The returned scores are normalized to a range of 0-1. +func (s *KVCacheAwareScorer) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { + loggerDebug := log.FromContext(ctx).WithName(kvCacheAwareScorerName).V(logutil.DEBUG) + if ctx.Req == nil { + loggerDebug.Info("Request is nil, skipping scoring") + return nil + } + + scores, err := s.kvCacheIndexer.GetPodScores(ctx.Context, ctx.Req.Prompt, ctx.Req.Model, nil) + if err != nil { + loggerDebug.Error(err, "Failed to get pod scores") + return nil + } + loggerDebug.Info("Got pod scores", "scores", scores) + + return indexerScoresToNormalizedScoredPods(pods, scores) +} + +func getMinMax(scores map[string]int) (int, int) { + minScore := int(^uint(0) >> 1) // max int + maxScore := -1 + + for _, score := range scores { + if score < minScore { + minScore = score + } + if score > maxScore { + maxScore = score + } + } + + return minScore, maxScore +} + +func indexerScoresToNormalizedScoredPods(pods []types.Pod, scores map[string]int) map[types.Pod]float64 { + scoredPods := make(map[types.Pod]float64) + minScore, maxScore := getMinMax(scores) + + for _, pod := range pods { + metricsPod := pod.GetPod() + if metricsPod == nil { + continue + } + + if score, ok := scores[metricsPod.Address]; ok { + if minScore == maxScore { + scoredPods[pod] = 1.0 + continue + } + + scoredPods[pod] = float64(score-minScore) / float64(maxScore-minScore) + } else { + scoredPods[pod] = 0.0 + } + } + + return scoredPods +} diff --git a/pkg/epp/scheduling/plugins/scorers/load_based_scorer.go b/pkg/epp/scheduling/plugins/scorer/load_based_scorer.go similarity index 88% rename from pkg/epp/scheduling/plugins/scorers/load_based_scorer.go rename to pkg/epp/scheduling/plugins/scorer/load_based_scorer.go index 5bea87c95..d24f49b33 100644 --- a/pkg/epp/scheduling/plugins/scorers/load_based_scorer.go +++ b/pkg/epp/scheduling/plugins/scorer/load_based_scorer.go @@ -13,17 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -package scorers + +package scorer import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) -type LoadBasedScorer struct{} +type LoadAwareScorer struct{} -func (s LoadBasedScorer) Name() string { - return "load based scorer" +func (s *LoadAwareScorer) Name() string { + return "load-aware-scorer" } // Score scores the given pod in range of 0-1 @@ -33,7 +34,7 @@ func (s LoadBasedScorer) Name() string { // Pod with requests in the queue will get score between 0.5 and 0. // Score 0 will get pod with number of requests in the queue equal to the threshold used in load-based filter (QueueingThresholdLoRA) // In future pods with additional capacity will get score higher than 0.5 -func (s LoadBasedScorer) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { +func (s *LoadAwareScorer) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { scoredPods := make(map[types.Pod]float64) for _, pod := range pods { diff --git a/pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer.go b/pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer.go new file mode 100644 index 000000000..8c3d673b0 --- /dev/null +++ b/pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer.go @@ -0,0 +1,134 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scorer + +import ( + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const prefixAwareScorerName = "prefix-aware-scorer" + +// PrefixAwareScorer is a routing scorer that scores pods based on the longest prefix match +// between the request's prompt and stored prefixes. The score is normalized between 0 and 1, +// where 1 represents the longest matching prefix. +type PrefixAwareScorer struct { + prefixStore *PrefixStore +} + +var _ plugins.Scorer = &PrefixAwareScorer{} + +// NewPrefixAwareScorer creates a new PrefixAwareScorer with the given +// PrefixStoreConfig. If the config is nil, default is used. +func NewPrefixAwareScorer(config *PrefixStoreConfig) *PrefixAwareScorer { + return &PrefixAwareScorer{ + prefixStore: NewPrefixStore(config), + } +} + +func (s *PrefixAwareScorer) Name() string { + return "prefix-aware-scorer" +} + +// Score scores the target pods based on the longest prefix match. +func (s *PrefixAwareScorer) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { + loggerDebug := log.FromContext(ctx).WithName(prefixAwareScorerName).V(logutil.DEBUG) + if ctx.Req == nil { + loggerDebug.Info("Request is nil, skipping scoring") + return nil + } + + scores := s.prefixStore.FindMatchingPods(ctx.Req.Prompt, ctx.Req.Model) + loggerDebug.Info("Got pod scores", "scores", scores) + + if len(scores) == 0 { + loggerDebug.Info("No scores found for pods") + return nil + } + + podToKey := func(pod types.Pod) (string, bool) { + if pod.GetPod() == nil { + return "", false + } + + return pod.GetPod().NamespacedName.String(), true + } + + return indexedScoresToNormalizedScoredPods(pods, podToKey, scores) +} + +// PostSchedule implements the PostSchedulePlugin interface. +// It adds the prefix to the PrefixStore for the given pod. +// TODO: switch to PostResponse. +func (s *PrefixAwareScorer) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { + pod := res.TargetPod + + debugLogger := log.FromContext(ctx).WithName(prefixAwareScorerName) + debugLogger.Info("PostResponse called", "req", ctx.Req, "pod", pod) + + if ctx.Req == nil { + debugLogger.Info("Request is nil, skipping PostResponse") + return + } + + if pod.GetPod() == nil { + debugLogger.Info("Pod is nil, skipping PostResponse", "req", ctx.Req, "pod", pod) + return + } + + if err := s.prefixStore.AddEntry(ctx.Req.Model, ctx.Req.Prompt, &pod.GetPod().NamespacedName); err != nil { + debugLogger.Error(err, "Failed to add entry to prefix store", "req", ctx.Req, "pod", pod) + return + } +} + +// GetPrefixStore returns the scorer's PrefixStore. +func (s *PrefixAwareScorer) GetPrefixStore() *PrefixStore { + return s.prefixStore +} + +// podToKey is a function type that converts a Pod to a string key. +// It returns the key and a boolean indicating success. +type podToKeyFunc func(pod types.Pod) (string, bool) + +func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc, + scores map[string]int) map[types.Pod]float64 { + scoredPods := make(map[types.Pod]float64) + minScore, maxScore := getMinMax(scores) + + for _, pod := range pods { + key, ok := podToKey(pod) + if !ok { + continue + } + + if score, ok := scores[key]; ok { + if minScore == maxScore { + scoredPods[pod] = 1.0 + continue + } + + scoredPods[pod] = float64(score-minScore) / float64(maxScore-minScore) + } else { + scoredPods[pod] = 0.0 + } + } + + return scoredPods +} diff --git a/pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer_test.go b/pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer_test.go new file mode 100644 index 000000000..49318fa47 --- /dev/null +++ b/pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer_test.go @@ -0,0 +1,156 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scorer_test + +import ( + "context" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "testing" +) + +func TestPrefixAwareScorer(t *testing.T) { + ctx := context.Background() + logger := log.FromContext(ctx) + ctx = log.IntoContext(ctx, logger) + + // Create test pods + pod1 := &types.PodMetrics{ + Pod: &backendmetrics.Pod{ + NamespacedName: k8stypes.NamespacedName{ + Name: "pod1", + Namespace: "default", + }, + }, + Metrics: &backendmetrics.Metrics{}, + } + pod2 := &types.PodMetrics{ + Pod: &backendmetrics.Pod{ + NamespacedName: k8stypes.NamespacedName{ + Name: "pod2", + Namespace: "default", + }, + }, + Metrics: &backendmetrics.Metrics{}, + } + + tests := []struct { + name string + weight float64 + prompt string + modelName string + prefixToAdd string + podToAdd k8stypes.NamespacedName + prefixModel string // Model name to use when adding the prefix + expectedScores map[types.Pod]float64 + }{ + { + name: "no prompt", + weight: 1.0, + prompt: "", + modelName: "model1", + prefixToAdd: "hello", + podToAdd: pod1.Pod.NamespacedName, + prefixModel: "model1", + expectedScores: map[types.Pod]float64{}, // No prompt means zero scores + }, + { + name: "exact prefix match", + weight: 1.0, + prompt: "hello world", + modelName: "model1", + prefixToAdd: "hello", + podToAdd: pod1.Pod.NamespacedName, + prefixModel: "model1", + expectedScores: map[types.Pod]float64{ + pod1: 1.0, + pod2: 0.0, + }, // pod1 matches, pod2 doesn't + }, + { + name: "no prefix match", + weight: 1.0, + prompt: "goodbye", + modelName: "model1", + prefixToAdd: "hello", + podToAdd: pod1.Pod.NamespacedName, + prefixModel: "model1", + expectedScores: map[types.Pod]float64{}, // No matching prefix + }, + { + name: "different model name", + weight: 1.0, + prompt: "hello world", + modelName: "model2", // Try to find with model2 + prefixToAdd: "hello", + podToAdd: pod1.Pod.NamespacedName, + prefixModel: "model1", // But prefix was added with model1 + expectedScores: map[types.Pod]float64{}, // Model name mismatch should result in no match + }, + { + name: "custom weight", + weight: 0.5, + prompt: "hello world", + modelName: "model1", + prefixToAdd: "hello", + podToAdd: pod1.Pod.NamespacedName, + prefixModel: "model1", + expectedScores: map[types.Pod]float64{ + pod1: 0.5, // Pod1 matches with weight + pod2: 0.0, // Pod2 doesn't match + }, // Weight affects score + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset prefix store for each test + config := scorer.DefaultPrefixStoreConfig() + config.BlockSize = 5 // set small chunking for testing + + s := scorer.NewPrefixAwareScorer(config) + + // Add prefix if specified + if tt.prefixToAdd != "" { + err := s.GetPrefixStore().AddEntry(tt.prefixModel, + tt.prefixToAdd, &tt.podToAdd) + if err != nil { + t.Fatalf("Failed to add prefix: %v", err) + } + } + + // Create test context + sCtx := types.NewSchedulingContext(ctx, &types.LLMRequest{ + Prompt: tt.prompt, + ResolvedTargetModel: tt.modelName, + }, []types.Pod{}, 0) + + // Score pods + pods := []types.Pod{pod1, pod2} + scores := s.Score(sCtx, pods) + + for p, score := range scores { + if score != tt.expectedScores[p] { + t.Errorf("Pod %v: expected score %v, got %v", p, tt.expectedScores[p], score) + } + } + }) + } +} diff --git a/pkg/epp/scheduling/plugins/scorer/prefix_store.go b/pkg/epp/scheduling/plugins/scorer/prefix_store.go new file mode 100644 index 000000000..8c6961647 --- /dev/null +++ b/pkg/epp/scheduling/plugins/scorer/prefix_store.go @@ -0,0 +1,181 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scorer + +import ( + "fmt" + "k8s.io/apimachinery/pkg/types" + "sync" + "time" + + "github.com/cespare/xxhash/v2" + lru "github.com/hashicorp/golang-lru/v2" +) + +const ( + // defaultMaxCacheSize sets the maximum number of blocks the LRU cache can store. + defaultMaxCacheSize = 500000 + // defaultBlockSize defines how many runes each block contains in the prefix cache. + defaultBlockSize = 256 + // defaultMaxBlockCacheSize sets the maximum number of pods a block can store. + defaultMaxBlockCacheSize = 100 +) + +// PrefixStoreConfig contains initialization configuration for PrefixStore. +type PrefixStoreConfig struct { + // CacheSize sets the maximum number of blocks the LRU cache can store. + CacheSize int + // BlockSize defines how many runes each block contains in the prefix cache. + BlockSize int + // BlockCacheSize sets the maximum number of pods a block can store. + BlockCacheSize int +} + +// DefaultPrefixStoreConfig returns an PrefixStoreConfig instance with default +// configuration. +func DefaultPrefixStoreConfig() *PrefixStoreConfig { + return &PrefixStoreConfig{ + CacheSize: defaultMaxCacheSize, + BlockSize: defaultBlockSize, + BlockCacheSize: defaultMaxBlockCacheSize, + } +} + +// block holds the tokens contained in the block. +type block struct { + Pods *lru.Cache[types.NamespacedName, time.Time] //TODO: implement Pod eviction based on staleness +} + +// PrefixStore is an in-memory prefix-to-block cache with xxhash keys and LRU +// eviction. +type PrefixStore struct { + sync.RWMutex + + cacheSize int + blockSize int + blockCacheSize int + + store map[string]*lru.Cache[uint64, *block] +} + +// NewPrefixStore initializes the PrefixStore with LRU cache. +// If the configuration is nil, default is used. +func NewPrefixStore(config *PrefixStoreConfig) *PrefixStore { + if config == nil { + config = DefaultPrefixStoreConfig() + } + + return &PrefixStore{ + cacheSize: config.CacheSize, + blockSize: config.BlockSize, + blockCacheSize: config.BlockCacheSize, + store: make(map[string]*lru.Cache[uint64, *block]), + } +} + +// AddEntry adds a new entry to the prefix store. +func (s *PrefixStore) AddEntry(modelName string, prompt string, pod *types.NamespacedName) error { + if prompt == "" || pod == nil || len(prompt) < s.blockSize /* skip if prompt is too short */ { + return nil + } + + s.Lock() + // Get or create the LRU cache for the model + cache, ok := s.store[modelName] + if !ok { + var err error + cache, err = lru.New[uint64, *block](s.cacheSize) + if err != nil { + return fmt.Errorf("failed to create LRU cache for model %s: %w", modelName, err) + } + + s.store[modelName] = cache + } + s.Unlock() + + // Chunk the text into blocks and populate the cache + for start := 0; start < len(prompt); start += s.blockSize { + end := start + s.blockSize + if end > len(prompt) { + break // skip partial blocks + } + + // Compute the hash for the current block + digest := xxhash.New() + if _, err := digest.WriteString(prompt[start:end]); err != nil { + return fmt.Errorf("failed to compute chunk hash: %w", err) + } + + blockHash := digest.Sum64() + + b, ok := cache.Get(blockHash) + if !ok { + pods, err := lru.New[types.NamespacedName, time.Time](s.blockCacheSize) + if err != nil { + return fmt.Errorf("failed to create LRU cache for block: %w", err) + } + + b = &block{Pods: pods} + cache.Add(blockHash, b) + } + + b.Pods.Add(*pod, time.Now()) // thread-safe + } + + return nil +} + +// FindMatchingPods finds all pods that match the given prompt and model name. +// It returns a map of pods and the number of blocks they match. +func (s *PrefixStore) FindMatchingPods(prompt, modelName string) map[string]int { + if prompt == "" || modelName == "" || len(prompt) < s.blockSize /* skip if prompt is too short */ { + return nil + } + + s.RLock() + cache, ok := s.store[modelName] // cache is thread-safe + s.RUnlock() + + if !ok { + return nil + } + + matchedPods := make(map[string]int) + for start := 0; start < len(prompt); start += s.blockSize { + end := start + s.blockSize + if end > len(prompt) { + end = len(prompt) + } + + digest := xxhash.New() + if _, err := digest.WriteString(prompt[start:end]); err != nil { + return nil + } + blockHash := digest.Sum64() + + b, ok := cache.Get(blockHash) + if !ok { + break // match consecutive blocks + } + + for _, pod := range b.Pods.Keys() { + matchedPods[pod.String()]++ + } + } + + return matchedPods +} diff --git a/pkg/epp/scheduling/plugins/scorer/prefix_store_test.go b/pkg/epp/scheduling/plugins/scorer/prefix_store_test.go new file mode 100644 index 000000000..c0765b845 --- /dev/null +++ b/pkg/epp/scheduling/plugins/scorer/prefix_store_test.go @@ -0,0 +1,59 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scorer_test + +import ( + "context" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" + "testing" +) + +// TestBasicPrefixOperations tests the basic functionality of adding and finding prefixes +func TestBasicPrefixOperations(t *testing.T) { + ctx := context.Background() + logger := log.FromContext(ctx) + ctx = log.IntoContext(ctx, logger) + + config := scorer.DefaultPrefixStoreConfig() + config.BlockSize = 5 // set small chunking for testing + store := scorer.NewPrefixStore(config) + + podName := k8stypes.NamespacedName{ + Name: "pod1", + Namespace: "default", + } + + // Test adding a prefix + err := store.AddEntry("model1", "hello", &podName) + if err != nil { + t.Errorf("Failed to add prefix: %v", err) + } + + // Test finding the exact prefix + scores := store.FindMatchingPods("hello", "model1") + if _, ok := scores[podName.String()]; !ok { + t.Errorf("Expected pod %v, scores %v", podName, scores) + } + + // Test finding with a longer prefix + scores = store.FindMatchingPods("hello world", "model1") + if _, ok := scores[podName.String()]; !ok { + t.Errorf("Expected pod %v, scores %v", podName, scores) + } +} diff --git a/pkg/epp/scheduling/plugins/scorer/session-affinity-scorer.go b/pkg/epp/scheduling/plugins/scorer/session-affinity-scorer.go new file mode 100644 index 000000000..2431b95a2 --- /dev/null +++ b/pkg/epp/scheduling/plugins/scorer/session-affinity-scorer.go @@ -0,0 +1,79 @@ +/* +Copyright 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scorer + +import ( + "encoding/base64" + "time" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + sessionKeepAliveTime = 60 * time.Minute // How long should an idle session be kept alive + sessionKeepAliveCheckFrequency = 15 * time.Minute // How often to check for overly idle sessions + sessionTokenHeader = "x-session-token" // name of the session header in request +) + +// sessionAffinity is a routing scorer that routes subsequent +// requests in a session to the same pod as the first request in the +// session was sent to, by giving that pod the specified weight and assigning +// zero score to the rest of the targets +type SessionAffinity struct { +} + +func NewSessionAffinity() *SessionAffinity { + return &SessionAffinity{} +} + +func (s *SessionAffinity) Name() string { + return "session affinity scorer" +} + +func (s *SessionAffinity) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { + scoredPods := make(map[types.Pod]float64) + + reqHeaders := ctx.Req.Headers + + var sessionToken = "" + v, ok := reqHeaders[sessionTokenHeader] + if ok { + sessionToken = v + } + + podName := "" + if sessionToken != "" { + decodedBytes, err := base64.StdEncoding.DecodeString(sessionToken) + if err != nil { + ctx.Logger.Error(err, "Error decoding") + } else { + podName = string(decodedBytes) + } + } + for _, pod := range pods { + if podName == "" { + scoredPods[pod] = 0.0 + } else { + if pod.GetPod().NamespacedName.String() == podName { + scoredPods[pod] = 1.0 + } + } + } + + return scoredPods +} + +func (s *SessionAffinity) PostResponse(ctx *types.SchedulingContext, pod types.Pod) { + ctx.MutatedHeaders[sessionTokenHeader] = base64.StdEncoding.EncodeToString([]byte(pod.GetPod().NamespacedName.String())) +} diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 9bad61316..b56d20ca7 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -22,6 +22,7 @@ import ( "fmt" "time" + "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -80,6 +81,7 @@ func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Sched scorers: config.scorers, picker: config.picker, postSchedulePlugins: config.postSchedulePlugins, + postResponsePlugins: config.postResponsePlugins, } } @@ -90,6 +92,7 @@ type Scheduler struct { scorers map[plugins.Scorer]int // map from scorer to its weight picker plugins.Picker postSchedulePlugins []plugins.PostSchedule + postResponsePlugins []plugins.PostResponse } type Datastore interface { @@ -97,20 +100,32 @@ type Datastore interface { PodGetAll() []backendmetrics.PodMetrics } +func createSchedulerContext(ctx context.Context, req *types.LLMRequest, datastore Datastore) (*types.SchedulingContext, error) { + pool, err := datastore.PoolGet() + if err != nil { + return nil, errutil.Error{Code: errutil.Internal, Msg: "failed to find a target pod"} // pool not defined, no pods + } + + // Snapshot pod metrics from the datastore to: + // 1. Reduce concurrent access to the datastore. + // 2. Ensure consistent data during the scheduling operation of a request. + return types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(datastore.PodGetAll()), pool.Spec.TargetPortNumber), nil +} + // Schedule finds the target pod based on metrics and the requested lora adapter. func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { logger := log.FromContext(ctx).WithValues("request", req) loggerDebug := logger.V(logutil.DEBUG) - pool, err := s.datastore.PoolGet() + sCtx, err := createSchedulerContext(ctx, req, s.datastore) if err != nil { - return nil, errutil.Error{Code: errutil.Internal, Msg: "failed to find a target pod"} // pool not defined, no pods + return nil, err } - // Snapshot pod metrics from the datastore to: - // 1. Reduce concurrent access to the datastore. - // 2. Ensure consistent data during the scheduling operation of a request. - sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()), pool.Spec.TargetPortNumber) + return s.scheduleWithContext(ctx, sCtx, req, loggerDebug) +} + +func (s *Scheduler) scheduleWithContext(ctx context.Context, sCtx *types.SchedulingContext, req *types.LLMRequest, loggerDebug logr.Logger) (*types.Result, error) { loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot)) s.runPreSchedulePlugins(sCtx) @@ -210,6 +225,38 @@ func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *ty } } +func (s *Scheduler) RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) { + logger := log.FromContext(ctx) + + pool, err := s.datastore.PoolGet() + if err != nil { + return nil, errutil.Error{Code: errutil.Internal, Msg: "failed to find a target pod"} // pool not defined, no pods + } + + // Snapshot pod metrics from the datastore to: + // 1. Reduce concurrent access to the datastore. + // 2. Ensure consistent data during the scheduling operation of a request. + pods := types.ToSchedulerPodMetrics(s.datastore.PodGetAll()) + var targetPod types.Pod + for _, pod := range pods { + if pod.GetPod().NamespacedName.String() == targetPodName { + targetPod = pod + break + } + } + + sCtx := types.NewSchedulingContext(ctx, req, pods, pool.Spec.TargetPortNumber) + + for _, plugin := range s.postResponsePlugins { + logger.V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PostResponse(sCtx, targetPod) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PostResponsePluginType, plugin.Name(), time.Since(before)) + } + + return &types.Result{TargetPod: nil, MutatedHeaders: sCtx.MutatedHeaders}, nil +} + type defaultPlugin struct { picker.RandomPicker } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index e6d229aee..eafa8d681 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -483,6 +483,56 @@ func TestSchedulePlugins(t *testing.T) { } } +func TestPostResponse(t *testing.T) { + pr1 := &testPostResponse{ + NameRes: "pr1", + ExtraHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"}, + ReceivedResponseHeaders: make(map[string]string), + } + + tests := []struct { + name string + config SchedulerConfig + input []*backendmetrics.FakePodMetrics + responseHeaders map[string]string + wantMutatedHeaders map[string]string + }{ + { + name: "Simple postResponse test", + config: SchedulerConfig{ + postResponsePlugins: []plugins.PostResponse{pr1}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + }, + responseHeaders: map[string]string{"Content-type": "application/json", "Content-Length": "1234"}, + wantMutatedHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"}, + }, + } + + for _, test := range tests { + scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) + + req := &types.LLMRequest{ + Model: "test-model", + Headers: test.responseHeaders, + } + + result, err := scheduler.RunPostResponsePlugins(context.Background(), req, test.input[0].Pod.NamespacedName.String()) + if err != nil { + t.Errorf("Received an error. Error: %s", err) + } + + if diff := cmp.Diff(test.responseHeaders, pr1.ReceivedResponseHeaders); diff != "" { + t.Errorf("Unexpected output (-responseHeaders +ReceivedResponseHeaders): %v", diff) + } + + if diff := cmp.Diff(test.wantMutatedHeaders, result.MutatedHeaders); diff != "" { + t.Errorf("Unexpected output (-wantedMutatedHeaders +MutatedHeaders): %v", diff) + } + } +} + type fakeDataStore struct { pods []*backendmetrics.FakePodMetrics } @@ -571,6 +621,23 @@ func (tp *TestPlugin) reset() { tp.NumOfPickerCandidates = 0 } +type testPostResponse struct { + NameRes string + ReceivedResponseHeaders map[string]string + ExtraHeaders map[string]string +} + +func (pr *testPostResponse) Name() string { return pr.NameRes } + +func (pr *testPostResponse) PostResponse(ctx *types.SchedulingContext, pod types.Pod) { + for key, value := range ctx.Req.Headers { + pr.ReceivedResponseHeaders[key] = value + } + for key, value := range pr.ExtraHeaders { + ctx.MutatedHeaders[key] = value + } +} + func findPods(ctx *types.SchedulingContext, names ...k8stypes.NamespacedName) []types.Pod { res := []types.Pod{} for _, pod := range ctx.PodsSnapshot { diff --git a/pkg/epp/scheduling/scorers_test.go b/pkg/epp/scheduling/scorers_test.go index 365b2375b..640143bf1 100644 --- a/pkg/epp/scheduling/scorers_test.go +++ b/pkg/epp/scheduling/scorers_test.go @@ -25,7 +25,7 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorers" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -40,7 +40,7 @@ func TestScorers(t *testing.T) { }{ { name: "load based scorer", - scorer: &scorers.LoadBasedScorer{}, + scorer: &scorer.LoadAwareScorer{}, req: &types.LLMRequest{ Model: "critical", ResolvedTargetModel: "critical", @@ -86,19 +86,23 @@ func TestScorers(t *testing.T) { }, }, wantRes: &types.Result{ - TargetPod: &types.PodMetrics{ - Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, + TargetPod: &types.ScoredPod{ + Pod: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + WaitingModels: map[string]int{}, }, - WaitingModels: map[string]int{}, }, + Score: 0.5, }, + MutatedHeaders: map[string]string{}, }, }, } diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index e5896dbc8..d46b9d063 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -27,10 +27,10 @@ import ( // LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. type LLMRequest struct { - Model string + Model string + Prompt string // Target models is a map of target model name to weight. TargetModels map[string]int - Prompt string Headers map[string]string // Resolved target model is the final target model after traffic split. ResolvedTargetModel string diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index 0c0a6a6dc..9b8ea4177 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -137,7 +137,14 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { } else { srv = grpc.NewServer() } - extProcServer := handlers.NewStreamingServer(scheduling.NewScheduler(r.Datastore), r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore) + + var scheduler handlers.Scheduler + if scheduling.PDEnabled { + scheduler = scheduling.NewPDScheduler(r.Datastore) + } else { + scheduler = scheduling.NewScheduler(r.Datastore) + } + extProcServer := handlers.NewStreamingServer(scheduler, r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore) extProcPb.RegisterExternalProcessorServer( srv, extProcServer, diff --git a/pkg/epp/server/runserver_test.go b/pkg/epp/server/runserver_test.go index b02688c58..0cb52d6d2 100644 --- a/pkg/epp/server/runserver_test.go +++ b/pkg/epp/server/runserver_test.go @@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - package server_test import ( @@ -25,6 +24,9 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +// Define a variable with the manager package type to explicitly show usage to linter +var _ manager.LeaderElectionRunnable = nil + func TestRunnable(t *testing.T) { // Make sure AsRunnable() does not use leader election. runner := server.NewDefaultExtProcServerRunner().AsRunnable(logutil.NewTestLogger()) diff --git a/scripts/kubernetes-dev-env.sh b/scripts/kubernetes-dev-env.sh index 21564e9cc..e9d92c174 100755 --- a/scripts/kubernetes-dev-env.sh +++ b/scripts/kubernetes-dev-env.sh @@ -65,10 +65,10 @@ case "${VLLM_MODE}" in export LORA_ADAPTER_SYNCER_TAG="${LORA_ADAPTER_SYNCER_TAG:-v20250425-ddc3d69}" elif [[ "$VLLM_MODE" == "vllm-p2p" ]]; then - export VLLM_IMAGE="${VLLM_IMAGE:-lmcache/vllm-openai}" - export VLLM_TAG="${VLLM_TAG:-2025-03-10}" - export EPP_IMAGE="${EPP_IMAGE:- quay.io/vmaroon/gateway-api-inference-extension/epp}" - export EPP_TAG="${EPP_TAG:-kv-aware}" + export VLLM_IMAGE="${VLLM_IMAGE:-quay.io/llm-d/llm-d-dev}" + export VLLM_TAG="${VLLM_TAG:-lmcache-0.0.6-amd64}" + export EPP_IMAGE="${EPP_IMAGE:-quay.io/llm-d/llm-d-gateway-api-inference-extension-dev}" + export EPP_TAG="${EPP_TAG:-0.0.5-amd64}" export MAX_MODEL_LEN="${MAX_MODEL_LEN:-32768}" export PVC_NAME="${PVC_NAME:-vllm-p2p-storage-claim}" export PVC_ACCESS_MODE="${PVC_ACCESS_MODE:-ReadWriteOnce}"