-
Notifications
You must be signed in to change notification settings - Fork 2.9k
83 lines (76 loc) · 3.4 KB
/
bazel_cuda_non_rbe.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# CI - Bazel CUDA tests (Non-RBE)
#
# This workflow runs the CUDA tests with Bazel. It can only be triggered by other workflows via
# `workflow_call`. It is used by the `CI - Wheel Tests` workflows to run the Bazel CUDA tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket.
# - Executes the `run_bazel_test_cuda_non_rbe.sh` script, which performs the following actions:
# - Installs the downloaded wheel artifacts.
# - Runs the CUDA tests with Bazel.
name: CI - Bazel CUDA tests (Non-RBE)
on:
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-n2-16"
python:
description: "Which python version to test?"
type: string
required: true
default: "3.12"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
required: true
default: "0"
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: boolean
required: false
default: false
jobs:
run-tests:
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
# Enable writing to the Bazel remote cache bucket.
JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: "1"
name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)
# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')
echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
- name: Download the wheel artifacts from GCS
run: >-
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CUDA tests (Non-RBE)
timeout-minutes: 60
run: ./ci/run_bazel_test_cuda_non_rbe.sh