Skip to content

Commit 7835895

Browse files
timmoon10deneramingxu1067
authored andcommitted
[JAX] Cherry-pick #785 and #780 (#800)
* [JAX] Fixing CI failure due to incorrect use of `static_argnums` in jax.jit (#785) * fixed static argnums for jax.jit in single gpu encoder test, changed warning filtering for pytest Signed-off-by: Alp Dener <[email protected]> * propagating the fix to the JAX mnist example Signed-off-by: Alp Dener <[email protected]> * fixed missing space ibetween flags i QAA scripts Signed-off-by: Alp Dener <[email protected]> * added TE warnings into the ignore list Signed-off-by: Alp Dener <[email protected]> --------- Signed-off-by: Alp Dener <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [JAX] Allow multi-dims for dgamma and dbeta in LN descriptor. (#780) * Allow multi-dims for dgamma and dbeta in LN descriptor. Signed-off-by: Ming Huang <[email protected]> * Fix the jit error in examples/jax Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Alp Dener <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Alp Dener <[email protected]> Co-authored-by: Ming-Xu Huang <[email protected]>
1 parent 4f5723e commit 7835895

File tree

8 files changed

+91
-53
lines changed

8 files changed

+91
-53
lines changed

examples/jax/encoder/test_single_gpu_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __call__(self, x, mask, disable_dropout=False):
5555
return x
5656

5757

58-
@partial(jax.jit, static_argnums=6)
58+
@partial(jax.jit)
5959
def train_step(state, inputs, masks, labels, var_collect, rngs):
6060
"""Computes gradients, loss and accuracy for a single batch."""
6161

examples/jax/mnist/test_single_gpu_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def loss_fn(var_collect, disable_dropout=False):
7474
return grads, loss, accuracy
7575

7676

77-
@partial(jax.jit, static_argnums=2)
77+
@partial(jax.jit)
7878
def update_model(state, grads):
7979
"""Update model params and FP8 meta."""
8080
state = state.apply_gradients(grads=grads[PARAMS_KEY])

qa/L0_jax_unittest/test.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
set -xe
66

77
: ${TE_PATH:=/opt/transformerengine}
8-
pytest -Wignore -v $TE_PATH/tests/jax -k 'not distributed'
8+
9+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'
910

1011
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
1112
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
1213

13-
pytest -Wignore -v $TE_PATH/examples/jax/mnist
14+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
1415

1516
# Make encoder tests to have run-to-run deterministic to have the stable CI results
1617
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
17-
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
18-
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
18+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
19+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py

qa/L1_jax_distributed_unittest/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
set -xe
66

77
: ${TE_PATH:=/opt/transformerengine}
8-
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*
8+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
99

tests/jax/pytest.ini

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
[pytest]
6+
filterwarnings=
7+
ignore:sharding_type of.*:DeprecationWarning
8+
ignore:major_sharding_type of.*:DeprecationWarning
9+
ignore:Fused attention is not enabled.*:UserWarning
10+
ignore:The hookimpl.*:DeprecationWarning
11+
ignore:xmap is an experimental feature and probably has bugs!
12+
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
13+
ignore:can't resolve package from __spec__ or __package__:ImportWarning
14+
ignore:Using or importing the ABCs.*:DeprecationWarning
15+
ignore:numpy.ufunc size changed
16+
ignore:.*experimental feature
17+
ignore:The distutils.* is deprecated.*:DeprecationWarning
18+
ignore:backend and device argument on jit is deprecated.*:DeprecationWarning
19+
ignore:ml_dtypes.float8_e4m3b11 is deprecated.
20+
ignore:np.find_common_type is deprecated.*:DeprecationWarning
21+
ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning
22+
ignore:The numpy.array_api submodule is still experimental.*:UserWarning
23+
ignore:case not machine-readable.*:UserWarning
24+
ignore:not machine-readable.*:UserWarning
25+
ignore:Special cases found for .* but none were parsed.*:UserWarning
26+
ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning
27+
ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning
28+
ignore:The host_callback APIs are deprecated .*:DeprecationWarning

transformer_engine/jax/cpp_extensions.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,8 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
385385
hidden_size,
386386
wkspace_aval.size,
387387
barrier_aval.size,
388-
0, # no dgamma_part in FWD pass
389-
0, # no dbeta_part in BWD pass
388+
(0,), # no dgamma_part in FWD pass
389+
(0,), # no dbeta_part in BWD pass
390390
jax_dtype_to_te_dtype(x_aval.dtype),
391391
jax_dtype_to_te_dtype(gamma_aval.dtype),
392392
jax_dtype_to_te_dtype(wkspace_aval.dtype),
@@ -464,7 +464,6 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
464464
f"Enforcing no sharding of parameters hidden dim! " \
465465
)
466466

467-
468467
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
469468
g_sharding = NamedSharding(mesh, PartitionSpec(None))
470469
b_sharding = NamedSharding(mesh, PartitionSpec(None))
@@ -589,8 +588,8 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):
589588
hidden_size,
590589
wkspace_aval.size,
591590
barrier_aval.size,
592-
dgamma_part_aval.size,
593-
dbeta_part_aval.size,
591+
dgamma_part_aval.shape,
592+
dbeta_part_aval.shape,
594593
jax_dtype_to_te_dtype(x_aval.dtype),
595594
jax_dtype_to_te_dtype(gamma_aval.dtype),
596595
jax_dtype_to_te_dtype(wkspace_aval.dtype),
@@ -791,8 +790,8 @@ def lowering(ctx, x, gamma, *, epsilon):
791790
hidden_size,
792791
wkspace_aval.size,
793792
barrier_aval.size,
794-
0, # no dgamma_part in FWD pass
795-
0, # no dbeta_part in BWD pass
793+
(0,), # no dgamma_part in FWD pass
794+
(0,), # no dbeta_part in BWD pass
796795
jax_dtype_to_te_dtype(x_aval.dtype),
797796
jax_dtype_to_te_dtype(gamma_aval.dtype),
798797
jax_dtype_to_te_dtype(wkspace_aval.dtype),
@@ -968,8 +967,8 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
968967
hidden_size,
969968
wkspace_aval.size,
970969
barrier_aval.size,
971-
dgamma_part_aval.size,
972-
0, # no dbeta_part for RMSnorm
970+
dgamma_part_aval.shape,
971+
(0,), # no dbeta_part for RMSnorm
973972
jax_dtype_to_te_dtype(x_aval.dtype),
974973
jax_dtype_to_te_dtype(gamma_aval.dtype),
975974
jax_dtype_to_te_dtype(wkspace_aval.dtype),
@@ -3588,8 +3587,8 @@ def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_cen
35883587
hidden_size,
35893588
wkspace_aval.size,
35903589
barrier_aval.size,
3591-
0, # no dgamma_part in FWD pass
3592-
0, # no dbeta_part in BWD pass
3590+
(0,), # no dgamma_part in FWD pass
3591+
(0,), # no dbeta_part in BWD pass
35933592
jax_dtype_to_te_dtype(x_aval.dtype),
35943593
jax_dtype_to_te_dtype(gamma_aval.dtype),
35953594
jax_dtype_to_te_dtype(wkspace_aval.dtype),
@@ -3840,8 +3839,8 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
38403839
hidden_size,
38413840
wkspace_aval.size,
38423841
barrier_aval.size,
3843-
0, # no dgamma_part in FWD pass
3844-
0, # no dbeta_part in BWD pass
3842+
(0,), # no dgamma_part in FWD pass
3843+
(0,), # no dbeta_part in BWD pass
38453844
jax_dtype_to_te_dtype(x_aval.dtype),
38463845
jax_dtype_to_te_dtype(gamma_aval.dtype),
38473846
jax_dtype_to_te_dtype(wkspace_aval.dtype),

transformer_engine/jax/csrc/modules.cpp

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,28 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shap
7171
return PackOpaque(desc);
7272
}
7373

74-
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
75-
size_t wkspace_size, size_t barrier_size,
76-
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
77-
DType x_dtype, DType w_dtype, DType wkspace_dtype,
78-
DType barrier_dtype, DType dgamma_part_dtype,
79-
DType dbeta_part_dtype, bool zero_centered_gamma,
80-
float eps, int sm_margin) {
81-
return PackOpaque(CustomCallNormDescriptor{
82-
batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, dbeta_part_sizes,
83-
x_dtype, w_dtype, wkspace_dtype, barrier_dtype, dgamma_part_dtype, dbeta_part_dtype,
84-
zero_centered_gamma, eps, sm_margin});
74+
pybind11::bytes PackCustomCallNormDescriptor(
75+
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
76+
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
77+
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
78+
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
79+
CustomCallNormDescriptor desc;
80+
desc.batch_size = batch_size;
81+
desc.hidden_size = hidden_size;
82+
desc.wkspace_size = wkspace_size;
83+
desc.barrier_size = barrier_size;
84+
desc.dgamma_part_shape.from_vector(dgamma_part_shape);
85+
desc.dbeta_part_shape.from_vector(dbeta_part_shape);
86+
desc.x_dtype = x_dtype;
87+
desc.w_dtype = w_dtype;
88+
desc.wkspace_dtype = wkspace_dtype;
89+
desc.barrier_dtype = barrier_dtype;
90+
desc.dgamma_part_dtype = dgamma_part_dtype;
91+
desc.dbeta_part_dtype = dbeta_part_dtype;
92+
desc.zero_centered_gamma = zero_centered_gamma;
93+
desc.eps = eps;
94+
desc.sm_margin = sm_margin;
95+
return PackOpaque(desc);
8596
}
8697

8798
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
@@ -529,7 +540,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
529540
}
530541

531542
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
532-
size_t barrier_size, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
543+
size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape,
533544
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
534545
void *weight, DType w_dtype, void *ograd, void *workspace,
535546
DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
@@ -563,14 +574,14 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
563574
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
564575
auto barrier_shape = std::vector<size_t>{barrier_size};
565576
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
566-
auto dgamma_part_shape = std::vector<size_t>{dgamma_part_sizes[0], dgamma_part_sizes[1]};
567-
auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape, dgamma_dtype);
577+
auto dgamma_part_tensor =
578+
TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);
568579

569580
if (is_layer_norm) {
570581
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
571582
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
572-
auto dbeta_part_shape = std::vector<size_t>{dbeta_part_sizes[0], dbeta_part_sizes[1]};
573-
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape, dbeta_dtype);
583+
auto dbeta_part_tensor =
584+
TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);
574585

575586
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
576587
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
@@ -664,8 +675,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
664675
auto hidden_size = desc.hidden_size;
665676
auto wkspace_size = desc.wkspace_size;
666677
auto barrier_size = desc.barrier_size;
667-
auto *dgamma_part_sizes = desc.dgamma_part_sizes;
668-
auto *dbeta_part_sizes = desc.dbeta_part_sizes;
678+
auto dgamma_part_shape = desc.dgamma_part_shape;
679+
auto dbeta_part_shape = desc.dbeta_part_shape;
669680
auto in_dtype = desc.x_dtype;
670681
auto w_dtype = desc.w_dtype;
671682
auto wkspace_dtype = desc.wkspace_dtype;
@@ -689,8 +700,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
689700
auto *dgamma_part = buffers[10];
690701
auto *dbeta_part = buffers[11];
691702

692-
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
693-
dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
703+
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
704+
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
694705
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
695706
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
696707
dbeta_part_dtype, stream);
@@ -786,8 +797,9 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
786797
auto hidden_size = desc.hidden_size;
787798
auto wkspace_size = desc.wkspace_size;
788799
auto barrier_size = desc.barrier_size;
789-
auto dgamma_part_sizes = desc.dgamma_part_sizes;
790-
size_t dbeta_part_sizes[2] = {0, 0};
800+
auto dgamma_part_shape = desc.dgamma_part_shape;
801+
Shape dbeta_part_shape;
802+
dbeta_part_shape.from_vector({0, 0});
791803
auto in_dtype = desc.x_dtype;
792804
auto w_dtype = desc.w_dtype;
793805
auto wkspace_dtype = desc.wkspace_dtype;
@@ -797,8 +809,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
797809
auto eps = desc.eps;
798810
auto zero_centered_gamma = desc.zero_centered_gamma;
799811

800-
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
801-
dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
812+
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
813+
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
802814
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
803815
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
804816
dbeta_part_dtype, stream);

transformer_engine/jax/csrc/modules.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ struct CustomCallNormDescriptor {
6969
size_t hidden_size;
7070
size_t wkspace_size;
7171
size_t barrier_size;
72-
size_t *dgamma_part_sizes; // 2D tensor
73-
size_t *dbeta_part_sizes; // 2D tensor
72+
Shape dgamma_part_shape;
73+
Shape dbeta_part_shape;
7474
DType x_dtype;
7575
DType w_dtype;
7676
DType wkspace_dtype;
@@ -82,13 +82,11 @@ struct CustomCallNormDescriptor {
8282
int sm_margin;
8383
};
8484

85-
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
86-
size_t wkspace_size, size_t barrier_size,
87-
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
88-
DType x_dtype, DType w_dtype, DType wkspace_dtype,
89-
DType barrier_dtype, DType dgamma_part_dtype,
90-
DType dbeta_part_dtype, bool zero_centered_gamma,
91-
float eps, int sm_margin);
85+
pybind11::bytes PackCustomCallNormDescriptor(
86+
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
87+
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
88+
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
89+
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);
9290

9391
struct SoftmaxDescriptor {
9492
size_t batch_size;

0 commit comments

Comments
 (0)