1
1
name : ~CI, single-arch
2
- run-name : CI-${{ inputs.ARCHITECTURE }}
2
+ run-name : CI-${{ inputs.ARCHITECTURE }}-${{ inputs.TESTSUBSET }}
3
3
on :
4
4
workflow_call :
5
5
inputs :
16
16
description : Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch
17
17
default : ' '
18
18
required : false
19
+ TEST_SUBSET :
20
+ type : string
21
+ description : |
22
+ Subset of tests to run. Allowed values are one of:
23
+ - base
24
+ - jax
25
+ - levanter
26
+ - equinox
27
+ - triton
28
+ - upstream-t5x
29
+ - rosetta-t5x
30
+ - upstream-pax
31
+ - rosetta-pax
32
+ - maxtext
33
+ - grok
34
+
35
+ Will run all downstream-connected nodes and leaves.
36
+ default : ' base'
37
+ required : false
19
38
outputs :
20
39
DOCKER_TAGS :
21
40
description : JSON object containing tags of all docker images built
22
41
value : ${{ jobs.collect-docker-tags.outputs.TAGS }}
23
42
24
43
permissions :
25
- contents : read # to fetch code
26
- actions : write # to cancel previous workflows
44
+ contents : read # to fetch code
45
+ actions : write # to cancel previous workflows
27
46
packages : write # to upload container
28
47
29
48
jobs :
49
+ pre-flight :
50
+ runs-on : ubuntu-22.04
51
+ steps :
52
+ - name : Validate input `TEST_SUBSET`
53
+ shell : bash
54
+ run : |
55
+ valid_inputs=("base" "core" "levanter" "equinox" "triton" "upstream-t5x" "rosetta-t5x" "upstream-pax" "rosetta-pax" "maxtext" "grok")
56
+
57
+ if [[ " ${valid_inputs[*]} " != *" ${{ inputs.TEST_SUBSET }} "* ]]; then
58
+ echo "Invalid value for \`TEST_SUBSET\` provided. Expected one of: ($valid_inputs), Actual: ${{ inputs.TEST_SUBSET }}"
59
+ exit 1
60
+ fi
30
61
62
+ # Always
31
63
build-base :
32
64
uses : ./.github/workflows/_build_base.yaml
65
+ needs : pre-flight
33
66
with :
34
67
ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
35
68
BUILD_DATE : ${{ inputs.BUILD_DATE }}
36
69
MANIFEST_ARTIFACT_NAME : ${{ inputs.MANIFEST_ARTIFACT_NAME }}
37
70
secrets : inherit
38
71
72
+ # Always
39
73
build-jax :
40
74
needs : build-base
41
75
uses : ./.github/workflows/_build.yaml
50
84
RUNNER_SIZE : large
51
85
secrets : inherit
52
86
87
+ # base, jax, triton
53
88
build-triton :
54
89
needs : build-jax
55
- if : inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
90
+ if : contains(fromJSON('["base", "jax", "triton"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
56
91
uses : ./.github/workflows/_build.yaml
57
92
with :
58
93
ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
64
99
DOCKERFILE : .github/container/Dockerfile.triton
65
100
secrets : inherit
66
101
102
+ # base, jax, equinox
67
103
build-equinox :
68
104
needs : build-jax
69
105
uses : ./.github/workflows/_build.yaml
106
+ if : contains(fromJSON('["base", "jax", "equinox"]'), inputs.TEST_SUBSET)
70
107
with :
71
108
ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
72
109
ARTIFACT_NAME : artifact-equinox-build
@@ -77,9 +114,10 @@ jobs:
77
114
DOCKERFILE : .github/container/Dockerfile.equinox
78
115
secrets : inherit
79
116
117
+ # base, jax, maxtext
80
118
build-maxtext :
81
119
needs : build-jax
82
- if : inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
120
+ if : contains(fromJSON('["base", "jax", "maxtext"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
83
121
uses : ./.github/workflows/_build.yaml
84
122
with :
85
123
ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
@@ -91,35 +129,41 @@ jobs:
91
129
DOCKERFILE : .github/container/Dockerfile.maxtext.amd64
92
130
secrets : inherit
93
131
132
+ # base, jax, levanter
94
133
build-levanter :
95
134
needs : [build-jax]
96
135
uses : ./.github/workflows/_build.yaml
136
+ if : contains(fromJSON('["base", "jax", "levanter"]'), inputs.TEST_SUBSET)
97
137
with :
98
138
ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
99
- ARTIFACT_NAME : " artifact-levanter-build"
100
- BADGE_FILENAME : " badge-levanter-build"
139
+ ARTIFACT_NAME : ' artifact-levanter-build'
140
+ BADGE_FILENAME : ' badge-levanter-build'
101
141
BUILD_DATE : ${{ inputs.BUILD_DATE }}
102
142
BASE_IMAGE : ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
103
143
CONTAINER_NAME : levanter
104
144
DOCKERFILE : .github/container/Dockerfile.levanter
105
145
secrets : inherit
106
146
147
+ # base, jax, upstream-t5x
107
148
build-upstream-t5x :
108
149
needs : build-jax
109
150
uses : ./.github/workflows/_build.yaml
151
+ if : contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET)
110
152
with :
111
153
ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
112
- ARTIFACT_NAME : " artifact-t5x-build"
113
- BADGE_FILENAME : " badge-t5x-build"
154
+ ARTIFACT_NAME : ' artifact-t5x-build'
155
+ BADGE_FILENAME : ' badge-t5x-build'
114
156
BUILD_DATE : ${{ inputs.BUILD_DATE }}
115
157
BASE_IMAGE : ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
116
158
CONTAINER_NAME : upstream-t5x
117
159
DOCKERFILE : .github/container/Dockerfile.t5x.${{ inputs.ARCHITECTURE }}
118
160
secrets : inherit
119
161
162
+ # base, jax, upstream-pax
120
163
build-upstream-pax :
121
164
needs : build-jax
122
165
uses : ./.github/workflows/_build.yaml
166
+ if : contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET)
123
167
with :
124
168
ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
125
169
ARTIFACT_NAME : artifact-pax-build
@@ -130,42 +174,48 @@ jobs:
130
174
DOCKERFILE : .github/container/Dockerfile.pax.${{ inputs.ARCHITECTURE }}
131
175
secrets : inherit
132
176
177
+ # base, jax, upstream-t5x, rosetta-t5x
133
178
build-rosetta-t5x :
134
179
needs : build-upstream-t5x
135
180
uses : ./.github/workflows/_build_rosetta.yaml
181
+ if : contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET)
136
182
with :
137
183
ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
138
184
BUILD_DATE : ${{ inputs.BUILD_DATE }}
139
185
BASE_IMAGE : ${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_MEALKIT }}
140
186
BASE_LIBRARY : t5x
141
187
secrets : inherit
142
188
189
+ # base, jax, upstream-pax, rosetta-pax
143
190
build-rosetta-pax :
144
191
needs : build-upstream-pax
145
192
uses : ./.github/workflows/_build_rosetta.yaml
193
+ if : contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET)
146
194
with :
147
195
ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
148
196
BUILD_DATE : ${{ inputs.BUILD_DATE }}
149
197
BASE_IMAGE : ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_MEALKIT }}
150
198
BASE_LIBRARY : pax
151
199
secrets : inherit
152
200
201
+ # base, jax, grok
153
202
build-grok :
154
203
needs : [build-jax]
155
204
uses : ./.github/workflows/_build.yaml
205
+ if : contains(fromJSON('["base", "jax", "grok"]'), inputs.TEST_SUBSET)
156
206
with :
157
207
ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
158
- ARTIFACT_NAME : " artifact-grok-build"
159
- BADGE_FILENAME : " badge-grok-build"
208
+ ARTIFACT_NAME : ' artifact-grok-build'
209
+ BADGE_FILENAME : ' badge-grok-build'
160
210
BUILD_DATE : ${{ inputs.BUILD_DATE }}
161
211
BASE_IMAGE : ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
162
212
CONTAINER_NAME : grok
163
213
DOCKERFILE : .github/container/Dockerfile.grok
164
214
secrets : inherit
165
-
215
+
166
216
collect-docker-tags :
167
217
runs-on : ubuntu-22.04
168
- if : " !cancelled()"
218
+ if : ' !cancelled()'
169
219
needs :
170
220
- build-base
171
221
- build-jax
@@ -236,9 +286,10 @@ jobs:
236
286
- name : Run integration test ${{ matrix.TEST_SCRIPT }}
237
287
run : bash rosetta/tests/${{ matrix.TEST_SCRIPT }}
238
288
289
+ # base, jax
239
290
test-jax :
240
291
needs : build-jax
241
- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
292
+ if : contains(fromJSON('["base", "jax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
242
293
uses : ./.github/workflows/_test_unit.yaml
243
294
with :
244
295
TEST_NAME : jax
@@ -291,33 +342,37 @@ jobs:
291
342
# test-equinox.log
292
343
# secrets: inherit
293
344
345
+ # base, jax, upstream-pax
294
346
test-te-multigpu :
295
347
needs : build-upstream-pax
296
- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
348
+ if : contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
297
349
uses : ./.github/workflows/_test_te.yaml
298
350
with :
299
351
TE_IMAGE : ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_FINAL }}
300
352
secrets : inherit
301
353
354
+ # base, jax, upstream-t5x
302
355
test-upstream-t5x :
303
356
needs : build-upstream-t5x
304
- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
357
+ if : contains(fromJSON('["base", "jax", "upstream-t5x"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
305
358
uses : ./.github/workflows/_test_upstream_t5x.yaml
306
359
with :
307
360
T5X_IMAGE : ${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_FINAL }}
308
361
secrets : inherit
309
362
363
+ # base, jax, upstream-t5x, rosetta-t5x
310
364
test-rosetta-t5x :
311
365
needs : build-rosetta-t5x
312
- if : inputs.ARCHITECTURE == 'amd64' # no images for arm64
366
+ if : contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
313
367
uses : ./.github/workflows/_test_t5x_rosetta.yaml
314
368
with :
315
369
T5X_IMAGE : ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAG_FINAL }}
316
370
secrets : inherit
317
371
372
+ # base, jax
318
373
test-pallas :
319
374
needs : build-jax
320
- if : inputs.ARCHITECTURE == 'amd64' # triton doesn't support arm64(?)
375
+ if : contains(fromJSON('["base", "jax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # triton doesn't support arm64(?)
321
376
uses : ./.github/workflows/_test_unit.yaml
322
377
with :
323
378
TEST_NAME : pallas
@@ -341,9 +396,10 @@ jobs:
341
396
test-pallas.log
342
397
secrets : inherit
343
398
399
+ # base, jax, triton
344
400
test-triton :
345
401
needs : build-triton
346
- if : inputs.ARCHITECTURE == 'amd64' # no images for arm64
402
+ if : contains(fromJSON('["base", "jax", "triton"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
347
403
uses : ./.github/workflows/_test_unit.yaml
348
404
with :
349
405
TEST_NAME : triton
@@ -367,9 +423,10 @@ jobs:
367
423
test-triton.log
368
424
secrets : inherit
369
425
426
+ # base, jax, levanter
370
427
test-levanter :
371
428
needs : build-levanter
372
- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
429
+ if : contains(fromJSON('["base", "jax", "levanter"]'), inputs.TEST_SUBSET) && inputs. ARCHITECTURE == 'amd64' # arm64 runners n/a
373
430
uses : ./.github/workflows/_test_unit.yaml
374
431
with :
375
432
TEST_NAME : levanter
@@ -394,9 +451,10 @@ jobs:
394
451
test-levanter.log
395
452
secrets : inherit
396
453
454
+ # base, jax, upstream-pax
397
455
test-te :
398
456
needs : build-upstream-pax
399
- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
457
+ if : contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs. ARCHITECTURE == 'amd64' # arm64 runners n/a
400
458
uses : ./.github/workflows/_test_unit.yaml
401
459
with :
402
460
TEST_NAME : te
@@ -422,25 +480,28 @@ jobs:
422
480
pytest-report.jsonl
423
481
secrets : inherit
424
482
483
+ # base, jax, upstream-pax
425
484
test-upstream-pax :
426
485
needs : build-upstream-pax
427
- if : inputs.ARCHITECTURE == 'amd64' # no images for arm64
486
+ if : contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
428
487
uses : ./.github/workflows/_test_upstream_pax.yaml
429
488
with :
430
489
PAX_IMAGE : ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_FINAL }}
431
490
secrets : inherit
432
491
492
+ # base, jax, upstream-pax, rosetta-pax
433
493
test-rosetta-pax :
434
494
needs : build-rosetta-pax
435
- if : inputs.ARCHITECTURE == 'amd64' # no images for arm64
495
+ if : contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
436
496
uses : ./.github/workflows/_test_pax_rosetta.yaml
437
497
with :
438
498
PAX_IMAGE : ${{ needs.build-rosetta-pax.outputs.DOCKER_TAG_FINAL }}
439
499
secrets : inherit
440
500
501
+ # base, jax, maxtext
441
502
test-maxtext :
442
503
needs : build-maxtext
443
- if : inputs.ARCHITECTURE == 'amd64' # no images for arm64
504
+ if : contains(fromJSON('["base", "jax", "maxtext"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
444
505
uses : ./.github/workflows/_test_maxtext.yaml
445
506
with :
446
507
MAXTEXT_IMAGE : ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }}
0 commit comments