Skip to content

DeepSeek V3/R1 inference recipe on Trillium with JetStream + MaxText + Pathways #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 2, 2025
533 changes: 533 additions & 0 deletions inference/trillium/JetStream-Maxtext/DeepSeek-R1-671B/README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive

# Install dependencies
RUN apt -y update && apt install -y --no-install-recommends \
apt-transport-https ca-certificates gnupg git wget \
python3.10 python3-pip curl nano vim

RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1

# Install google cloud sdk
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" \
| tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \
&& curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \
| gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg \
&& apt-get update -y \
&& apt-get install google-cloud-sdk -y

# Install pip
RUN python3 -m pip install --upgrade pip

RUN pip install "huggingface_hub[cli]" hf_transfer

# Set environment variables
ENV JAX_PLATFORMS=proxy
ENV JAX_BACKEND_TARGET=grpc://localhost:38681
ENV XCLOUD_ENVIRONMENT=GCP

# Install JetStream and MaxText

RUN git clone https://github.com/AI-Hypercomputer/JetStream.git && \
git clone https://github.com/AI-Hypercomputer/maxtext.git && \
git clone https://github.com/google/aqt.git

RUN cd /maxtext && bash setup.sh && pip install torch --index-url https://download.pytorch.org/whl/cpu

RUN pip install safetensors setuptools fastapi uvicorn rouge_score scikit-learn

RUN cd /JetStream && pip install -e .

RUN apt -y update && apt-get -y install python3-dev && apt-get -y install build-essential
RUN cp -r /aqt/aqt/* /usr/local/lib/python3.10/dist-packages/aqt/

ENTRYPOINT [ "/bin/bash" ]
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

steps:
- name: 'gcr.io/cloud-builders/docker'
args:
- 'build'
- '--tag=${_ARTIFACT_REGISTRY}/${_JETSTREAM_MAXTEXT_IMAGE}:${_JETSTREAM_MAXTEXT_VERSION}'
- '--file=Dockerfile'
- '.'
automapSubstitutions: true

images:
- ${_ARTIFACT_REGISTRY}/${_JETSTREAM_MAXTEXT_IMAGE}:${_JETSTREAM_MAXTEXT_VERSION}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

taskGroups:
- taskSpec:
runnables:
- container:
imageUri: ${ARTIFACT_REGISTRY}/${JETSTREAM_MAXTEXT_IMAGE}:${JETSTREAM_MAXTEXT_VERSION}
entrypoint: "/bin/sh"
commands:
- "-c"
- mkdir -p /mnt/disks/persist/models/ && echo "Downloading model ${HF_MODEL_NAME}" && huggingface-cli download ${HF_MODEL_NAME} --local-dir /mnt/disks/persist/models/fp8 && cd /maxtext && echo "Converting checkpoint from fp8 to bf16" && python3 -m MaxText.deepseek_fp8_to_bf16 --input-fp8-hf-path /mnt/disks/persist/models/fp8 --output-bf16-hf-path /mnt/disks/persist/models/bf16 --cache-file-num 16 && echo "Converting checkpoint from bf16 to maxtext/unscanned format" && JAX_PLATFORMS='' python3 -m MaxText.convert_deepseek_unscanned_ckpt --base_model_path /mnt/disks/persist/models/bf16 --maxtext_model_path ${GCS_CKPT_PATH_UNSCANNED} --model_size $MODEL_NAME --use-zarr3 false --use-ocdbt false && echo "Completed checkpoint conversion. Unscanned checkpoint saved at ${GCS_CKPT_PATH_UNSCANNED}"
volumes:
- deviceName: persist
mountPath: /mnt/disks/persist
mountOptions: rw,async
computeResource:
cpuMilli: 160000
memoryMib: 3936256
# Define the allocation policy for provisioning VMs
allocationPolicy:
location:
allowedLocations: ["regions/${CLUSTER_CKPT_NODE_REGION}"]
instances:
- policy:
machineType: ${CLUSTER_CKPT_NODE_MACHINE_TYPE}
bootDisk:
type: pd-ssd
sizeGb: 500
disks:
newDisk:
sizeGb: 3000
type: pd-ssd
deviceName: persist
logsPolicy:
destination: CLOUD_LOGGING
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

apiVersion: v2
name: trillium-pathways-jetstream-maxtext-serve-model
description: trillium-pathways-jetstream-maxtext-serve-model
type: application
version: 0.1.0
appVersion: "1.16.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

apiVersion: v1
kind: ConfigMap
metadata:
name: "{{ .Release.Name }}"
data:
maxtext-configuration.yaml: |-
{{- range $key, $value := .Values.maxtext_config }}
{{ $key }}: {{ $value }}
{{- end }}
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

{{- $root := . }}

apiVersion: leaderworkerset.x-k8s.io/v1
kind: LeaderWorkerSet
metadata:
name: {{ .Release.Name }}
annotations:
leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
replicas: 1
leaderWorkerTemplate:
leaderTemplate:
metadata:
labels:
role: leader
app: {{ .Release.Name }}
spec:
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 8x8
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
volumes:
- name: shared-memory
emptyDir:
medium: "Memory"
sizeLimit: 250Gi
- name: local-ssd
hostPath:
path: /mnt/stateful_partition/kube-ephemeral-ssd
- name: workload-configuration
configMap:
name: "{{.Release.Name}}"
containers:
- name: pathways-proxy
image: "{{ .Values.job.pathways_proxy_image.repository }}:{{ .Values.job.pathways_proxy_image.tag }}"
args:
- --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
- --server_port=38681
{{- with (index .Values.volumes.gcsMounts 0) }}
- --gcs_scratch_location=gs://{{ .bucketName }}/tmp
{{- end }}
imagePullPolicy: Always
ports:
- containerPort: 38681

- name: pathways-rm
env:
- name: HOST_ADDRESS
value: "$(LWS_LEADER_ADDRESS)"
- name: TPU_SKIP_MDS_QUERY
value: "true"
image: "{{ .Values.job.pathways_rm_image.repository }}:{{ .Values.job.pathways_rm_image.tag }}"
args:
- --server_port=38677
- --node_type=resource_manager
- --instance_count=1
- --instance_type=tpuv6e:8x8
{{- with (index .Values.volumes.gcsMounts 0) }}
- --gcs_scratch_location=gs://{{ .bucketName }}/tmp
{{- end }}
imagePullPolicy: Always
ports:
- containerPort: 38677

- name: jax-tpu
image: "{{ .Values.job.jax_tpu_image.repository }}:{{ .Values.job.jax_tpu_image.tag }}"
imagePullPolicy: Always
env:
- name: ENABLE_PATHWAYS_PERSISTENCE
value: "1"
- name: HF_TOKEN
valueFrom:
secretKeyRef:
name: "{{ .Values.huggingface.secretName }}"
key: "{{ .Values.huggingface.secretData.token }}"
workingDir: /workspace
command: ["/bin/bash", "-c"]
args:
- |
set -eux
# Parse server configurations from values file
echo "MaxText configuration file:"
sed 's/^/| /' /etc/workload-configuration/maxtext-configuration.yaml
echo ""

OPTIONS=()
while IFS= read -r line || [[ -n "$line" ]]; do
# Skip empty lines and comments
[[ -z "$line" || "$line" =~ ^[[:space:]]*# ]] && continue

key=$(echo "$line" | cut -d':' -f1 | tr -d '[:space:]')
value=$(echo "$line" | cut -d':' -f2- | sed 's/^[[:space:]]*//')

# Handle environment variable expansion
if [[ "$value" == \$* ]]; then
var_name=${value#\$}

if [[ -z "$var_name" ]]; then
expanded_value="$"
else
expanded_value="${!var_name:-$value}"
fi

OPTIONS+=("$key=$expanded_value")
else
OPTIONS+=("$key=$value")
fi
done < /etc/workload-configuration/maxtext-configuration.yaml

echo "===== MaxText Configuration ====="
echo "${OPTIONS[@]}"

cd /maxtext
python3 -m MaxText.maxengine_server \
/maxtext/MaxText/configs/base.yml \
"${OPTIONS[@]}"

ports:
- containerPort: {{ .Values.jetstream.service.ports.grpc }}
startupProbe:
httpGet:
path: /healthcheck
port: {{ .Values.jetstream.service.ports.http }}
scheme: HTTP
periodSeconds: 1
initialDelaySeconds: 600
failureThreshold: 10000
livenessProbe:
httpGet:
path: /healthcheck
port: {{ .Values.jetstream.service.ports.http }}
scheme: HTTP
periodSeconds: 60
failureThreshold: 10
readinessProbe:
httpGet:
path: /healthcheck
port: {{ .Values.jetstream.service.ports.http }}
scheme: HTTP
periodSeconds: 60
failureThreshold: 10
volumeMounts:
- name: shared-memory
mountPath: /dev/shm
- name: workload-configuration
mountPath: /etc/workload-configuration
- name: local-ssd
mountPath: {{ .Values.volumes.ssdMountPath }}

- name: jetstream-http
image: "{{ .Values.job.jetstream_http_image.repository }}:{{ .Values.job.jetstream_http_image.tag }}"
imagePullPolicy: Always
ports:
- containerPort: 8000
size: 17

workerTemplate:
spec:
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 8x8
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
containers:
- name: worker
args:
- --server_port=38679
- --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
{{- with (index .Values.volumes.gcsMounts 0) }}
- --gcs_scratch_location=gs://{{ .bucketName }}/tmp
{{- end }}
image: "{{ .Values.job.pathways_rm_image.repository }}:{{ .Values.job.pathways_rm_image.tag }}"
imagePullPolicy: Always
ports:
- containerPort: 38679
resources:
limits:
google.com/tpu: "4"
Loading