-
Notifications
You must be signed in to change notification settings - Fork 512
UBNEXT with optional add-rms fuse #2212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
|
||
#include "transformer_engine.h" | ||
|
||
namespace transformer_engine { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These headers should be compatible with C, but namespace
is C++. We should either make the functions C-compatible and rename them to nvte_allreduce_*
, or we should commit to the C++ API and wrap everything within #ifdef __cplusplus
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This header is public-facing and should be documented. The options are very non-obvious.
ub_name: Optional[str] = None, | ||
delay_wgrad_compute: bool = False, | ||
symmetric_ar_type: Optional[str] = None, | ||
skip_layernorm: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of LayerNormLinear
without LayerNorm?
params_dtype, | ||
) | ||
self.eps = eps | ||
self.layer_norm_weight = ln_weight # in general expected to be filled with reference to layernorm_weight from next LayerNormLinear later |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be impossible to reason about or maintain.
If I understand, we are targetting the following compute pattern:
linear -> fork -> norm -> linear
\> residual branch
Your workflow (I presume using Mcore) implements this with a Linear
and LayerNormLinear
. You are trying to keep these modules, but sneakily move the norm computation in between the modules. This has a massive surface area for bugs (what if someone attempts to access the output of Linear
under the assumption it's the same as the residual branch, what if they use LayerNorm instead of RMSNorm, what if they enable this sneaky interface but accidentally enable sequence parallelism). Even if it works for one narrow use-case, we also have to make it work for real users. Badly designed interfaces and uncaught edge cases will lead them down dead ends and waste their time.
A better approach would be to reorganize the modules. Instead of shoehorning this into linear -> (norm -> linear)
, it would be more natural to have (linear -> norm) -> linear
. Instead of reimplementing another fused module like LayerNormLinear
, this is a good use-case for the op fusion API:
proj = te.ops.Sequential(
te.ops.Linear(...),
te.ops.MakeExtraOutput(),
te.ops.RMSNorm(...),
)
fc1 = te.ops.Sequential(
te.ops.Linear(...),
)
x, residual = proj(x)
x = fc1(x)
For an example, see this Mcore module that fuses the MLP block: https://github.com/NVIDIA/Megatron-LM/blob/275854c92d734e2e373b154ce267fff1d2f1e232/megatron/core/extensions/transformer_engine.py#L1491
Description
Added UBnext fast Allreduce kernels into linear layer.
Falls under symmetric_ar_type with new type being 'ubnext' or 'ubnext_add_rms'
#Details
Added NVLS: simple and low latency (lamport) allreduce kernels which can optionally fuse ADD+RMS in the middle of Allreduce.
Added symmetric allocator which uses pytorch symmetric to allocate pool and suballocate from it.
As pytorch symmetric doesnt support MNNVL yet there is a fallback to use legacy UB code by creating a 11th CommOverlap object. Enabled with env NVTE_USE_UB_FOR_UBNEXT ( requires user to initialize ub by calling initialize_ub)
NVTE_UB_MAXBATCH (default 64) can increase batch size which would have enough memory for fastest kernel. If memory cant be allocated there is gradual fallback: first to UBmain in-place kernel if input could be allocated and output couldnt, and to pytorch symmetric if input couldnt be allocated.
NVTE_UB_SYMM_POOL_SIZE env overrides pool size to given number of megabytes.