-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Make fused normalization functions backward-compatible #1760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make fused normalization functions backward-compatible #1760
Conversation
Signed-off-by: Tim Moon <[email protected]>
See NVIDIA/apex#1760. Signed-off-by: Tim Moon <[email protected]>
|
Hi @timmoon10 , Just thought people may not be using the Function directly and forgot about Megatron. I believe it might be best to submit another PR to Megatron-LM in tandem with this one, since I believe Megatron-Deepspeed already have this feature (deepspeedai/Megatron-DeepSpeed#277) and it'd be great if Megatron has it as well. |
|
@RuiWang1998 That's nifty, it'll be convenient to just reuse that existing work. These two approaches aren't mutually exclusive, so I don't think there is any harm to merging. This change won't break the newer code that uses |
* Add distopt support for FP8 params and BF16 optimizer state Signed-off-by: Tim Moon <[email protected]> * Removed unused import Signed-off-by: Tim Moon <[email protected]> * Update PyTorch container in Jenkins pipeline Signed-off-by: Tim Moon <[email protected]> * Use custom container with Apex bugfixes See NVIDIA/apex#1760. Signed-off-by: Tim Moon <[email protected]> * Upgrade to PyTorch 23.11 container Signed-off-by: Tim Moon <[email protected]> * Update Apex commit Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Eric Harper <[email protected]>
…eMo#7909) * Add distopt support for FP8 params and BF16 optimizer state Signed-off-by: Tim Moon <[email protected]> * Removed unused import Signed-off-by: Tim Moon <[email protected]> * Update PyTorch container in Jenkins pipeline Signed-off-by: Tim Moon <[email protected]> * Use custom container with Apex bugfixes See NVIDIA/apex#1760. Signed-off-by: Tim Moon <[email protected]> * Upgrade to PyTorch 23.11 container Signed-off-by: Tim Moon <[email protected]> * Update Apex commit Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Eric Harper <[email protected]>
…eMo#7909) * Add distopt support for FP8 params and BF16 optimizer state Signed-off-by: Tim Moon <[email protected]> * Removed unused import Signed-off-by: Tim Moon <[email protected]> * Update PyTorch container in Jenkins pipeline Signed-off-by: Tim Moon <[email protected]> * Use custom container with Apex bugfixes See NVIDIA/apex#1760. Signed-off-by: Tim Moon <[email protected]> * Upgrade to PyTorch 23.11 container Signed-off-by: Tim Moon <[email protected]> * Update Apex commit Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Eric Harper <[email protected]> Signed-off-by: Sasha Meister <[email protected]>
…eMo#7909) * Add distopt support for FP8 params and BF16 optimizer state Signed-off-by: Tim Moon <[email protected]> * Removed unused import Signed-off-by: Tim Moon <[email protected]> * Update PyTorch container in Jenkins pipeline Signed-off-by: Tim Moon <[email protected]> * Use custom container with Apex bugfixes See NVIDIA/apex#1760. Signed-off-by: Tim Moon <[email protected]> * Upgrade to PyTorch 23.11 container Signed-off-by: Tim Moon <[email protected]> * Update Apex commit Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Eric Harper <[email protected]>
#1715 makes breaking API changes to some fused normalization functions, in particular adding
memory_efficientas a positional argument. This PR makesmemory_efficienta keyword argument to ensure backward compatibility.This change is motivated by the fact that Megatron-LM uses the old API:
https://github.com/NVIDIA/Megatron-LM/blob/2bc6cd307a11423928c675f741e79e03df23e721/megatron/core/fusions/fused_layer_norm.py#L147
This prevents NeMo from upgrading from the 23.09 to 23.11 PyTorch container. See NVIDIA-NeMo/NeMo#7909 (comment).
Feedback would be appreciated. An alternative approach is to update Megatron-LM, but this seems simpler. Pinging @RuiWang1998.