Skip to content

Commit 20161c0

Browse files
committed
ci: Allow selective triggering of pre-submit CI
By specifying a target-container as input to `TEST_SUBSET` the pre-submit CI will build all required containers and run tests only beginning from that node down to all connected leaves. It assumes that all nodes prior to the target-container are unchanged since their last successful test.
1 parent e44e507 commit 20161c0

File tree

2 files changed

+129
-43
lines changed

2 files changed

+129
-43
lines changed

.github/workflows/_ci.yaml

Lines changed: 85 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name: ~CI, single-arch
2-
run-name: CI-${{ inputs.ARCHITECTURE }}
2+
run-name: CI-${{ inputs.ARCHITECTURE }}-${{ inputs.TESTSUBSET }}
33
on:
44
workflow_call:
55
inputs:
@@ -16,26 +16,60 @@ on:
1616
description: Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch
1717
default: ''
1818
required: false
19+
TEST_SUBSET:
20+
type: string
21+
description: |
22+
Subset of tests to run. Allowed values are one of:
23+
- base
24+
- jax
25+
- levanter
26+
- equinox
27+
- triton
28+
- upstream-t5x
29+
- rosetta-t5x
30+
- upstream-pax
31+
- rosetta-pax
32+
- maxtext
33+
- grok
34+
35+
Will run all downstream-connected nodes and leaves.
36+
default: 'base'
37+
required: false
1938
outputs:
2039
DOCKER_TAGS:
2140
description: JSON object containing tags of all docker images built
2241
value: ${{ jobs.collect-docker-tags.outputs.TAGS }}
2342

2443
permissions:
25-
contents: read # to fetch code
26-
actions: write # to cancel previous workflows
44+
contents: read # to fetch code
45+
actions: write # to cancel previous workflows
2746
packages: write # to upload container
2847

2948
jobs:
49+
pre-flight:
50+
runs-on: ubuntu-22.04
51+
steps:
52+
- name: Validate input `TEST_SUBSET`
53+
shell: bash
54+
run: |
55+
valid_inputs=("base" "core" "levanter" "equinox" "triton" "upstream-t5x" "rosetta-t5x" "upstream-pax" "rosetta-pax" "maxtext" "grok")
56+
57+
if [[ " ${valid_inputs[*]} " != *" ${{ inputs.TEST_SUBSET }} "* ]]; then
58+
echo "Invalid value for \`TEST_SUBSET\` provided. Expected one of: ($valid_inputs), Actual: ${{ inputs.TEST_SUBSET }}"
59+
exit 1
60+
fi
3061
62+
# Always
3163
build-base:
3264
uses: ./.github/workflows/_build_base.yaml
65+
needs: pre-flight
3366
with:
3467
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
3568
BUILD_DATE: ${{ inputs.BUILD_DATE }}
3669
MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }}
3770
secrets: inherit
3871

72+
# Always
3973
build-jax:
4074
needs: build-base
4175
uses: ./.github/workflows/_build.yaml
@@ -50,9 +84,10 @@ jobs:
5084
RUNNER_SIZE: large
5185
secrets: inherit
5286

87+
# base, jax, triton
5388
build-triton:
5489
needs: build-jax
55-
if: inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
90+
if: contains(fromJSON('["base", "jax", "triton"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
5691
uses: ./.github/workflows/_build.yaml
5792
with:
5893
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
@@ -64,9 +99,11 @@ jobs:
6499
DOCKERFILE: .github/container/Dockerfile.triton
65100
secrets: inherit
66101

102+
# base, jax, equinox
67103
build-equinox:
68104
needs: build-jax
69105
uses: ./.github/workflows/_build.yaml
106+
if: contains(fromJSON('["base", "jax", "equinox"]'), inputs.TEST_SUBSET)
70107
with:
71108
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
72109
ARTIFACT_NAME: artifact-equinox-build
@@ -77,9 +114,10 @@ jobs:
77114
DOCKERFILE: .github/container/Dockerfile.equinox
78115
secrets: inherit
79116

117+
# base, jax, maxtext
80118
build-maxtext:
81119
needs: build-jax
82-
if: inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
120+
if: contains(fromJSON('["base", "jax", "maxtext"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
83121
uses: ./.github/workflows/_build.yaml
84122
with:
85123
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
@@ -91,35 +129,41 @@ jobs:
91129
DOCKERFILE: .github/container/Dockerfile.maxtext.amd64
92130
secrets: inherit
93131

132+
# base, jax, levanter
94133
build-levanter:
95134
needs: [build-jax]
96135
uses: ./.github/workflows/_build.yaml
136+
if: contains(fromJSON('["base", "jax", "levanter"]'), inputs.TEST_SUBSET)
97137
with:
98138
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
99-
ARTIFACT_NAME: "artifact-levanter-build"
100-
BADGE_FILENAME: "badge-levanter-build"
139+
ARTIFACT_NAME: 'artifact-levanter-build'
140+
BADGE_FILENAME: 'badge-levanter-build'
101141
BUILD_DATE: ${{ inputs.BUILD_DATE }}
102142
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
103143
CONTAINER_NAME: levanter
104144
DOCKERFILE: .github/container/Dockerfile.levanter
105145
secrets: inherit
106146

147+
# base, jax, upstream-t5x
107148
build-upstream-t5x:
108149
needs: build-jax
109150
uses: ./.github/workflows/_build.yaml
151+
if: contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET)
110152
with:
111153
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
112-
ARTIFACT_NAME: "artifact-t5x-build"
113-
BADGE_FILENAME: "badge-t5x-build"
154+
ARTIFACT_NAME: 'artifact-t5x-build'
155+
BADGE_FILENAME: 'badge-t5x-build'
114156
BUILD_DATE: ${{ inputs.BUILD_DATE }}
115157
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
116158
CONTAINER_NAME: upstream-t5x
117159
DOCKERFILE: .github/container/Dockerfile.t5x.${{ inputs.ARCHITECTURE }}
118160
secrets: inherit
119161

162+
# base, jax, upstream-pax
120163
build-upstream-pax:
121164
needs: build-jax
122165
uses: ./.github/workflows/_build.yaml
166+
if: contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET)
123167
with:
124168
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
125169
ARTIFACT_NAME: artifact-pax-build
@@ -130,42 +174,48 @@ jobs:
130174
DOCKERFILE: .github/container/Dockerfile.pax.${{ inputs.ARCHITECTURE }}
131175
secrets: inherit
132176

177+
# base, jax, upstream-t5x, rosetta-t5x
133178
build-rosetta-t5x:
134179
needs: build-upstream-t5x
135180
uses: ./.github/workflows/_build_rosetta.yaml
181+
if: contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET)
136182
with:
137183
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
138184
BUILD_DATE: ${{ inputs.BUILD_DATE }}
139185
BASE_IMAGE: ${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_MEALKIT }}
140186
BASE_LIBRARY: t5x
141187
secrets: inherit
142188

189+
# base, jax, upstream-pax, rosetta-pax
143190
build-rosetta-pax:
144191
needs: build-upstream-pax
145192
uses: ./.github/workflows/_build_rosetta.yaml
193+
if: contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET)
146194
with:
147195
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
148196
BUILD_DATE: ${{ inputs.BUILD_DATE }}
149197
BASE_IMAGE: ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_MEALKIT }}
150198
BASE_LIBRARY: pax
151199
secrets: inherit
152200

201+
# base, jax, grok
153202
build-grok:
154203
needs: [build-jax]
155204
uses: ./.github/workflows/_build.yaml
205+
if: contains(fromJSON('["base", "jax", "grok"]'), inputs.TEST_SUBSET)
156206
with:
157207
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
158-
ARTIFACT_NAME: "artifact-grok-build"
159-
BADGE_FILENAME: "badge-grok-build"
208+
ARTIFACT_NAME: 'artifact-grok-build'
209+
BADGE_FILENAME: 'badge-grok-build'
160210
BUILD_DATE: ${{ inputs.BUILD_DATE }}
161211
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
162212
CONTAINER_NAME: grok
163213
DOCKERFILE: .github/container/Dockerfile.grok
164214
secrets: inherit
165-
215+
166216
collect-docker-tags:
167217
runs-on: ubuntu-22.04
168-
if: "!cancelled()"
218+
if: '!cancelled()'
169219
needs:
170220
- build-base
171221
- build-jax
@@ -236,9 +286,10 @@ jobs:
236286
- name: Run integration test ${{ matrix.TEST_SCRIPT }}
237287
run: bash rosetta/tests/${{ matrix.TEST_SCRIPT }}
238288

289+
# base, jax
239290
test-jax:
240291
needs: build-jax
241-
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
292+
if: contains(fromJSON('["base", "jax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
242293
uses: ./.github/workflows/_test_unit.yaml
243294
with:
244295
TEST_NAME: jax
@@ -291,33 +342,37 @@ jobs:
291342
# test-equinox.log
292343
# secrets: inherit
293344

345+
# base, jax, upstream-pax
294346
test-te-multigpu:
295347
needs: build-upstream-pax
296-
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
348+
if: contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
297349
uses: ./.github/workflows/_test_te.yaml
298350
with:
299351
TE_IMAGE: ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_FINAL }}
300352
secrets: inherit
301353

354+
# base, jax, upstream-t5x
302355
test-upstream-t5x:
303356
needs: build-upstream-t5x
304-
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
357+
if: contains(fromJSON('["base", "jax", "upstream-t5x"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
305358
uses: ./.github/workflows/_test_upstream_t5x.yaml
306359
with:
307360
T5X_IMAGE: ${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_FINAL }}
308361
secrets: inherit
309362

363+
# base, jax, upstream-t5x, rosetta-t5x
310364
test-rosetta-t5x:
311365
needs: build-rosetta-t5x
312-
if: inputs.ARCHITECTURE == 'amd64' # no images for arm64
366+
if: contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
313367
uses: ./.github/workflows/_test_t5x_rosetta.yaml
314368
with:
315369
T5X_IMAGE: ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAG_FINAL }}
316370
secrets: inherit
317371

372+
# base, jax
318373
test-pallas:
319374
needs: build-jax
320-
if: inputs.ARCHITECTURE == 'amd64' # triton doesn't support arm64(?)
375+
if: contains(fromJSON('["base", "jax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # triton doesn't support arm64(?)
321376
uses: ./.github/workflows/_test_unit.yaml
322377
with:
323378
TEST_NAME: pallas
@@ -341,9 +396,10 @@ jobs:
341396
test-pallas.log
342397
secrets: inherit
343398

399+
# base, jax, triton
344400
test-triton:
345401
needs: build-triton
346-
if: inputs.ARCHITECTURE == 'amd64' # no images for arm64
402+
if: contains(fromJSON('["base", "jax", "triton"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
347403
uses: ./.github/workflows/_test_unit.yaml
348404
with:
349405
TEST_NAME: triton
@@ -367,9 +423,10 @@ jobs:
367423
test-triton.log
368424
secrets: inherit
369425

426+
# base, jax, levanter
370427
test-levanter:
371428
needs: build-levanter
372-
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
429+
if: contains(fromJSON('["base", "jax", "levanter"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
373430
uses: ./.github/workflows/_test_unit.yaml
374431
with:
375432
TEST_NAME: levanter
@@ -394,9 +451,10 @@ jobs:
394451
test-levanter.log
395452
secrets: inherit
396453

454+
# base, jax, upstream-pax
397455
test-te:
398456
needs: build-upstream-pax
399-
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
457+
if: contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
400458
uses: ./.github/workflows/_test_unit.yaml
401459
with:
402460
TEST_NAME: te
@@ -422,25 +480,28 @@ jobs:
422480
pytest-report.jsonl
423481
secrets: inherit
424482

483+
# base, jax, upstream-pax
425484
test-upstream-pax:
426485
needs: build-upstream-pax
427-
if: inputs.ARCHITECTURE == 'amd64' # no images for arm64
486+
if: contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
428487
uses: ./.github/workflows/_test_upstream_pax.yaml
429488
with:
430489
PAX_IMAGE: ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_FINAL }}
431490
secrets: inherit
432491

492+
# base, jax, upstream-pax, rosetta-pax
433493
test-rosetta-pax:
434494
needs: build-rosetta-pax
435-
if: inputs.ARCHITECTURE == 'amd64' # no images for arm64
495+
if: contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
436496
uses: ./.github/workflows/_test_pax_rosetta.yaml
437497
with:
438498
PAX_IMAGE: ${{ needs.build-rosetta-pax.outputs.DOCKER_TAG_FINAL }}
439499
secrets: inherit
440500

501+
# base, jax, maxtext
441502
test-maxtext:
442503
needs: build-maxtext
443-
if: inputs.ARCHITECTURE == 'amd64' # no images for arm64
504+
if: contains(fromJSON('["base", "jax", "maxtext"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
444505
uses: ./.github/workflows/_test_maxtext.yaml
445506
with:
446507
MAXTEXT_IMAGE: ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }}

0 commit comments

Comments
 (0)