Skip to content

Commit c340786

Browse files
AnilAltinaygvisor-bot
authored andcommitted
Docker tests for Triton+TensorRT
Presubmit test:https://buildkite.com/gvisor/pipeline/builds/38505/steps/canvas PiperOrigin-RevId: 822667497
1 parent feddee9 commit c340786

File tree

12 files changed

+1112
-2
lines changed

12 files changed

+1112
-2
lines changed

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,13 @@ cos-gpu-smoke-tests: gpu-smoke-images $(RUNTIME_BIN)
313313
gpu-images: gpu-smoke-images load-gpu_pytorch load-gpu_ollama load-gpu_ollama_client load-basic_busybox load-basic_alpine load-basic_python load-gpu_stable-diffusion-xl load-gpu_vllm load-gpu_nccl-tests load-benchmarks_ffmpeg
314314
.PHONY: gpu-images
315315

316-
l4-gpu-images: load-gpu_sglang load-gpu_sglang_client
316+
l4-gpu-images: load-gpu_sglang load-gpu_sglang_client load-gpu_triton load-gpu_triton_client
317317
.PHONY: l4-gpu-images
318318

319319
l4-gpu-tests: l4-gpu-images $(RUNTIME_BIN)
320320
@$(call install_runtime,$(RUNTIME),--nvproxy=true --nvproxy-docker=true --nvproxy-allowed-driver-capabilities=all)
321321
@$(call sudo,test/gpu:sglang_test,--runtime=$(RUNTIME) -test.v $(ARGS))
322+
@$(call sudo,test/gpu:triton_test,--runtime=$(RUNTIME) -test.v $(ARGS))
322323
.PHONY: l4-gpu-tests
323324

324325
gpu-all-tests: gpu-images gpu-smoke-tests $(RUNTIME_BIN)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# --- Downloader Stage ---
2+
# Fetches model/tokenizer assets from GCS
3+
FROM google/cloud-sdk:541.0.0-slim AS downloader
4+
RUN gcloud config set auth/disable_credentials true
5+
RUN gsutil -m cp -r gs://gvisor/tests/models/llama-2-7b-chat-hf /
6+
RUN mkdir -p /engines
7+
RUN gsutil -m cp -r gs://gvisor/tests/l4/engines/llama-2-7b-chat-hf /engines/
8+
9+
# --- Builder Stage for TensorRT-LLM ---
10+
# This stage uses 'git sparse-checkout' to download *only* the
11+
# files we need, which is much faster than 'git clone' and avoids svn.
12+
FROM nvcr.io/nvidia/tritonserver:25.08-trtllm-python-py3 AS trtllm_builder
13+
14+
WORKDIR /
15+
16+
# 1. Clone an empty "blob-less" repo. This is very fast.
17+
RUN git clone --filter=blob:none --no-checkout --depth 1 \
18+
https://github.com/NVIDIA/TensorRT-LLM.git /TensorRT-LLM
19+
20+
WORKDIR /TensorRT-LLM
21+
22+
# 2. Set up sparse checkout to define *only* the paths we need
23+
RUN git sparse-checkout init --cone && \
24+
git sparse-checkout set \
25+
"triton_backend/all_models/inflight_batcher_llm/" \
26+
"triton_backend/tools/"
27+
28+
# 3. Now, check out the v1.2.0rc1 tag.
29+
# This will download *only* the files in the two directories above.
30+
RUN git checkout 796891ba2a6959bad58c0da9645416c7264349e9
31+
32+
# --- Final Stage ---
33+
# This is our final runtime image.
34+
# NO CHANGES are needed here. The COPY commands work perfectly
35+
# because the builder stage created the identical paths.
36+
FROM nvcr.io/nvidia/tritonserver:25.08-trtllm-python-py3
37+
38+
# --- Build Arguments ---
39+
ARG TOKENIZER_DIR=/llama-2-7b-chat-hf
40+
ARG ENGINE_DIR=/engines/llama-2-7b-chat-hf/fp8/1-gpu
41+
ARG MAX_BATCH_SIZE=1
42+
ARG INSTANCE_COUNT=1
43+
ARG TOKENIZER_TYPE=auto
44+
ARG DECOUPLED_MODE=true
45+
ARG MODEL_FOLDER=/models/
46+
ARG MAX_QUEUE_DELAY_MS=10000
47+
ARG TRITON_BACKEND=tensorrtllm
48+
ARG LOGITS_DATATYPE="TYPE_FP32"
49+
ARG FILL_TEMPLATE_SCRIPT=/TensorRT-LLM/triton_backend/tools/fill_template.py
50+
51+
# --- Asset Copying ---
52+
53+
# Copy only the tokenizer (needed for config)
54+
COPY --from=downloader ${TOKENIZER_DIR} ${TOKENIZER_DIR}
55+
56+
# Copy *only* the model templates from the trtllm_builder stage
57+
COPY --from=trtllm_builder /TensorRT-LLM/triton_backend/all_models/inflight_batcher_llm ${MODEL_FOLDER}
58+
59+
# Copy *only* the build script we need from the trtllm_builder stage
60+
COPY --from=trtllm_builder ${FILL_TEMPLATE_SCRIPT} /usr/local/bin/fill_template.py
61+
ARG FILL_TEMPLATE_SCRIPT=/usr/local/bin/fill_template.py # Update ARG to new path
62+
63+
# Copy *only* the specific engine directory we need, directly
64+
# from the downloader into the final model repository path.
65+
COPY --from=downloader ${ENGINE_DIR} ${MODEL_FOLDER}/tensorrt_llm/1/
66+
67+
# --- Model Configuration ---
68+
# Run the template-filling commands and clean up the script
69+
RUN python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/preprocessing/config.pbtxt \
70+
tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${MAX_BATCH_SIZE},preprocessing_instance_count:${INSTANCE_COUNT} && \
71+
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/postprocessing/config.pbtxt \
72+
tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${MAX_BATCH_SIZE},postprocessing_instance_count:${INSTANCE_COUNT} && \
73+
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm_bls/config.pbtxt \
74+
prompt_embedding_table_data_type:TYPE_FP16,triton_max_batch_size:${MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},bls_instance_count:${INSTANCE_COUNT},logits_datatype:${LOGITS_DATATYPE} && \
75+
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/ensemble/config.pbtxt \
76+
triton_max_batch_size:${MAX_BATCH_SIZE},logits_datatype:${LOGITS_DATATYPE} && \
77+
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm/config.pbtxt \
78+
prompt_embedding_table_data_type:TYPE_FP16,triton_backend:${TRITON_BACKEND},triton_max_batch_size:${MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},engine_dir:${MODEL_FOLDER}/tensorrt_llm/1,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MS},batching_strategy:inflight_fused_batching,encoder_input_features_data_type:TYPE_FP16,logits_datatype:${LOGITS_DATATYPE}
79+
80+
CMD ["tritonserver", "--model-repository=/models/"]

images/gpu/triton/client/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
load("//tools:defs.bzl", "go_binary")
2+
3+
package(
4+
default_applicable_licenses = ["//:license"],
5+
licenses = ["notice"],
6+
)
7+
8+
go_binary(
9+
name = "client",
10+
srcs = ["client.go"],
11+
)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
FROM golang:1.22 AS builder
2+
3+
COPY client.go /client.go
4+
RUN CGO_ENABLED=0 go build -o /httpclient /client.go
5+
6+
FROM alpine:latest
7+
COPY --from=builder /httpclient /usr/bin/
8+
CMD ["/usr/bin/httpclient"]

images/gpu/triton/client/client.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Copyright 2024 The gVisor Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// A simple `curl`-like HTTP client that prints metrics after the request.
16+
// All of its output is structured to be unambiguous even if stdout/stderr
17+
// is combined, as is the case for Kubernetes logs.
18+
// Useful for communicating with SGLang.
19+
package main
20+
21+
import (
22+
"bufio"
23+
"bytes"
24+
"encoding/base64"
25+
"encoding/json"
26+
"flag"
27+
"fmt"
28+
"net/http"
29+
"os"
30+
"sort"
31+
"strings"
32+
"time"
33+
)
34+
35+
// LINT.IfChange
36+
37+
// Flags.
38+
var (
39+
url = flag.String("url", "", "HTTP request URL.")
40+
method = flag.String("method", "GET", "HTTP request method (GET or POST).")
41+
postDataBase64 = flag.String("post_base64", "", "HTTP request POST data in base64 format; ignored for GET requests.")
42+
timeout = flag.Duration("timeout", 0, "HTTP request timeout; 0 for no timeout.")
43+
)
44+
45+
// bufSize is the size of buffers used for HTTP requests and responses.
46+
const bufSize = 1024 * 1024 // 1MiB
47+
48+
// fatalf crashes the program with a given error message.
49+
func fatalf(format string, values ...any) {
50+
fmt.Fprintf(os.Stderr, "FATAL: "+format+"\n", values...)
51+
os.Exit(1)
52+
}
53+
54+
// Metrics contains the request metrics to export to JSON.
55+
// This is parsed by the sglang library at `test/gpu/sglang/sglang.go`.
56+
type Metrics struct {
57+
// ProgramStarted is the time when the program started.
58+
ProgramStarted time.Time `json:"program_started"`
59+
// RequestSent is the time when the HTTP request was sent.
60+
RequestSent time.Time `json:"request_sent"`
61+
// ResponseReceived is the time when the HTTP response headers were received.
62+
ResponseReceived time.Time `json:"response_received"`
63+
// FirstByteRead is the time when the first HTTP response body byte was read.
64+
FirstByteRead time.Time `json:"first_byte_read"`
65+
// LastByteRead is the time when the last HTTP response body byte was read.
66+
LastByteRead time.Time `json:"last_byte_read"`
67+
}
68+
69+
func main() {
70+
var metrics Metrics
71+
metrics.ProgramStarted = time.Now()
72+
flag.Parse()
73+
if *url == "" {
74+
fatalf("--url is required")
75+
}
76+
client := http.Client{
77+
Transport: &http.Transport{
78+
MaxIdleConns: 1,
79+
IdleConnTimeout: *timeout,
80+
ReadBufferSize: bufSize,
81+
WriteBufferSize: bufSize,
82+
},
83+
Timeout: *timeout,
84+
}
85+
var request *http.Request
86+
var err error
87+
switch *method {
88+
case "GET":
89+
request, err = http.NewRequest("GET", *url, nil)
90+
case "POST":
91+
postData, postDataErr := base64.StdEncoding.DecodeString(*postDataBase64)
92+
if postDataErr != nil {
93+
fatalf("cannot decode POST data: %v", postDataErr)
94+
}
95+
request, err = http.NewRequest("POST", *url, bytes.NewBuffer(postData))
96+
default:
97+
err = fmt.Errorf("unknown method %q", *method)
98+
}
99+
if err != nil {
100+
fatalf("cannot create request: %v", err)
101+
}
102+
orderedReqHeaders := make([]string, 0, len(request.Header))
103+
for k := range request.Header {
104+
orderedReqHeaders = append(orderedReqHeaders, k)
105+
}
106+
sort.Strings(orderedReqHeaders)
107+
for _, k := range orderedReqHeaders {
108+
for _, v := range request.Header[k] {
109+
fmt.Fprintf(os.Stderr, "REQHEADER: %s: %s\n", k, v)
110+
}
111+
}
112+
metrics.RequestSent = time.Now()
113+
resp, err := client.Do(request)
114+
metrics.ResponseReceived = time.Now()
115+
if err != nil {
116+
fatalf("cannot make request: %v", err)
117+
}
118+
gotFirstByte := false
119+
scanner := bufio.NewScanner(resp.Body)
120+
for scanner.Scan() {
121+
if !gotFirstByte {
122+
metrics.FirstByteRead = time.Now()
123+
gotFirstByte = true
124+
}
125+
if scanner.Text() == "" {
126+
continue
127+
}
128+
fmt.Printf("BODY: %q\n", strings.TrimPrefix(scanner.Text(), "data: "))
129+
}
130+
// Check for any errors that may have occurred during scanning
131+
if err := scanner.Err(); err != nil {
132+
fatalf("error reading response body: %v", err)
133+
}
134+
metrics.LastByteRead = time.Now()
135+
if err := resp.Body.Close(); err != nil {
136+
fatalf("cannot close response body: %v", err)
137+
}
138+
orderedRespHeaders := make([]string, 0, len(resp.Header))
139+
for k := range resp.Header {
140+
orderedRespHeaders = append(orderedRespHeaders, k)
141+
}
142+
sort.Strings(orderedRespHeaders)
143+
for _, k := range orderedRespHeaders {
144+
for _, v := range resp.Header[k] {
145+
fmt.Fprintf(os.Stderr, "RESPHEADER: %s: %s\n", k, v)
146+
}
147+
}
148+
metricsBytes, err := json.Marshal(&metrics)
149+
if err != nil {
150+
fatalf("cannot marshal metrics: %v", err)
151+
}
152+
fmt.Fprintf(os.Stderr, "STATS: %s\n", string(metricsBytes))
153+
}
154+
155+
// LINT.ThenChange(../../ollama/client/client.go)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Use the official NVIDIA CUDA image as the base.
2+
FROM nvidia/cuda:12.8.1-devel-ubuntu22.04
3+
4+
# Set the default shell to bash.
5+
SHELL ["/bin/bash", "-c"]
6+
7+
# Consolidate system dependency installation into a single RUN command
8+
# to reduce the number of layers in the final image.
9+
RUN apt-get update && apt-get install -y \
10+
neovim \
11+
git \
12+
openmpi-bin \
13+
libopenmpi-dev \
14+
python3.10 \
15+
python3.10-dev \
16+
python3-pip \
17+
python3-venv \
18+
python-is-python3 && \
19+
# Clean up the apt cache to reduce image size.
20+
rm -rf /var/lib/apt/lists/*
21+
22+
# Download TensorRT-LLM from the specified version tag.
23+
ARG TENSORRT_LLM_VERSION="1.0.0"
24+
ARG TENSORRT_LLM_DIR="/TensorRT-LLM-${TENSORRT_LLM_VERSION}"
25+
RUN git clone --depth 1 --branch "v${TENSORRT_LLM_VERSION}" https://github.com/NVIDIA/TensorRT-LLM.git "${TENSORRT_LLM_DIR}"
26+
27+
# Create a Python virtual environment and add its bin directory to the system's PATH.
28+
# This makes commands from the venv (like pip, huggingface-cli) available in all subsequent layers.
29+
ENV VENV_PATH="/opt/venv"
30+
RUN python3 -m venv "${VENV_PATH}"
31+
ENV PATH="${VENV_PATH}/bin:${PATH}"
32+
33+
# Upgrade pip and install the huggingface_hub library.
34+
RUN pip install --upgrade pip
35+
RUN pip install huggingface_hub
36+
37+
# Download the model from Hugging Face.
38+
# The HF_TOKEN should be passed as a build argument for security.
39+
ARG HF_TOKEN=""
40+
ARG REPO_ID="meta-llama/Llama-2-7b-chat-hf"
41+
ARG MODEL_DIR="/llama-2-7b-chat-hf"
42+
RUN huggingface-cli download \
43+
"${REPO_ID}" \
44+
--local-dir "${MODEL_DIR}" \
45+
--local-dir-use-symlinks False \
46+
--token "${HF_TOKEN}"
47+
48+
# Set the working directory to the Llama example within the TensorRT-LLM repository.
49+
WORKDIR "${TENSORRT_LLM_DIR}/examples/models/core/llama"
50+
51+
# Install the Python dependencies required for the Llama example.
52+
# This command will use the pip from the virtual environment we added to the PATH.
53+
RUN pip install -r requirements.txt

0 commit comments

Comments
 (0)