diff --git a/README.md b/README.md index e8d598626..a0db1f404 100644 --- a/README.md +++ b/README.md @@ -270,15 +270,7 @@ We currently enable training and evaluation for the following models: We will update this table as new models become available, so stay tuned. ## Environment Variables - -The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning: - -| XLA Flags | Value | Explanation | -| --------- | ----- | ----------- | -| `--xla_gpu_enable_latency_hiding_scheduler` | `true` | allows XLA to move communication collectives to increase overlap with compute kernels | -| `--xla_gpu_enable_async_all_gather` | `true` | allows XLA to run NCCL [AllGather](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#allgather) kernels on a separate CUDA stream to allow overlap with compute kernels | -| `--xla_gpu_enable_async_reduce_scatter` | `true` | allows XLA to run NCCL [ReduceScatter](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#reducescatter) kernels on a separate CUDA stream to allow overlap with compute kernels | -| `--xla_gpu_enable_triton_gemm` | `false` | use cuBLAS instead of Trition GeMM kernels | +The [JAX images](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) are embedded with the following environment variables and XLA flags for performance tuning: | Environment Variable | Value | Explanation | | -------------------- | ----- | ----------- | @@ -286,6 +278,21 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb | `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. | | `CUDA_MODULE_LOADING` | `EAGER` | Disables lazy-loading ([1](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-environment-variables)) which uses slightly more GPU memory. | +XLA flags that tune performance are also set by default in the [JAX images](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax). To view the +the flags currently set, you can inspect the container's environment variables: +```sh +# Update IMAGE to inspect a container of your choosing +IMAGE=ghcr.io/nvidia/jax:jax + +docker run --rm quay.io/skopeo/stable inspect docker://$IMAGE | jq -r '.Env[]' | grep '^XLA_FLAGS=' + +# which returns + +XLA_FLAGS= --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false +``` + +See [GPU performance](./rosetta/docs/GPU_performance.md) for details about these, and other XLA flags, that enable high-performance for LLMs on NVIDIA GPUs. + ## Profiling JAX programs on GPU See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU.