Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/container/Dockerfile.torchax
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# syntax=docker/dockerfile:1-labs

ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
ARG URLREF_TORCHAX=https://github.com/google/torchax.git#main
ARG SRC_PATH_TORCHAX=/opt/torchax

###############################################################################
## Download source and add auxiliary scripts
###############################################################################

FROM ${BASE_IMAGE} as mealkit
ARG URLREF_TORCHAX
ARG SRC_PATH_TORCHAX

# Specify installation targets
RUN <<"EOF" bash -ex
git-clone.sh ${URLREF_TORCHAX} ${SRC_PATH_TORCHAX}
echo "-e file://${SRC_PATH_TORCHAX}" >> /opt/pip-tools.d/requirements-torchax.in
echo "torch" >> /opt/pip-tools.d/requirements-torchax.in
echo "torchvision" >> /opt/pip-tools.d/requirements-torchax.in
# tensorflow is only needed to suppress an absl warning when importing torchax
# that tensorflow.io.gfile will not support GCS paths such as gs://...
# comment it out if you want to save ~300 MB in image size
echo "tensorflow" >> /opt/pip-tools.d/requirements-torchax.in
EOF

###############################################################################
## Install accumulated packages from the base image and the previous stage
###############################################################################

FROM mealkit as final

RUN <<"EOF" bash -ex -o pipefail
PIP_INDEX_URL=https://download.pytorch.org/whl/cpu \
PIP_EXTRA_INDEX_URL=https://pypi.org/simple \
pip-finalize.sh
EOF
5 changes: 5 additions & 0 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,8 @@ tunix:
tracking_ref: main
latest_verified_commit: d799a45d48027e27b6a08aaf7cb15e6a4f495c01
mode: git-clone
torchax:
url: https://github.com/google/torchax.git
tracking_ref: main
latest_verified_commit: f41e3de8526f9d4e8410bfb84660faaaf0b3ba4a
mode: git-clone
28 changes: 28 additions & 0 deletions .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,34 @@ jobs:
EXTRA_BUILD_ARGS: |
URLREF_MAXTEXT=${{ fromJson(inputs.SOURCE_URLREFS).MAXTEXT }}

build-torchax:
needs: build-jax
runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "small"]
outputs:
DOCKER_TAG_MEALKIT: ${{ steps.build-torchax.outputs.DOCKER_TAG_MEALKIT }}
DOCKER_TAG_FINAL: ${{ steps.build-torchax.outputs.DOCKER_TAG_FINAL }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Build TorchAX container
id: build-torchax
uses: ./.github/actions/build-container
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: artifact-torchax-build
BADGE_FILENAME: badge-torchax-build
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
CONTAINER_NAME: torchax
DOCKERFILE: .github/container/Dockerfile.torchax
RUNNER_SIZE: small
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }}
github-token: ${{ secrets.GITHUB_TOKEN }}
bazel-remote-cache-url: ${{ vars.BAZEL_REMOTE_CACHE_URL }}
EXTRA_BUILD_ARGS: |
URLREF_TORCHAX=${{ fromJson(inputs.SOURCE_URLREFS).TORCHAX }}

build-upstream-t5x:
needs: build-jax
runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "small"]
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ on:
type: string
description: |
A comma-separated PACKAGE=URL#REF list to override sources used by build.
PACKAGE∊{JAX,XLA,Flax,transformer-engine,airio,axlearn,equinox,T5X,maxtext} (case-insensitive)
PACKAGE∊{JAX,XLA,Flax,transformer-engine,airio,axlearn,equinox,T5X,maxtext,TorchAX} (case-insensitive)
default: ''
required: false
MODE:
Expand Down Expand Up @@ -361,6 +361,7 @@ jobs:
upstream-t5x
t5x
axlearn
torchax
)
declare -a STAGES=(
mealkit
Expand Down
Loading