Skip to content

Commit

Permalink
[JAX] Cherry-pick #785 and #780 (#800)
Browse files Browse the repository at this point in the history
* [JAX] Fixing CI failure due to incorrect use of `static_argnums` in jax.jit (#785)

* fixed static argnums for jax.jit in single gpu encoder test, changed warning filtering for pytest

Signed-off-by: Alp Dener <[email protected]>

* propagating the fix to the JAX mnist example

Signed-off-by: Alp Dener <[email protected]>

* fixed missing space ibetween flags i QAA scripts

Signed-off-by: Alp Dener <[email protected]>

* added TE warnings into the ignore list

Signed-off-by: Alp Dener <[email protected]>

---------

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

* [JAX] Allow multi-dims for dgamma and dbeta in LN descriptor. (#780)

* Allow multi-dims for dgamma and dbeta in LN descriptor.

Signed-off-by: Ming Huang <[email protected]>

* Fix the jit error in examples/jax

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Co-authored-by: Alp Dener <[email protected]>
Co-authored-by: Ming-Xu Huang <[email protected]>
  • Loading branch information
3 people authored and ksivaman committed Apr 24, 2024
1 parent 4f5723e commit 7835895
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 53 deletions.
2 changes: 1 addition & 1 deletion examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __call__(self, x, mask, disable_dropout=False):
return x


@partial(jax.jit, static_argnums=6)
@partial(jax.jit)
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""

Expand Down
2 changes: 1 addition & 1 deletion examples/jax/mnist/test_single_gpu_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def loss_fn(var_collect, disable_dropout=False):
return grads, loss, accuracy


@partial(jax.jit, static_argnums=2)
@partial(jax.jit)
def update_model(state, grads):
"""Update model params and FP8 meta."""
state = state.apply_gradients(grads=grads[PARAMS_KEY])
Expand Down
9 changes: 5 additions & 4 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
set -xe

: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax -k 'not distributed'

pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'

pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt

pytest -Wignore -v $TE_PATH/examples/jax/mnist
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist

# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
2 changes: 1 addition & 1 deletion qa/L1_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
set -xe

: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*

28 changes: 28 additions & 0 deletions tests/jax/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

[pytest]
filterwarnings=
ignore:sharding_type of.*:DeprecationWarning
ignore:major_sharding_type of.*:DeprecationWarning
ignore:Fused attention is not enabled.*:UserWarning
ignore:The hookimpl.*:DeprecationWarning
ignore:xmap is an experimental feature and probably has bugs!
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
ignore:can't resolve package from __spec__ or __package__:ImportWarning
ignore:Using or importing the ABCs.*:DeprecationWarning
ignore:numpy.ufunc size changed
ignore:.*experimental feature
ignore:The distutils.* is deprecated.*:DeprecationWarning
ignore:backend and device argument on jit is deprecated.*:DeprecationWarning
ignore:ml_dtypes.float8_e4m3b11 is deprecated.
ignore:np.find_common_type is deprecated.*:DeprecationWarning
ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning
ignore:The numpy.array_api submodule is still experimental.*:UserWarning
ignore:case not machine-readable.*:UserWarning
ignore:not machine-readable.*:UserWarning
ignore:Special cases found for .* but none were parsed.*:UserWarning
ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning
ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning
ignore:The host_callback APIs are deprecated .*:DeprecationWarning
25 changes: 12 additions & 13 deletions transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
Expand Down Expand Up @@ -464,7 +464,6 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
f"Enforcing no sharding of parameters hidden dim! " \
)


x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(None))
Expand Down Expand Up @@ -589,8 +588,8 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.size,
dbeta_part_aval.size,
dgamma_part_aval.shape,
dbeta_part_aval.shape,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
Expand Down Expand Up @@ -791,8 +790,8 @@ def lowering(ctx, x, gamma, *, epsilon):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
Expand Down Expand Up @@ -968,8 +967,8 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.size,
0, # no dbeta_part for RMSnorm
dgamma_part_aval.shape,
(0,), # no dbeta_part for RMSnorm
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
Expand Down Expand Up @@ -3588,8 +3587,8 @@ def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_cen
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
Expand Down Expand Up @@ -3840,8 +3839,8 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
Expand Down
60 changes: 36 additions & 24 deletions transformer_engine/jax/csrc/modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,28 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shap
return PackOpaque(desc);
}

pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype, DType wkspace_dtype,
DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma,
float eps, int sm_margin) {
return PackOpaque(CustomCallNormDescriptor{
batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, dbeta_part_sizes,
x_dtype, w_dtype, wkspace_dtype, barrier_dtype, dgamma_part_dtype, dbeta_part_dtype,
zero_centered_gamma, eps, sm_margin});
pybind11::bytes PackCustomCallNormDescriptor(
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
CustomCallNormDescriptor desc;
desc.batch_size = batch_size;
desc.hidden_size = hidden_size;
desc.wkspace_size = wkspace_size;
desc.barrier_size = barrier_size;
desc.dgamma_part_shape.from_vector(dgamma_part_shape);
desc.dbeta_part_shape.from_vector(dbeta_part_shape);
desc.x_dtype = x_dtype;
desc.w_dtype = w_dtype;
desc.wkspace_dtype = wkspace_dtype;
desc.barrier_dtype = barrier_dtype;
desc.dgamma_part_dtype = dgamma_part_dtype;
desc.dbeta_part_dtype = dbeta_part_dtype;
desc.zero_centered_gamma = zero_centered_gamma;
desc.eps = eps;
desc.sm_margin = sm_margin;
return PackOpaque(desc);
}

pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
Expand Down Expand Up @@ -529,7 +540,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
}

void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
size_t barrier_size, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape,
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
void *weight, DType w_dtype, void *ograd, void *workspace,
DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
Expand Down Expand Up @@ -563,14 +574,14 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto barrier_shape = std::vector<size_t>{barrier_size};
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
auto dgamma_part_shape = std::vector<size_t>{dgamma_part_sizes[0], dgamma_part_sizes[1]};
auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape, dgamma_dtype);
auto dgamma_part_tensor =
TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);

if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
auto dbeta_part_shape = std::vector<size_t>{dbeta_part_sizes[0], dbeta_part_sizes[1]};
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape, dbeta_dtype);
auto dbeta_part_tensor =
TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);

layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
Expand Down Expand Up @@ -664,8 +675,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto *dgamma_part_sizes = desc.dgamma_part_sizes;
auto *dbeta_part_sizes = desc.dbeta_part_sizes;
auto dgamma_part_shape = desc.dgamma_part_shape;
auto dbeta_part_shape = desc.dbeta_part_shape;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
Expand All @@ -689,8 +700,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *dgamma_part = buffers[10];
auto *dbeta_part = buffers[11];

LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream);
Expand Down Expand Up @@ -786,8 +797,9 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto dgamma_part_sizes = desc.dgamma_part_sizes;
size_t dbeta_part_sizes[2] = {0, 0};
auto dgamma_part_shape = desc.dgamma_part_shape;
Shape dbeta_part_shape;
dbeta_part_shape.from_vector({0, 0});
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
Expand All @@ -797,8 +809,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;

LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream);
Expand Down
16 changes: 7 additions & 9 deletions transformer_engine/jax/csrc/modules.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ struct CustomCallNormDescriptor {
size_t hidden_size;
size_t wkspace_size;
size_t barrier_size;
size_t *dgamma_part_sizes; // 2D tensor
size_t *dbeta_part_sizes; // 2D tensor
Shape dgamma_part_shape;
Shape dbeta_part_shape;
DType x_dtype;
DType w_dtype;
DType wkspace_dtype;
Expand All @@ -82,13 +82,11 @@ struct CustomCallNormDescriptor {
int sm_margin;
};

pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype, DType wkspace_dtype,
DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma,
float eps, int sm_margin);
pybind11::bytes PackCustomCallNormDescriptor(
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);

struct SoftmaxDescriptor {
size_t batch_size;
Expand Down

0 comments on commit 7835895

Please sign in to comment.