diff --git a/.github/container/Dockerfile.torchax b/.github/container/Dockerfile.torchax new file mode 100644 index 000000000..3aa4450d4 --- /dev/null +++ b/.github/container/Dockerfile.torchax @@ -0,0 +1,37 @@ +# syntax=docker/dockerfile:1-labs + +ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax +ARG URLREF_TORCHAX=https://github.com/google/torchax.git#main +ARG SRC_PATH_TORCHAX=/opt/torchax + +############################################################################### +## Download source and add auxiliary scripts +############################################################################### + +FROM ${BASE_IMAGE} as mealkit +ARG URLREF_TORCHAX +ARG SRC_PATH_TORCHAX + +# Specify installation targets +RUN <<"EOF" bash -ex +git-clone.sh ${URLREF_TORCHAX} ${SRC_PATH_TORCHAX} +echo "-e file://${SRC_PATH_TORCHAX}" >> /opt/pip-tools.d/requirements-torchax.in +echo "torch" >> /opt/pip-tools.d/requirements-torchax.in +echo "torchvision" >> /opt/pip-tools.d/requirements-torchax.in +# tensorflow is only needed to suppress an absl warning when importing torchax +# that tensorflow.io.gfile will not support GCS paths such as gs://... +# comment it out if you want to save ~300 MB in image size +echo "tensorflow" >> /opt/pip-tools.d/requirements-torchax.in +EOF + +############################################################################### +## Install accumulated packages from the base image and the previous stage +############################################################################### + +FROM mealkit as final + +RUN <<"EOF" bash -ex -o pipefail +PIP_INDEX_URL=https://download.pytorch.org/whl/cpu \ +PIP_EXTRA_INDEX_URL=https://pypi.org/simple \ +pip-finalize.sh +EOF diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 0103df6d1..ccbb14d73 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -109,3 +109,8 @@ tunix: tracking_ref: main latest_verified_commit: d799a45d48027e27b6a08aaf7cb15e6a4f495c01 mode: git-clone +torchax: + url: https://github.com/google/torchax.git + tracking_ref: main + latest_verified_commit: f41e3de8526f9d4e8410bfb84660faaaf0b3ba4a + mode: git-clone diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 59911d0f2..61f5fd5af 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -148,6 +148,34 @@ jobs: EXTRA_BUILD_ARGS: | URLREF_MAXTEXT=${{ fromJson(inputs.SOURCE_URLREFS).MAXTEXT }} + build-torchax: + needs: build-jax + runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "small"] + outputs: + DOCKER_TAG_MEALKIT: ${{ steps.build-torchax.outputs.DOCKER_TAG_MEALKIT }} + DOCKER_TAG_FINAL: ${{ steps.build-torchax.outputs.DOCKER_TAG_FINAL }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Build TorchAX container + id: build-torchax + uses: ./.github/actions/build-container + with: + ARCHITECTURE: ${{ inputs.ARCHITECTURE }} + ARTIFACT_NAME: artifact-torchax-build + BADGE_FILENAME: badge-torchax-build + BUILD_DATE: ${{ inputs.BUILD_DATE }} + BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }} + CONTAINER_NAME: torchax + DOCKERFILE: .github/container/Dockerfile.torchax + RUNNER_SIZE: small + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }} + github-token: ${{ secrets.GITHUB_TOKEN }} + bazel-remote-cache-url: ${{ vars.BAZEL_REMOTE_CACHE_URL }} + EXTRA_BUILD_ARGS: | + URLREF_TORCHAX=${{ fromJson(inputs.SOURCE_URLREFS).TORCHAX }} + build-upstream-t5x: needs: build-jax runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "small"] diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1031f1e25..f6c789e9f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -38,7 +38,7 @@ on: type: string description: | A comma-separated PACKAGE=URL#REF list to override sources used by build. - PACKAGE∊{JAX,XLA,Flax,transformer-engine,airio,axlearn,equinox,T5X,maxtext} (case-insensitive) + PACKAGE∊{JAX,XLA,Flax,transformer-engine,airio,axlearn,equinox,T5X,maxtext,TorchAX} (case-insensitive) default: '' required: false MODE: @@ -361,6 +361,7 @@ jobs: upstream-t5x t5x axlearn + torchax ) declare -a STAGES=( mealkit