Skip to content

Conversation

nv-akorzh
Copy link

Description

Added UBnext fast Allreduce kernels into linear layer.
Falls under symmetric_ar_type with new type being 'ubnext' or 'ubnext_add_rms'

#Details

Added NVLS: simple and low latency (lamport) allreduce kernels which can optionally fuse ADD+RMS in the middle of Allreduce.

Added symmetric allocator which uses pytorch symmetric to allocate pool and suballocate from it.

As pytorch symmetric doesnt support MNNVL yet there is a fallback to use legacy UB code by creating a 11th CommOverlap object. Enabled with env NVTE_USE_UB_FOR_UBNEXT ( requires user to initialize ub by calling initialize_ub)
NVTE_UB_MAXBATCH (default 64) can increase batch size which would have enough memory for fastest kernel. If memory cant be allocated there is gradual fallback: first to UBmain in-place kernel if input could be allocated and output couldnt, and to pytorch symmetric if input couldnt be allocated.

NVTE_UB_SYMM_POOL_SIZE env overrides pool size to given number of megabytes.


#include "transformer_engine.h"

namespace transformer_engine {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These headers should be compatible with C, but namespace is C++. We should either make the functions C-compatible and rename them to nvte_allreduce_*, or we should commit to the C++ API and wrap everything within #ifdef __cplusplus.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This header is public-facing and should be documented. The options are very non-obvious.

ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
skip_layernorm: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of LayerNormLinear without LayerNorm?

params_dtype,
)
self.eps = eps
self.layer_norm_weight = ln_weight # in general expected to be filled with reference to layernorm_weight from next LayerNormLinear later
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be impossible to reason about or maintain.

If I understand, we are targetting the following compute pattern:

linear -> fork -> norm -> linear
             \> residual branch

Your workflow (I presume using Mcore) implements this with a Linear and LayerNormLinear. You are trying to keep these modules, but sneakily move the norm computation in between the modules. This has a massive surface area for bugs (what if someone attempts to access the output of Linear under the assumption it's the same as the residual branch, what if they use LayerNorm instead of RMSNorm, what if they enable this sneaky interface but accidentally enable sequence parallelism). Even if it works for one narrow use-case, we also have to make it work for real users. Badly designed interfaces and uncaught edge cases will lead them down dead ends and waste their time.

A better approach would be to reorganize the modules. Instead of shoehorning this into linear -> (norm -> linear), it would be more natural to have (linear -> norm) -> linear. Instead of reimplementing another fused module like LayerNormLinear, this is a good use-case for the op fusion API:

proj = te.ops.Sequential(
  te.ops.Linear(...),
  te.ops.MakeExtraOutput(),
  te.ops.RMSNorm(...),
)
fc1 = te.ops.Sequential(
  te.ops.Linear(...),
)

x, residual = proj(x)
x = fc1(x)

For an example, see this Mcore module that fuses the MLP block: https://github.com/NVIDIA/Megatron-LM/blob/275854c92d734e2e373b154ce267fff1d2f1e232/megatron/core/extensions/transformer_engine.py#L1491

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants