diff --git a/Makefile b/Makefile index 7f613ea689..6c1fb44a25 100644 --- a/Makefile +++ b/Makefile @@ -313,12 +313,13 @@ cos-gpu-smoke-tests: gpu-smoke-images $(RUNTIME_BIN) gpu-images: gpu-smoke-images load-gpu_pytorch load-gpu_ollama load-gpu_ollama_client load-basic_busybox load-basic_alpine load-basic_python load-gpu_stable-diffusion-xl load-gpu_vllm load-gpu_nccl-tests load-benchmarks_ffmpeg .PHONY: gpu-images -l4-gpu-images: load-gpu_sglang load-gpu_sglang_client +l4-gpu-images: load-gpu_sglang load-gpu_sglang_client load-gpu_triton load-gpu_triton_client .PHONY: l4-gpu-images l4-gpu-tests: l4-gpu-images $(RUNTIME_BIN) @$(call install_runtime,$(RUNTIME),--nvproxy=true --nvproxy-docker=true --nvproxy-allowed-driver-capabilities=all) @$(call sudo,test/gpu:sglang_test,--runtime=$(RUNTIME) -test.v $(ARGS)) + @$(call sudo,test/gpu:triton_test,--runtime=$(RUNTIME) -test.v $(ARGS)) .PHONY: l4-gpu-tests gpu-all-tests: gpu-images gpu-smoke-tests $(RUNTIME_BIN) diff --git a/images/gpu/triton/Dockerfile.x86_64 b/images/gpu/triton/Dockerfile.x86_64 new file mode 100644 index 0000000000..ae3d578396 --- /dev/null +++ b/images/gpu/triton/Dockerfile.x86_64 @@ -0,0 +1,79 @@ +# --- Downloader Stage --- +# Fetches model/tokenizer assets from GCS +FROM google/cloud-sdk:541.0.0-slim AS downloader +RUN gcloud config set auth/disable_credentials true +RUN gsutil -m cp -r gs://gvisor/tests/models/llama-2-7b-chat-hf / +RUN mkdir -p /engines +RUN gsutil -m cp -r gs://gvisor/tests/l4/engines/llama-2-7b-chat-hf /engines/ + +# --- Builder Stage for TensorRT-LLM --- +# This stage uses 'git sparse-checkout' to download *only* the +# files we need, which is much faster than 'git clone' and avoids svn. +FROM nvcr.io/nvidia/tritonserver:25.08-trtllm-python-py3 AS trtllm_builder + +WORKDIR / + +# 1. Clone an empty "blob-less" repo. This is very fast. +RUN git clone --filter=blob:none --no-checkout --depth 1 \ + https://github.com/NVIDIA/TensorRT-LLM.git /TensorRT-LLM +WORKDIR /TensorRT-LLM + +# 2. Set up sparse checkout to define *only* the paths we need +RUN git sparse-checkout init --cone && \ + git sparse-checkout set \ + "triton_backend/all_models/inflight_batcher_llm/" \ + "triton_backend/tools/" + +# 3. Now, check out the v1.2.0rc1 tag. +# This will download *only* the files in the two directories above. +RUN git checkout 796891ba2a6959bad58c0da9645416c7264349e9 + +# --- Final Stage --- +# This is our final runtime image. +# NO CHANGES are needed here. The COPY commands work perfectly +# because the builder stage created the identical paths. +FROM nvcr.io/nvidia/tritonserver:25.08-trtllm-python-py3 + +# --- Build Arguments --- +ARG TOKENIZER_DIR=/llama-2-7b-chat-hf +ARG ENGINE_DIR=/engines/llama-2-7b-chat-hf/fp8/1-gpu +ARG MAX_BATCH_SIZE=1 +ARG INSTANCE_COUNT=1 +ARG TOKENIZER_TYPE=auto +ARG DECOUPLED_MODE=true +ARG MODEL_FOLDER=/models/ +ARG MAX_QUEUE_DELAY_MS=10000 +ARG TRITON_BACKEND=tensorrtllm +ARG LOGITS_DATATYPE="TYPE_FP32" +ARG FILL_TEMPLATE_SCRIPT=/TensorRT-LLM/triton_backend/tools/fill_template.py + +# --- Asset Copying --- + +# Copy only the tokenizer (needed for config) +COPY --from=downloader ${TOKENIZER_DIR} ${TOKENIZER_DIR} + +# Copy *only* the model templates from the trtllm_builder stage +COPY --from=trtllm_builder /TensorRT-LLM/triton_backend/all_models/inflight_batcher_llm ${MODEL_FOLDER} + +# Copy *only* the build script we need from the trtllm_builder stage +COPY --from=trtllm_builder ${FILL_TEMPLATE_SCRIPT} /usr/local/bin/fill_template.py +ARG FILL_TEMPLATE_SCRIPT=/usr/local/bin/fill_template.py # Update ARG to new path + +# Copy *only* the specific engine directory we need, directly +# from the downloader into the final model repository path. +COPY --from=downloader ${ENGINE_DIR} ${MODEL_FOLDER}/tensorrt_llm/1/ + +# --- Model Configuration --- +# Run the template-filling commands and clean up the script +RUN python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/preprocessing/config.pbtxt \ + tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${MAX_BATCH_SIZE},preprocessing_instance_count:${INSTANCE_COUNT} && \ + python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/postprocessing/config.pbtxt \ + tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${MAX_BATCH_SIZE},postprocessing_instance_count:${INSTANCE_COUNT} && \ + python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm_bls/config.pbtxt \ + prompt_embedding_table_data_type:TYPE_FP16,triton_max_batch_size:${MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},bls_instance_count:${INSTANCE_COUNT},logits_datatype:${LOGITS_DATATYPE} && \ + python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/ensemble/config.pbtxt \ + triton_max_batch_size:${MAX_BATCH_SIZE},logits_datatype:${LOGITS_DATATYPE} && \ + python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm/config.pbtxt \ + prompt_embedding_table_data_type:TYPE_FP16,triton_backend:${TRITON_BACKEND},triton_max_batch_size:${MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},engine_dir:${MODEL_FOLDER}/tensorrt_llm/1,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MS},batching_strategy:inflight_fused_batching,encoder_input_features_data_type:TYPE_FP16,logits_datatype:${LOGITS_DATATYPE} + +CMD ["tritonserver", "--model-repository=/models/"] \ No newline at end of file diff --git a/images/gpu/triton/client/BUILD b/images/gpu/triton/client/BUILD new file mode 100644 index 0000000000..c322daa431 --- /dev/null +++ b/images/gpu/triton/client/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_binary") + +package( + default_applicable_licenses = ["//:license"], + licenses = ["notice"], +) + +go_binary( + name = "client", + srcs = ["client.go"], +) diff --git a/images/gpu/triton/client/Dockerfile b/images/gpu/triton/client/Dockerfile new file mode 100644 index 0000000000..e4429a295f --- /dev/null +++ b/images/gpu/triton/client/Dockerfile @@ -0,0 +1,8 @@ +FROM golang:1.22 AS builder + +COPY client.go /client.go +RUN CGO_ENABLED=0 go build -o /httpclient /client.go + +FROM alpine:latest +COPY --from=builder /httpclient /usr/bin/ +CMD ["/usr/bin/httpclient"] \ No newline at end of file diff --git a/images/gpu/triton/client/client.go b/images/gpu/triton/client/client.go new file mode 100644 index 0000000000..5a71674f77 --- /dev/null +++ b/images/gpu/triton/client/client.go @@ -0,0 +1,155 @@ +// Copyright 2025 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// A simple `curl`-like HTTP client that prints metrics after the request. +// All of its output is structured to be unambiguous even if stdout/stderr +// is combined, as is the case for Kubernetes logs. +// Useful for communicating with SGLang. +package main + +import ( + "bufio" + "bytes" + "encoding/base64" + "encoding/json" + "flag" + "fmt" + "net/http" + "os" + "sort" + "strings" + "time" +) + +// LINT.IfChange + +// Flags. +var ( + url = flag.String("url", "", "HTTP request URL.") + method = flag.String("method", "GET", "HTTP request method (GET or POST).") + postDataBase64 = flag.String("post_base64", "", "HTTP request POST data in base64 format; ignored for GET requests.") + timeout = flag.Duration("timeout", 0, "HTTP request timeout; 0 for no timeout.") +) + +// bufSize is the size of buffers used for HTTP requests and responses. +const bufSize = 1024 * 1024 // 1MiB + +// fatalf crashes the program with a given error message. +func fatalf(format string, values ...any) { + fmt.Fprintf(os.Stderr, "FATAL: "+format+"\n", values...) + os.Exit(1) +} + +// Metrics contains the request metrics to export to JSON. +// This is parsed by the sglang library at `test/gpu/sglang/sglang.go`. +type Metrics struct { + // ProgramStarted is the time when the program started. + ProgramStarted time.Time `json:"program_started"` + // RequestSent is the time when the HTTP request was sent. + RequestSent time.Time `json:"request_sent"` + // ResponseReceived is the time when the HTTP response headers were received. + ResponseReceived time.Time `json:"response_received"` + // FirstByteRead is the time when the first HTTP response body byte was read. + FirstByteRead time.Time `json:"first_byte_read"` + // LastByteRead is the time when the last HTTP response body byte was read. + LastByteRead time.Time `json:"last_byte_read"` +} + +func main() { + var metrics Metrics + metrics.ProgramStarted = time.Now() + flag.Parse() + if *url == "" { + fatalf("--url is required") + } + client := http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 1, + IdleConnTimeout: *timeout, + ReadBufferSize: bufSize, + WriteBufferSize: bufSize, + }, + Timeout: *timeout, + } + var request *http.Request + var err error + switch *method { + case "GET": + request, err = http.NewRequest("GET", *url, nil) + case "POST": + postData, postDataErr := base64.StdEncoding.DecodeString(*postDataBase64) + if postDataErr != nil { + fatalf("cannot decode POST data: %v", postDataErr) + } + request, err = http.NewRequest("POST", *url, bytes.NewBuffer(postData)) + default: + err = fmt.Errorf("unknown method %q", *method) + } + if err != nil { + fatalf("cannot create request: %v", err) + } + orderedReqHeaders := make([]string, 0, len(request.Header)) + for k := range request.Header { + orderedReqHeaders = append(orderedReqHeaders, k) + } + sort.Strings(orderedReqHeaders) + for _, k := range orderedReqHeaders { + for _, v := range request.Header[k] { + fmt.Fprintf(os.Stderr, "REQHEADER: %s: %s\n", k, v) + } + } + metrics.RequestSent = time.Now() + resp, err := client.Do(request) + metrics.ResponseReceived = time.Now() + if err != nil { + fatalf("cannot make request: %v", err) + } + gotFirstByte := false + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + if !gotFirstByte { + metrics.FirstByteRead = time.Now() + gotFirstByte = true + } + if scanner.Text() == "" { + continue + } + fmt.Printf("BODY: %q\n", strings.TrimPrefix(scanner.Text(), "data: ")) + } + // Check for any errors that may have occurred during scanning + if err := scanner.Err(); err != nil { + fatalf("error reading response body: %v", err) + } + metrics.LastByteRead = time.Now() + if err := resp.Body.Close(); err != nil { + fatalf("cannot close response body: %v", err) + } + orderedRespHeaders := make([]string, 0, len(resp.Header)) + for k := range resp.Header { + orderedRespHeaders = append(orderedRespHeaders, k) + } + sort.Strings(orderedRespHeaders) + for _, k := range orderedRespHeaders { + for _, v := range resp.Header[k] { + fmt.Fprintf(os.Stderr, "RESPHEADER: %s: %s\n", k, v) + } + } + metricsBytes, err := json.Marshal(&metrics) + if err != nil { + fatalf("cannot marshal metrics: %v", err) + } + fmt.Fprintf(os.Stderr, "STATS: %s\n", string(metricsBytes)) +} + +// LINT.ThenChange(../../ollama/client/client.go) diff --git a/images/gpu/triton/tensorrt/Dockerfile.llama-2-7b-chat-hf b/images/gpu/triton/tensorrt/Dockerfile.llama-2-7b-chat-hf new file mode 100644 index 0000000000..01b8090218 --- /dev/null +++ b/images/gpu/triton/tensorrt/Dockerfile.llama-2-7b-chat-hf @@ -0,0 +1,53 @@ +# Use the official NVIDIA CUDA image as the base. +FROM nvidia/cuda:12.8.1-devel-ubuntu22.04 + +# Set the default shell to bash. +SHELL ["/bin/bash", "-c"] + +# Consolidate system dependency installation into a single RUN command +# to reduce the number of layers in the final image. +RUN apt-get update && apt-get install -y \ + neovim \ + git \ + openmpi-bin \ + libopenmpi-dev \ + python3.10 \ + python3.10-dev \ + python3-pip \ + python3-venv \ + python-is-python3 && \ + # Clean up the apt cache to reduce image size. + rm -rf /var/lib/apt/lists/* + +# Download TensorRT-LLM from the specified version tag. +ARG TENSORRT_LLM_VERSION="1.0.0" +ARG TENSORRT_LLM_DIR="/TensorRT-LLM-${TENSORRT_LLM_VERSION}" +RUN git clone --depth 1 --branch "v${TENSORRT_LLM_VERSION}" https://github.com/NVIDIA/TensorRT-LLM.git "${TENSORRT_LLM_DIR}" + +# Create a Python virtual environment and add its bin directory to the system's PATH. +# This makes commands from the venv (like pip, huggingface-cli) available in all subsequent layers. +ENV VENV_PATH="/opt/venv" +RUN python3 -m venv "${VENV_PATH}" +ENV PATH="${VENV_PATH}/bin:${PATH}" + +# Upgrade pip and install the huggingface_hub library. +RUN pip install --upgrade pip +RUN pip install huggingface_hub + +# Download the model from Hugging Face. +# The HF_TOKEN should be passed as a build argument for security. +ARG HF_TOKEN="" +ARG REPO_ID="meta-llama/Llama-2-7b-chat-hf" +ARG MODEL_DIR="/llama-2-7b-chat-hf" +RUN huggingface-cli download \ + "${REPO_ID}" \ + --local-dir "${MODEL_DIR}" \ + --local-dir-use-symlinks False \ + --token "${HF_TOKEN}" + +# Set the working directory to the Llama example within the TensorRT-LLM repository. +WORKDIR "${TENSORRT_LLM_DIR}/examples/models/core/llama" + +# Install the Python dependencies required for the Llama example. +# This command will use the pip from the virtual environment we added to the PATH. +RUN pip install -r requirements.txt \ No newline at end of file diff --git a/images/gpu/triton/tensorrt/README.md b/images/gpu/triton/tensorrt/README.md new file mode 100644 index 0000000000..57f89f2f06 --- /dev/null +++ b/images/gpu/triton/tensorrt/README.md @@ -0,0 +1,94 @@ +# Building TensorRT engine for Llama2-7B-Chat-HF model + +This guide provides instructions for building TensorRT engine files for the +Llama2-7B-Chat-HF model. + +## 1. Create a Google Cloud VM with GPU + +First, create a Google Cloud VM with the necessary accelerator. + +Set the following environment variables for your project: + +```bash +export IMAGE="common-cu128-ubuntu-2204-nvidia-570-v20251009" +export ZONE="us-central1-a" +export INSTANCE_NAME="model-prep" +export MACHINE_TYPE="g2-standard-32" +export ACCELERATOR="type=nvidia-l4,count=1" +export PROJECT="" +``` + +Then, create the instance using the following command: + +```bash +gcloud compute instances create $INSTANCE_NAME \ + --zone=$ZONE \ + --image=$IMAGE \ + --machine-type=$MACHINE_TYPE \ + --image-project=deeplearning-platform-release \ + --maintenance-policy=TERMINATE \ + --accelerator=$ACCELERATOR \ + --metadata="install-nvidia-driver=True" \ + --boot-disk-size=4TB \ + --boot-disk-type=pd-ssd \ + --boot-disk-device-name=boot-disk \ + --no-shielded-secure-boot \ + --project=$PROJECT +``` + +## 2. Install Dependencies + +Connect to the newly created VM and install the required packages: + +```bash +sudo apt-get update +sudo apt-get install -y ca-certificates curl gnupg neovim python3-dev +``` + +## 3. Install Docker + +Install Docker on the VM: + +Restart Docker and add your user to the `docker` group to run Docker commands +without `sudo`: + +```bash +sudo systemctl restart docker +sudo usermod -a -G docker $USER +newgrp docker +``` + +## 4. Build and Run Docker Container + +In your VM, create `Dockerfile.llama2-7b-chat-hf` with the content of +`images/gpu/triton/tensorrt/Dockerfile.llama-2-7b-chat-hf`. Don't forget to +provide your own HF token. + +Build the Docker image: + +```bash +docker build . -f Dockerfile.llama2-7b-chat-hf -t tensorrt:llama2-7b-chat-hf +``` + +Run the Docker container: + +```bash +docker run --rm -it --net host --shm-size=25g --ulimit memlock=-1 --ulimit stack=67108864 --gpus all -p 8000:8000 tensorrt:llama2-7b-chat-hf bash +``` + +## 5. Convert Model and Build TensorRT Engine + +Inside the container, run the following commands to generate the model +checkpoint, quantize it, and build the TensorRT engine. + +Convert the checkpoint: `bash python convert_checkpoint.py --model_dir +/llama-2-7b-chat-hf \ --output_dir /tllm_checkpoint_1gpu_tp1 \ --dtype float16 \ +--tp_size 1` + +Quantize the model: `bash python3 ../../../quantization/quantize.py +--dtype=float16 --output_dir /tllm_checkpoint_1gpu_tp1 --model_dir +/llama-2-7b-chat-hf --qformat=fp8 --kv_cache_dtype=fp8 --tp_size 1` + +Build the TensorRT engine: `bash trtllm-build --checkpoint_dir +/tllm_checkpoint_1gpu_tp1 \ --output_dir /engines/llama-2-7b-chat-hf/fp8/1-gpu/ +\ --gemm_plugin auto \ --max_batch_size 1` diff --git a/test/gpu/BUILD b/test/gpu/BUILD index 9dbb7c04d0..9a3b2fcdc9 100644 --- a/test/gpu/BUILD +++ b/test/gpu/BUILD @@ -175,6 +175,23 @@ go_test( deps = ["//pkg/test/dockerutil"], ) +go_test( + name = "triton_test", + srcs = ["triton_test.go"], + # runsc is needed to invalidate the bazel cache in case of any code changes. + data = ["//runsc"], + tags = [ + "manual", + "noguitar", + "notap", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/gpu/triton", + ], +) + proto_library( name = "gpu_driver_versions", testonly = True, diff --git a/test/gpu/triton/BUILD b/test/gpu/triton/BUILD new file mode 100644 index 0000000000..2ea2bdd796 --- /dev/null +++ b/test/gpu/triton/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library") + +package( + default_applicable_licenses = ["//:license"], + licenses = ["notice"], +) + +go_library( + name = "triton", + testonly = 1, + srcs = ["triton.go"], + stateify = False, # Does not support some generics methods. + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//pkg/test/testutil", + ], +) diff --git a/test/gpu/triton/triton.go b/test/gpu/triton/triton.go new file mode 100644 index 0000000000..4a04c994d1 --- /dev/null +++ b/test/gpu/triton/triton.go @@ -0,0 +1,575 @@ +// Copyright 2023 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package triton provides a Triton API client. +package triton + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +const ( + // Port is the port used by the triton server. + Port = 8000 + + // curtQuery is a query that should result in a very curt response. + curtQuery = `Reply with the single word: "Hello". Do not reply with any other word.` +) + +// Triton is a triton client. +type Triton struct { + // server is used to perform requests against the server. + server Server + + // logger is used to log. + logger testutil.Logger +} + +// Server performs requests against a triton server. +type Server interface { + // InstrumentedRequest performs an instrumented HTTP request against the + // triton server, using the `gpu/triton_client` triton image. + // `argvFn` takes in a `protocol://host:port` string and returns a + // command-line to use for making an instrumented HTTP request against the + // triton server. + // InstrumentedRequest should return the logs from the request container. + InstrumentedRequest(ctx context.Context, argvFn func(hostPort string) []string) ([]byte, error) + + // Logs retrieves logs from the server. + Logs(ctx context.Context) (string, error) +} + +// New starts a new Triton server in the given container, +// then waits for it to serve and returns the client. +func New(ctx context.Context, server Server, logger testutil.Logger) (*Triton, error) { + started := time.Now() + llm := &Triton{ + logger: logger, + server: server, + } + + // Wait until serving. + if err := llm.WaitUntilServing(ctx); err != nil { + return nil, fmt.Errorf("triton did not come up for serving: %w", err) + } + + logger.Logf("Triton serving API requests after %v", time.Since(started)) + + // Run a warmup query to force the model to load. + _, err := llm.WarmModel(ctx) + if err != nil { + return nil, fmt.Errorf("could not warmup the model: %w", err) + } + logger.Logf("Loaded triton model. (%v since container start)", time.Since(started)) + + logger.Logf("Triton successfully initialized in a total of %v", time.Since(started)) + return llm, nil +} + +// ModelLoadStats holds metrics about the model loading process. +type ModelLoadStats struct { + // ClientReportedDuration is the duration to load the model as perceived + // by the client, measured by HTTP client metrics. + ClientReportedDuration time.Duration +} + +// WarmModel pre-warms a model in memory and keeps it warm for `keepWarmFor`. +// If `unloadFirst` is true, another model will be loaded before loading the +// requested model. This ensures that the model was loaded from a cold state. +func (llm *Triton) WarmModel(ctx context.Context) (*ModelLoadStats, error) { + prompt := ZeroTemperaturePrompt(curtQuery, 10) + resp, err := llm.Prompt(ctx, prompt) + if err != nil { + return nil, llm.withServerLogsErr(ctx, fmt.Errorf("warmup prompt (%s) failed: %w", prompt.TextInput, err)) + } + return &ModelLoadStats{ + ClientReportedDuration: resp.metrics.TimeToFirstByte(), + }, nil +} + +// dockerServer implements `Server`. It interfaces with a triton server +// running in a local Docker container. +type dockerServer struct { + container *dockerutil.Container + logger testutil.Logger +} + +// NewDocker returns a new Triton client talking to a Triton server that runs +// in a local Docker container. +func NewDocker(ctx context.Context, cont *dockerutil.Container, logger testutil.Logger) (*Triton, error) { + opts, err := dockerutil.GPURunOpts(dockerutil.SniffGPUOpts{}) + if err != nil { + return nil, fmt.Errorf("failed to get GPU run options: %w", err) + } + opts.Image = "gpu/triton" + started := time.Now() + if err := cont.Spawn(ctx, opts); err != nil { + return nil, fmt.Errorf("could not start triton: %v", err) + } + logger.Logf("Triton container started after %v", time.Since(started)) + ds := &dockerServer{ + container: cont, + logger: logger, + } + return New(ctx, ds, logger) +} + +// InstrumentedRequest implements `Server.InstrumentedRequest`. +func (ds *dockerServer) InstrumentedRequest(ctx context.Context, argvFn func(hostPort string) []string) ([]byte, error) { + const tritonHost = "llm" + cmd := argvFn(fmt.Sprintf("http://%s:%d", tritonHost, Port)) + out, err := dockerutil.MakeContainer(ctx, ds.logger).Run(ctx, dockerutil.RunOpts{ + Image: "gpu/triton/client", + Links: []string{ds.container.MakeLink(tritonHost)}, + }, cmd...) + if err != nil { + if out != "" { + return []byte(out), fmt.Errorf("command %q failed (%w): %v", strings.Join(cmd, " "), err, out) + } + return nil, fmt.Errorf("could not run command %q: %w", strings.Join(cmd, " "), err) + } + return []byte(out), nil +} + +// Logs implements `Server.Logs`. +func (ds *dockerServer) Logs(ctx context.Context) (string, error) { + return ds.container.Logs(ctx) +} + +// ResponseMetrics are HTTP request metrics from a triton API query. +// These is the same JSON struct as defined in +// `images/gpu/triton/client/client.go`. +type ResponseMetrics struct { + // ProgramStarted is the time when the program started. + ProgramStarted time.Time `json:"program_started"` + // RequestSent is the time when the HTTP request was sent. + RequestSent time.Time `json:"request_sent"` + // ResponseReceived is the time when the HTTP response headers were received. + ResponseReceived time.Time `json:"response_received"` + // FirstByteRead is the time when the first HTTP response body byte was read. + FirstByteRead time.Time `json:"first_byte_read"` + // LastByteRead is the time when the last HTTP response body byte was read. + LastByteRead time.Time `json:"last_byte_read"` +} + +// TimeToFirstByte returns the duration it took between the request being sent +// and the first byte of the response being read. +func (rm *ResponseMetrics) TimeToFirstByte() time.Duration { + return rm.FirstByteRead.Sub(rm.RequestSent) +} + +// TimeToLastByte returns the duration it took between the request being sent +// and the last byte of the response being read. +func (rm *ResponseMetrics) TimeToLastByte() time.Duration { + return rm.LastByteRead.Sub(rm.RequestSent) +} + +// apiResponse represents a JSON response from the triton API. +type apiResponse[T any] struct { + // Objects is the list of JSON objects in the response. + Objects []*T + // Metrics contains HTTP response metrics. + Metrics ResponseMetrics +} + +// Obj returns the first object in the response, if there is a singular +// object in the response. +func (ar *apiResponse[T]) Obj() (*T, error) { + if len(ar.Objects) == 0 { + return nil, fmt.Errorf("no objects in response") + } + if len(ar.Objects) > 1 { + return nil, fmt.Errorf("multiple objects in response") + } + return ar.Objects[0], nil +} + +// makeAPIResponse decodes a raw response from an instrumented HTTP request +// into an `apiResponse` with deserialized JSON objects. +func makeAPIResponse[T any](rawResponse []byte) (*apiResponse[T], error) { + var respBytes bytes.Buffer + var resp apiResponse[T] + for _, line := range strings.Split(string(rawResponse), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + colonIndex := strings.Index(line, ":") + if colonIndex == -1 { + return nil, fmt.Errorf("malformed line: %q", line) + } + data := strings.TrimSpace(line[colonIndex+1:]) + switch line[:colonIndex] { + case "FATAL": + return nil, fmt.Errorf("request failed: %s", data) + case "REQHEADER", "RESPHEADER": + // Do nothing with these. + case "BODY": + unquoted, err := strconv.Unquote(data) + if err != nil { + return nil, fmt.Errorf("malformed body line: %q", data) + } + respBytes.WriteString(unquoted) + case "STATS": + if err := json.Unmarshal([]byte(data), &resp.Metrics); err != nil { + return nil, fmt.Errorf("malformed stats line: %q", data) + } + default: + return nil, fmt.Errorf("malformed line: %q", line) + } + } + decoder := json.NewDecoder(&respBytes) + for { + var obj T + err := decoder.Decode(&obj) + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("malformed JSON response: %w", err) + } + resp.Objects = append(resp.Objects, &obj) + } + if len(resp.Objects) == 0 { + return nil, fmt.Errorf("response is empty") + } + leftoverBytes, err := io.ReadAll(decoder.Buffered()) + if err != nil && err != io.EOF { + return nil, fmt.Errorf("could not read leftover bytes: %w", err) + } + if leftover := strings.TrimSpace(string(leftoverBytes)); leftover != "" { + return nil, fmt.Errorf("unprocessed bytes in response: %q", leftover) + } + return &resp, nil +} + +// instrumentedRequest makes an HTTP request to the triton API. +// It returns the raw bytestream from the instrumented request logs. +func (llm *Triton) instrumentedRequest(ctx context.Context, method, endpoint string, data []byte) ([]byte, error) { + if endpoint != "" && !strings.HasPrefix(endpoint, "/") { + return nil, fmt.Errorf("endpoint must be empty or start with '/', got %q", endpoint) + } + argvFn := func(hostPort string) []string { + argv := []string{ + "httpclient", + fmt.Sprintf("--method=%s", method), + fmt.Sprintf("--url=%s%s", hostPort, endpoint), + } + if data != nil { + argv = append(argv, fmt.Sprintf("--post_base64=%s", base64.StdEncoding.EncodeToString(data))) + } + if ctxDeadline, hasDeadline := ctx.Deadline(); hasDeadline { + argv = append(argv, fmt.Sprintf("--timeout=%v", time.Until(ctxDeadline))) + } + return argv + } + rawResponse, err := llm.server.InstrumentedRequest(ctx, argvFn) + if err != nil { + return nil, fmt.Errorf("%s: %w", endpoint, err) + } + return rawResponse, nil +} + +// jsonGet performs a JSON HTTP GET request. +func jsonGet[Out any](ctx context.Context, llm *Triton, endpoint string) (*apiResponse[Out], error) { + out, err := llm.instrumentedRequest(ctx, "GET", endpoint, nil) + if err != nil { + return nil, fmt.Errorf("GET %q failed: %w", endpoint, err) + } + return makeAPIResponse[Out](out) +} + +// jsonPost performs a JSON HTTP POST request. +func jsonPost[In, Out any](ctx context.Context, llm *Triton, endpoint string, input In) (*apiResponse[Out], error) { + query, err := json.Marshal(input) + if err != nil { + return nil, fmt.Errorf("could not marshal input %v: %w", input, err) + } + out, err := llm.instrumentedRequest(ctx, "POST", endpoint, query) + if err != nil { + return nil, fmt.Errorf("POST %q %v failed: %w", endpoint, string(query), err) + } + + return makeAPIResponse[Out](out) +} + +// WaitUntilServing waits until triton is serving, or the context expires. +func (llm *Triton) WaitUntilServing(ctx context.Context) error { + for ctx.Err() == nil { + _, err := llm.instrumentedRequest(ctx, "GET", "/v2/health/ready", nil) + if err != nil { + continue + } + return nil + } + return fmt.Errorf("triton did not respond: %w", ctx.Err()) +} + +// temperatureOption is the temperature option that most models have +// which controls how free they are from deviating from their most-likely +// token chain. +const temperatureOption = "temperature" +const streamOption = "stream" + +// RaiseTemperature increases the "temperature" option of the model, +// if any. +func (p *Prompt) RaiseTemperature() { + temp, ok := p.Options[temperatureOption] + if !ok { + temp = float64(0.0) + } + if p.Options == nil { + p.Options = map[string]any{} + } + p.Options[temperatureOption] = min(1.0, temp.(float64)*2+.025) +} + +// Copy returns a copy of the prompt. +func (p *Prompt) Copy() *Prompt { + promptCopy := *p + promptCopy.Options = make(map[string]any, len(p.Options)) + for k, v := range p.Options { + promptCopy.Options[k] = v + } + return &promptCopy +} + +// SetTemperature sets the "temperature" option of the prompt to the given +// value. +func (p *Prompt) SetTemperature(temperature float64) { + if p.Options == nil { + p.Options = map[string]any{} + } + p.Options[temperatureOption] = temperature +} + +// ZeroTemperaturePrompt returns a Prompt with the given text and an initial +// temperature setting of zero. This setting allows for consistent settings. +func ZeroTemperaturePrompt(text string, maxTokens int) *Prompt { + return &Prompt{ + TextInput: text, + MaxTokens: maxTokens, + Options: map[string]any{ + temperatureOption: 0.0, + streamOption: true, + }, + } +} + +// Prompt is a triton prompt. +type Prompt struct { + + // Text is the prompt string. + // Common leading whitespace will be removed. + TextInput string + + // MaxTokens is the maximum number of tokens to generate. + MaxTokens int + + // Options maps parameter names to JSON-compatible values. + Options map[string]any +} + +// CleanQuery removes common whitespace from query lines, and all +// leading/ending whitespace-only lines. +// It is useful to be able to specify query string as indented strings +// without breaking visual continuity in Go code. +// For example (where dots are spaces): +// +// """\n +// ..The Quick Brown Fox\n +// ..Jumps Over\n +// ....The Lazy Dog\n +// .""" +// +// becomes: +// Jumps Over\n +// ..The Lazy Dog""" +func (p *Prompt) CleanQuery() string { + lines := strings.Split(p.TextInput, "\n") + + // Trim lines at the beginning and end that are only whitespace. + trimmedLines := make([]string, 0, len(lines)) + startedNonWhitespace := false + var block []string + for _, line := range lines { + trimmedLine := strings.TrimSpace(line) + if !startedNonWhitespace && trimmedLine != "" { + startedNonWhitespace = true + } + if startedNonWhitespace { + block = append(block, line) + } + if trimmedLine != "" { + trimmedLines = append(trimmedLines, block...) + block = block[:0] + } + } + + // Find longest common whitespace prefix. + if len(trimmedLines) == 0 { + return "" + } + trimmedFirstLine := strings.TrimSpace(trimmedLines[0]) + common := []rune(trimmedLines[0][:strings.Index(trimmedLines[0], trimmedFirstLine)]) + for ; len(common) > 0; common = common[:len(common)-1] { + allMatch := true + for _, line := range trimmedLines[1:] { + if strings.TrimSpace(line) == "" { + continue // Ignore whitespace-only or empty lines. + } + if !strings.HasPrefix(line, string(common)) { + allMatch = false + break + } + } + if allMatch { + break + } + } + + // Remove it. + if len(common) > 0 { + for i, line := range trimmedLines { + trimmedLines[i] = strings.TrimPrefix(line, string(common)) + } + } + + return strings.Join(trimmedLines, "\n") +} + +// WithHotterModel returns a copy of this prompt with the same model having +// a higher temperature. +func (p *Prompt) WithHotterModel() *Prompt { + promptCopy := p.Copy() + promptCopy.RaiseTemperature() + return promptCopy +} + +// Request defines the structure for the JSON payload. +// https://docs.sglang.ai/basic_usage/sampling_params.html +type promptJSON struct { + TextInput string `json:"text_input"` + MaxTokens int `json:"max_tokens"` + Options map[string]any `json:"parameters"` +} + +// json encodes this prompt to the JSON format expected by Triton. +func (p *Prompt) json() promptJSON { + return promptJSON{ + TextInput: p.CleanQuery(), + MaxTokens: p.MaxTokens, + Options: p.Options, + } +} + +// responseJSON is the JSON-format response from triton about a prompt. +// Note that in `streamed` mode, the `Response` field contains a single token. +// To recover the whole response, all `Response` fields must be concatenated +// until the last `responseJSON`, identified as such by the `Done` field. +type responseJSON struct { + Text string `json:"text_output"` +} + +// Response represents a response to a query from Triton. +type Response struct { + data []*responseJSON + metrics ResponseMetrics +} + +// NumTokens returns the number of tokens in the response. +func (r *Response) NumTokens() int { + return len(r.data) +} + +// String returns the response text, if it is done. +func (r *Response) String() string { + if len(r.data) == 0 { + return "" + } + var fullResponse strings.Builder + for _, token := range r.data { + fullResponse.WriteString(token.Text) + } + return fullResponse.String() +} + +// Text returns the body of the response. +func (r *Response) Text() string { + return r.String() +} + +// withServerLogsErr adds server logs to `err` if possible. +func (llm *Triton) withServerLogsErr(ctx context.Context, err error) error { + if err == nil { + return nil + } + if ctx.Err() != nil { + return fmt.Errorf("%w (+ context err: %v)", err, ctx.Err()) + } + serverLogs, logsErr := llm.server.Logs(ctx) + if logsErr != nil { + return fmt.Errorf("%w (could not get server logs: %v)", err, logsErr) + } + if serverLogs != "" { + return fmt.Errorf("%w; triton server logs:\n%v\n(end of triton server logs)", err, serverLogs) + } + return fmt.Errorf("%w (server logs are empty)", err) +} + +// Prompt returns the result of prompting the given `model` with `prompt`. +func (llm *Triton) Prompt(ctx context.Context, prompt *Prompt) (*Response, error) { + resp, err := jsonPost[promptJSON, responseJSON](ctx, llm, "/v2/models/ensemble/generate_stream", prompt.json()) + if err != nil { + return nil, llm.withServerLogsErr(ctx, fmt.Errorf("prompt (%q) request failed: %w", prompt.CleanQuery(), err)) + } + return &Response{data: resp.Objects, metrics: resp.Metrics}, nil +} + +// PromptUntil repeatedly issues a prompt until `iterate` returns a nil error. +// `iterate` may optionally return an updated `Prompt` which will be used to +// follow up. This is useful to work around the flakiness of LLMs in tests. +func (llm *Triton) PromptUntil(ctx context.Context, prompt *Prompt, iterate func(*Prompt, *Response) (*Prompt, error)) (*Response, error) { + var lastResponse *Response + var lastError error + attempts := 0 + for ctx.Err() == nil { + response, err := llm.Prompt(ctx, prompt) + if err != nil { + return nil, fmt.Errorf("prompt request failed: %w", err) + } + attempts++ + newPrompt, err := iterate(prompt, response) + if err == nil { + return response, nil + } + if newPrompt != nil { + prompt = newPrompt + } + lastResponse = response + lastError = err + } + return nil, fmt.Errorf("response %q (attempt #%d with prompt %v) did not match predicate: %v", lastResponse, attempts, prompt, lastError) +} diff --git a/test/gpu/triton_test.go b/test/gpu/triton_test.go new file mode 100644 index 0000000000..d706797dbb --- /dev/null +++ b/test/gpu/triton_test.go @@ -0,0 +1,98 @@ +// Copyright 2023 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package triton_test runs triton and generates some text with it. +package triton_test + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/gpu/triton" +) + +// TestLLM tests an LLM running in a sandboxed container. +// It first asks the capital of Turkey. +// Then it asks it to write a unit test that verifies that +// given text contains the "Hello World" in Chinese. +func TestLLM(t *testing.T) { + ctx := context.Background() + // Run the LLM. + llmContainer := dockerutil.MakeContainer(ctx, t) + defer llmContainer.CleanUp(ctx) + startCtx, startCancel := context.WithTimeout(ctx, 5*time.Minute) + llm, err := triton.NewDocker(startCtx, llmContainer, t) + startCancel() + if err != nil { + t.Fatalf("Failed to start triton: %v", err) + } + + // Query it. + t.Run("knowledge test", func(t *testing.T) { + prompt := triton.ZeroTemperaturePrompt("How many legs do cats have?", 500) + promptCtx, promptCancel := context.WithTimeout(ctx, 5*time.Minute) + response, err := llm.PromptUntil(promptCtx, prompt, func(prompt *triton.Prompt, response *triton.Response) (*triton.Prompt, error) { + defer prompt.RaiseTemperature() + text := strings.TrimSpace(response.Text()) + t.Logf("The response is: %q", text) + for _, acceptableWord := range []string{ + "4", + } { + if strings.Contains(text, acceptableWord) { + return prompt, nil + } + } + return prompt, errors.New("text does not contain any of the expected words") + }) + promptCancel() + if err != nil { + t.Fatalf("request failed: %v", err) + } + answer := strings.TrimSpace(response.Text()) + t.Logf("The response to %q is: %q", prompt.TextInput, answer) + }) + if t.Failed() { + return + } + t.Run("math test", func(t *testing.T) { + prompt := triton.ZeroTemperaturePrompt("What is 9 times 10?", 500) + promptCtx, promptCancel := context.WithTimeout(ctx, 5*time.Minute) + response, err := llm.PromptUntil(promptCtx, prompt, func(prompt *triton.Prompt, response *triton.Response) (*triton.Prompt, error) { + defer prompt.RaiseTemperature() + text := strings.TrimSpace(response.Text()) + t.Logf("The response is: %q", text) + for _, acceptableWord := range []string{ + "90", + } { + if strings.Contains(text, acceptableWord) { + return prompt, nil + } + } + return prompt, errors.New("text does not contain any of the expected words") + }) + promptCancel() + if err != nil { + t.Fatalf("request failed: %v", err) + } + answer := strings.TrimSpace(response.Text()) + t.Logf("The response to %q is: %q", prompt.TextInput, answer) + }) + if t.Failed() { + return + } +} diff --git a/tools/images.mk b/tools/images.mk index edd08f95ab..4caac17ff2 100644 --- a/tools/images.mk +++ b/tools/images.mk @@ -45,7 +45,7 @@ DOCKER_BUILD_ARGS ?= REMOTE_IMAGE_PREFIX ?= us-central1-docker.pkg.dev/gvisor-presubmit/gvisor-presubmit-images LOCAL_IMAGE_PREFIX ?= gvisor.dev/images ALL_IMAGES := $(subst /,_,$(subst images/,,$(shell find images/ -name Dockerfile -o -name Dockerfile.$(ARCH) | xargs -n 1 dirname | uniq))) -NON_TEST_IMAGES := gpu/ollama/bench\|gpu/vllm +NON_TEST_IMAGES := gpu/ollama/bench\|gpu/vllm\|gpu/triton TEST_IMAGES := $(subst /,_,$(subst images/,,$(shell find images/ -name Dockerfile -o -name Dockerfile.$(ARCH) | xargs -n 1 dirname | uniq | grep -v "$(NON_TEST_IMAGES)"))) SUB_IMAGES := $(foreach image,$(ALL_IMAGES),$(if $(findstring _,$(image)),$(image),)) IMAGE_GROUPS := $(sort $(foreach image,$(SUB_IMAGES),$(firstword $(subst _, ,$(image)))))