From beedbb5834426ed63bdf4de83d705ff08a43dfa3 Mon Sep 17 00:00:00 2001 From: Anton Korzh Date: Mon, 29 Sep 2025 14:27:14 -0700 Subject: [PATCH 1/3] ubnext allreduce and allreduce fused with add and rmsnorm for low latency inference --- .../distributed/test_fused_linear_comms.py | 244 +++ .../pytorch/distributed/test_linear_comms.py | 352 ++++ transformer_engine/common/CMakeLists.txt.orig | 298 +++ transformer_engine/common/CMakeLists.txt.rej | 12 + .../comm_gemm_overlap/comm_gemm_overlap.cpp | 11 +- .../comm_gemm_overlap.cpp.orig | 1210 +++++++++++ .../comm_gemm_overlap.h.orig | 327 +++ .../comm_gemm_overlap.h.rej | 11 + .../include/transformer_engine/ubnext.h | 31 + .../common/libtransformer_engine.version | 3 +- transformer_engine/common/ubnext.cu | 608 ++++++ .../common/util/pybind_helper.h | 33 +- .../common/util/pybind_helper.h.orig | 140 ++ .../pytorch/cpp_extensions/__init__.py | 1 + .../pytorch/cpp_extensions/symm_allocator.py | 389 ++++ .../csrc/extensions/comm_gemm_overlap.cpp | 8 +- .../extensions/comm_gemm_overlap.cpp.orig | 320 +++ transformer_engine/pytorch/module/base.py | 11 +- .../pytorch/module/base.py.orig | 1597 ++++++++++++++ .../pytorch/module/layernorm_linear.py | 76 +- .../pytorch/module/layernorm_linear.py.orig | 1827 +++++++++++++++++ transformer_engine/pytorch/module/linear.py | 48 +- .../pytorch/module/linear.py.orig | 1710 +++++++++++++++ 23 files changed, 9238 insertions(+), 29 deletions(-) create mode 100644 tests/pytorch/distributed/test_fused_linear_comms.py create mode 100644 tests/pytorch/distributed/test_linear_comms.py create mode 100644 transformer_engine/common/CMakeLists.txt.orig create mode 100644 transformer_engine/common/CMakeLists.txt.rej create mode 100644 transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp.orig create mode 100644 transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.orig create mode 100644 transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.rej create mode 100644 transformer_engine/common/include/transformer_engine/ubnext.h create mode 100644 transformer_engine/common/ubnext.cu create mode 100644 transformer_engine/common/util/pybind_helper.h.orig create mode 100644 transformer_engine/pytorch/cpp_extensions/symm_allocator.py create mode 100644 transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp.orig create mode 100644 transformer_engine/pytorch/module/base.py.orig create mode 100644 transformer_engine/pytorch/module/layernorm_linear.py.orig create mode 100644 transformer_engine/pytorch/module/linear.py.orig diff --git a/tests/pytorch/distributed/test_fused_linear_comms.py b/tests/pytorch/distributed/test_fused_linear_comms.py new file mode 100644 index 0000000000..7b230ea4f9 --- /dev/null +++ b/tests/pytorch/distributed/test_fused_linear_comms.py @@ -0,0 +1,244 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +import torch.distributed._symmetric_memory as symm_mem +import time +import argparse +import os +import uuid +import math + + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser( + description=( + "Run a linear layer with Transformer Engine, CUDA Graphs, and Tensor Parallelism" + ) + ) + parser.add_argument("--hidden_size", type=int, default=8192, help="Input feature size") + parser.add_argument("--batch_size", type=int, default=2048, help="Batch size") + parser.add_argument("--fc_factor", type=int, default=4, help="MLP hidden layer factor") + parser.add_argument( + "--cuda_graph", action="store_true", help="Use CUDA Graphs (pass this flag to enable)" + ) + parser.add_argument("--validate", action="store_true", help="Validate allreduce ubnext") + parser.add_argument("--comm_type", type=str, default="sym", help="Comm type: none,nccl,sym,ub,ubnext,ubnext_add,ubnext_rms") + parser.add_argument( + "--sym_type", + type=str, + default="multimem_all_reduce", + help="pytorch sym type: one_shot, two_shot, multimem_all_reduce", + ) + parser.add_argument("--iterations", type=int, default=1000, help="Number of iterations") + parser.add_argument( + "--tp_size", + type=int, + default=None, + help="Tensor parallelism size (defaults to number of GPUs)", + ) + parser.add_argument("--eps", type=float, default=1e-5, help="Epsilon") + args = parser.parse_args() + + # Check CUDA availability and get device count + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Test requires NVIDIA GPUs.") + + num_devices = torch.cuda.device_count() + if num_devices == 0: + raise RuntimeError("No CUDA devices found.") + + # Set tensor parallelism size + tp_size = ( + args.tp_size if args.tp_size is not None else int(os.environ.get("WORLD_SIZE", num_devices)) + ) + + # Initialize distributed environment for each GPU + myrank = int(os.environ.get("RANK", 0)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + num_nodes = world_size // local_size + if num_nodes > 1: + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + # Set device + device = torch.device(f"cuda:{local_rank}") + # Initialize torch.distributed for tensor parallelism + # Only set defaults if not already set by torchrun + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "29500" + torch.cuda.set_device(device) + + torch.distributed.init_process_group( + backend="nccl", world_size=tp_size, rank=myrank % tp_size, device_id=device + ) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + # Transformer Engine handles tensor parallelism internally when distributed is initialized + # Set environment variable for tensor parallelism size + os.environ["NVTE_TP_SIZE"] = str(tp_size) + + ub_cfgs = { + "proj_fprop": { + "method": "pipeline", + "num_splits": 4, + "is_reduce_scatter": True, + "num_sm": 32, + "atomic_gemm": False, + "aggregate": False, + "cga_size": 4, + "set_sm_margin": True, + "fp8_buf": False, + "use_ce": False, + }, + "fc1_fprop": { + "method": "ring_exchange", + "num_splits": 1, + "is_reduce_scatter": False, + "num_sm": 1, + "atomic_gemm": False, + "aggregate": False, + "cga_size": 1, + "set_sm_margin": False, + "fp8_buf": False, + "use_ce": True, + }, + } + + # Initialize model with BF16 precision + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT") or args.comm_type == "ub": + te.module.base.initialize_ub( + [args.batch_size, args.hidden_size], + tp_size, + use_fp8=False, + dtype=torch.bfloat16, + bootstrap_backend="nccl", + ub_cfgs=ub_cfgs, + ) + + proj = te.Linear( + in_features=args.hidden_size // tp_size if args.comm_type == "none" else args.hidden_size, + out_features=args.hidden_size, + bias=False, + device=device, + params_dtype=torch.bfloat16, + tp_size=tp_size if args.comm_type != "none" else 1, + parallel_mode="row" if args.comm_type != "none" else None, + tp_group=torch.distributed.group.WORLD if args.comm_type != "none" else None, + symmetric_ar_type=args.sym_type if args.comm_type == "sym" else args.comm_type, + sequence_parallel=args.comm_type == "ub", + ub_overlap_rs=args.comm_type == "ub", + ub_name="proj" if args.comm_type == "ub" else None, + eps=args.eps if args.comm_type == "ubnext_add_rms" else None, + ) + + fc1 = te.LayerNormLinear( + in_features=args.hidden_size, + out_features=args.hidden_size*args.fc_factor//tp_size if args.comm_type == "none" else args.hidden_size*args.fc_factor, + bias=False, + device=device, + params_dtype=torch.bfloat16, + eps=args.eps, + zero_centered_gamma=False, + normalization="RMSNorm", + tp_size=tp_size if args.comm_type != "none" else 1, + parallel_mode="column" if args.comm_type != "none" else None, + tp_group=torch.distributed.group.WORLD if args.comm_type != "none" else None, + skip_layernorm=args.comm_type == "ubnext_add_rms", + sequence_parallel=args.comm_type == "ub", + ub_overlap_ag=args.comm_type == "ub", + ub_name="fc1" if args.comm_type == "ub" else None, + ) + + if args.comm_type == "ubnext_add_rms": + proj.layer_norm_weight = fc1.layer_norm_weight.data + # Create CUDA stream + stream = torch.cuda.Stream() + # Check for environment variable to override pool size + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + for logbatch in range(int(math.log2(args.batch_size)) + 1): + batch = 2**logbatch + if args.comm_type == "ub":# and batch < tp_size: + batch = args.batch_size#tp_size + # Create input tensor + torch.cuda.synchronize() + torch.distributed.barrier(group=torch.distributed.group.WORLD) + residual = torch.randn(batch//tp_size if args.comm_type == "ub" else batch, args.hidden_size, device=device, dtype=torch.bfloat16) + inp = torch.randn(batch, args.hidden_size // tp_size, device=device, dtype=torch.bfloat16) + + # Warm-up run + if not args.comm_type.startswith("ubnext_add"): + out_proj=proj(inp) + out_proj.add_(residual) + out=fc1(out_proj) + else: + out=fc1(proj(inp,residual=residual)) # this also allocates distributed internal residual + + torch.cuda.synchronize() + if args.cuda_graph: + with torch.cuda.stream(stream): + # Create CUDA Graph + graph = torch.cuda.CUDAGraph() + + with torch.cuda.graph(graph): + if not args.comm_type.startswith("ubnext_add"): + out_proj=proj(inp) + out_proj.add_(residual) + out=fc1(out_proj) + else: + out=fc1(proj(inp)) + + # Warm-up the graph + for _ in range(5): + graph.replay() + + torch.cuda.synchronize() + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + # Measure time for forward passes + start_time = time.time() + with torch.cuda.stream(stream): + for _ in range(args.iterations): + if args.cuda_graph: + graph.replay() + else: + if not args.comm_type.startswith("ubnext_add"): + out_proj=proj(inp) + out_proj.add_(residual) + out=fc1(out_proj) + else: + out=fc1(proj(inp)) + + torch.cuda.synchronize() + end_time = time.time() + elapsed = end_time - start_time + + # Calculate and print elapsed time (only on rank 0) + if myrank == 0: + print( + f"Batch{batch},{(elapsed/ args.iterations) * 1e6:.4f}" + ) + if args.cuda_graph: + # needed or NCCL would hang + del graph + + # Cleanup + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + # Generate a unique run ID for distributed initialization + os.environ["RUN_ID"] = str(uuid.uuid4()) + main() diff --git a/tests/pytorch/distributed/test_linear_comms.py b/tests/pytorch/distributed/test_linear_comms.py new file mode 100644 index 0000000000..2e7e9c0bf6 --- /dev/null +++ b/tests/pytorch/distributed/test_linear_comms.py @@ -0,0 +1,352 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +import torch.distributed._symmetric_memory as symm_mem +import time +import argparse +import os +import uuid +import math + + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser( + description=( + "Run a linear layer with Transformer Engine, CUDA Graphs, and Tensor Parallelism" + ) + ) + parser.add_argument("--in_features", type=int, default=8192, help="Input feature size") + parser.add_argument("--out_features", type=int, default=8192, help="Output feature size") + parser.add_argument("--batch_size", type=int, default=2048, help="Batch size") + parser.add_argument( + "--cuda_graph", action="store_true", help="Use CUDA Graphs (pass this flag to enable)" + ) + parser.add_argument("--validate", action="store_true", help="Validate allreduce ubnext") + parser.add_argument("--comm_type", type=str, default="sym", help="Comm type: nccl,sym,ub") + parser.add_argument( + "--sym_type", + type=str, + default="multimem_all_reduce", + help="sym type: one_shot, two_shot, multimem_all_reduce, ubnext", + ) + parser.add_argument("--iterations", type=int, default=1000, help="Number of iterations") + parser.add_argument( + "--tp_size", + type=int, + default=None, + help="Tensor parallelism size (defaults to number of GPUs)", + ) + parser.add_argument("--eps", type=float, default=1e-5, help="Epsilon") + parser.add_argument("--rmsnorm", action="store_true", help="Use RMSNorm") + args = parser.parse_args() + + # Check CUDA availability and get device count + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Test requires NVIDIA GPUs.") + + num_devices = torch.cuda.device_count() + if num_devices == 0: + raise RuntimeError("No CUDA devices found.") + + # Set tensor parallelism size + tp_size = ( + args.tp_size if args.tp_size is not None else int(os.environ.get("WORLD_SIZE", num_devices)) + ) + + # Initialize distributed environment for each GPU + myrank = int(os.environ.get("RANK", 0)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + num_nodes = world_size // local_size + if num_nodes > 1: + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + # Set device + device = torch.device(f"cuda:{local_rank}") + # Initialize torch.distributed for tensor parallelism + # Only set defaults if not already set by torchrun + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "29500" + torch.cuda.set_device(device) + + torch.distributed.init_process_group( + backend="nccl", world_size=tp_size, rank=myrank % tp_size, device_id=device + ) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + # Transformer Engine handles tensor parallelism internally when distributed is initialized + # Set environment variable for tensor parallelism size + os.environ["NVTE_TP_SIZE"] = str(tp_size) + + ub_cfgs = { + "proj_fprop": { + "method": "pipeline", + "num_splits": 1, + "is_reduce_scatter": True, + "num_sm": 32, + "atomic_gemm": False, + "aggregate": False, + "cga_size": 4, + "set_sm_margin": False, + "fp8_buf": False, + "use_ce": False, + } + } + + # Initialize model with BF16 precision + + modelseq = te.Linear( + in_features=int(args.in_features / tp_size), + out_features=args.out_features, + bias=False, + device=device, + params_dtype=torch.bfloat16, + ) + + modelnorm = te.RMSNorm( + normalized_shape=int(args.out_features), + eps=args.eps, + device=device, + dtype=torch.bfloat16, + zero_centered_gamma=False, + ) + residual = torch.randn(args.batch_size, args.out_features, device=device, dtype=torch.bfloat16) if args.rmsnorm else None + + ln_weight = modelnorm.weight.data if args.rmsnorm else None + if ( + args.comm_type == "sym" and os.environ.get("NVTE_USE_UB_FOR_UBNEXT") + ) or args.comm_type == "ub": + te.module.base.initialize_ub( + [args.batch_size, args.out_features], + tp_size, + use_fp8=False, + dtype=torch.bfloat16, + bootstrap_backend="nccl", + ub_cfgs=ub_cfgs, + ) + + modelpar = None + + if args.comm_type == "sym" or args.comm_type == "nccl": + modelpar = te.Linear( + in_features=args.in_features, + out_features=args.out_features, + bias=False, + device=device, + params_dtype=torch.bfloat16, + tp_size=tp_size, + parallel_mode="row", + tp_group=torch.distributed.group.WORLD, + symmetric_ar_type=args.sym_type if args.comm_type == "sym" else None, + eps=args.eps, + ln_weight=ln_weight, + ) + + if args.comm_type == "ub": + modelpar = te.Linear( + in_features=args.in_features, + out_features=args.out_features, + bias=False, + device=device, + params_dtype=torch.bfloat16, + tp_size=tp_size, + parallel_mode="row", + tp_group=torch.distributed.group.WORLD, + sequence_parallel=True, + ub_overlap_rs=True, + ub_name="proj", + eps=args.eps, + ln_weight=ln_weight, + ) + + # Create CUDA stream + stream = torch.cuda.Stream() + # Check for environment variable to override pool size + + allocator = None + if args.comm_type == "sym" and args.validate: + pool_size = int(os.environ.get("NVTE_UB_SYMM_POOL_SIZE", 64)) * 1024 * 1024 + allocator = te.cpp_extensions.symm_allocator.SymmAllocator( + pool_size, torch.device(device), torch.distributed.group.WORLD + ) + + # Run tensor comparison tests only for symmetric communication + if args.comm_type == "sym" and args.validate: + + if args.rmsnorm: + torch.manual_seed(57) + torch.cuda.manual_seed(57) + residual = torch.randn(1, args.out_features, dtype=torch.bfloat16, device=device) + t = allocator.create_tensor((1,args.out_features,), dtype=torch.bfloat16) + #te.cpp_extensions.symm_allocator.ubsymm_free_residual(t) + t.fill_(myrank) + t_in = t.clone() + torch.distributed.all_reduce(t_in) + t_in.add_(residual) + out1=modelnorm(t_in) + out2 = allocator.allreduce_simple(t,hidden_size=args.out_features,residual_in=residual,residual_out=residual,fuse_layernorm=True,eps=args.eps,gamma=modelnorm.weight.data) + abs_diff = torch.abs(out1 - out2) + max_delta = torch.max(abs_diff).item() + num_different = torch.sum(out1 != out2).item() + print(f"FUSED RMSNorm Max delta: {max_delta}, num different: {num_different}") + if(myrank== 0): + print(f"gamma: {modelnorm.weight.data}") + print(f"FUSED RMSNorm output: {out1}") + print(f"FUSED RMSNorm output: {out2}") + + # Test different tensor sizes from 64 to 1024*1024 elements + all_max_deltas = [] + all_num_different = [] + all_total_elements = [] + all_sizes = [] + + size = 64 + while size <= 1024 * 1024: + # Allocate tensors + t = allocator.create_tensor((size,), dtype=torch.bfloat16) + t.fill_(0) + t += torch.randn_like(t) # Add random noise to each element + tmain = t.clone() # Create a copy since allreduce operates in-place + torch.distributed.all_reduce(tmain) + tlamport = allocator.allreduce_lamport(t) + + # Compare the two tensors + abs_diff = torch.abs(tlamport - tmain) + max_delta = torch.max(abs_diff).item() + num_different = torch.sum(tlamport != tmain).item() + + # Store statistics + all_max_deltas.append(max_delta) + all_num_different.append(num_different) + all_total_elements.append(tlamport.numel()) + all_sizes.append(size) + + # Free tensor (memory returned to pool) + del t, tlamport, tmain, abs_diff + + # Double the size for next iteration + size *= 2 + + # Print summary statistics + if myrank == 0: + print("\n=== Tensor Comparison Summary ===") + total_elements_tested = sum(all_total_elements) + total_different_elements = sum(all_num_different) + overall_max_delta = max(all_max_deltas) + + print( + f"Tested sizes: {len(all_sizes)} different tensor sizes from {all_sizes[0]} to" + f" {all_sizes[-1]} elements" + ) + print(f"Total elements tested: {total_elements_tested}") + print(f"Total different elements: {total_different_elements}") + print( + "Overall error rate:" + f" {(total_different_elements / total_elements_tested) * 100:.6f}%" + ) + print(f"Maximum delta across all tests: {overall_max_delta}") + + if total_different_elements > 0 or overall_max_delta > 0: + print("\nPer-size breakdown:") + for i, size in enumerate(all_sizes): + error_rate = (all_num_different[i] / all_total_elements[i]) * 100 + print( + f" Size {size:7d}:" + f" {all_num_different[i]:6d}/{all_total_elements[i]:7d} different" + f" ({error_rate:6.3f}%), max_delta: {all_max_deltas[i]:.6f}" + ) + print("================================\n") + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + for logbatch in range(int(math.log2(args.batch_size)) + 1): + batch = 2**logbatch + if args.comm_type == "ub" and batch < tp_size: + batch = tp_size + # Create input tensor + inp = torch.randn( + batch, int(args.in_features / tp_size), device=device, dtype=torch.bfloat16 + ) + # Warm-up run + out=modelseq(inp) + modelnorm(out) + modelpar(inp,residual=residual) + torch.cuda.synchronize() + if args.cuda_graph: + with torch.cuda.stream(stream): + # Create CUDA Graph + gseq = torch.cuda.CUDAGraph() + gpar = torch.cuda.CUDAGraph() + with torch.cuda.graph(gseq): + output = modelseq(inp) + if args.rmsnorm: + output.add_(residual[:batch,:args.out_features]) + output=modelnorm(output) + with torch.cuda.graph(gpar): + output = modelpar(inp,residual=residual) + # Warm-up the graph + for _ in range(5): + gseq.replay() + gpar.replay() + torch.cuda.synchronize() + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + # Measure time for forward passes + start_time = time.time() + with torch.cuda.stream(stream): + for _ in range(args.iterations): + if args.cuda_graph: + gseq.replay() + else: + modelseq(inp) + + torch.cuda.synchronize() + end_time = time.time() + seq_elapsed = end_time - start_time + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + # Measure time for forward passes + start_time = time.time() + with torch.cuda.stream(stream): + for _ in range(args.iterations): + if args.cuda_graph: + gpar.replay() + else: + modelpar(inp) + + torch.cuda.synchronize() + end_time = time.time() + par_elapsed = end_time - start_time + nccl_elapsed = par_elapsed - seq_elapsed + # Calculate and print elapsed time (only on rank 0) + if myrank == 0: + print( + f"Batch{batch},{(seq_elapsed/ args.iterations) * 1e6:.4f}us,{(par_elapsed/ args.iterations) * 1e6:.4f} us,{(nccl_elapsed/ args.iterations) * 1e6:.4f}" + ) + if args.cuda_graph: + # needed or NCCL would hang + del gseq, gpar + + # Cleanup + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + # Generate a unique run ID for distributed initialization + os.environ["RUN_ID"] = str(uuid.uuid4()) + main() diff --git a/transformer_engine/common/CMakeLists.txt.orig b/transformer_engine/common/CMakeLists.txt.orig new file mode 100644 index 0000000000..a4915080e8 --- /dev/null +++ b/transformer_engine/common/CMakeLists.txt.orig @@ -0,0 +1,298 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +cmake_minimum_required(VERSION 3.21) + +# Language options +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) + endif() +endif() +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) +if (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G") +endif() + +# Hide non-necessary symbols in shared object. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") + +# Transformer Engine library +project(transformer_engine LANGUAGES CUDA CXX) + +# CUDA Toolkit +find_package(CUDAToolkit REQUIRED) +if (CUDAToolkit_VERSION VERSION_LESS 12.0) + message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") +endif() + +# cuDNN frontend API +set(CUDNN_FRONTEND_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") +if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. " + "Try running 'git submodule update --init --recursive' " + "within the Transformer Engine source.") +endif() +include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) + +set(CUTLASS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/include") +set(CUTLASS_TOOLS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include") + +# Python +find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) + +# NVIDIA MathDX include directory (from Python package install location) +if(NOT DEFINED MATHDX_INCLUDE_DIR) + execute_process( + COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx + OUTPUT_VARIABLE _PIP_SHOW_MATHDX + ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR + RESULT_VARIABLE _PIP_SHOW_MATHDX_RES + OUTPUT_STRIP_TRAILING_WHITESPACE) + if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0) + message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}") + endif() + string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}") + if(NOT _MATHDX_LOC_MATCH) + message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}") + endif() + set(MATHDX_LOCATION "${CMAKE_MATCH_1}") + set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include") +endif() +if(NOT EXISTS "${MATHDX_INCLUDE_DIR}") + message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.") +endif() + +# Configure Transformer Engine library +include_directories(${PROJECT_SOURCE_DIR}/..) +set(transformer_engine_SOURCES) +list(APPEND transformer_engine_SOURCES + cudnn_utils.cpp + transformer_engine.cpp + common.cu + multi_tensor/adam.cu + multi_tensor/compute_scale.cu + multi_tensor/l2norm.cu + multi_tensor/scale.cu + multi_tensor/sgd.cu + transpose/cast_transpose.cu + transpose/transpose.cu + transpose/cast_transpose_fusion.cu + transpose/transpose_fusion.cu + transpose/multi_cast_transpose.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise.cu + transpose/swap_first_dims.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu + activation/gelu.cu + dropout/dropout.cu + fused_attn/flash_attn.cu + fused_attn/context_parallel.cu + fused_attn/kv_cache.cu + fused_attn/fused_attn_f16_max512_seqlen.cu + fused_attn/fused_attn_f16_arbitrary_seqlen.cu + activation/relu.cu + activation/swiglu.cu + fused_attn/fused_attn_fp8.cu + fused_attn/fused_attn.cpp + fused_attn/utils.cu + gemm/config.cpp + gemm/cublaslt_gemm.cu + gemm/cutlass_grouped_gemm.cu + normalization/common.cpp + normalization/layernorm/ln_api.cpp + normalization/layernorm/ln_bwd_semi_cuda_kernel.cu + normalization/layernorm/ln_fwd_cuda_kernel.cu + normalization/rmsnorm/rmsnorm_api.cpp + normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu + normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu + permutation/permutation.cu + util/cast.cu + util/padding.cu + util/cuda_driver.cpp + util/cuda_nvml.cpp + util/cuda_runtime.cpp + util/multi_stream.cpp + util/rtc.cpp + swizzle/swizzle.cu + fused_softmax/scaled_masked_softmax.cu + fused_softmax/scaled_upper_triang_masked_softmax.cu + fused_softmax/scaled_aligned_causal_masked_softmax.cu + fused_rope/fused_rope.cu + fused_router/fused_moe_aux_loss.cu + fused_router/fused_score_for_moe_aux_loss.cu + fused_router/fused_topk_with_score_function.cu + recipe/current_scaling.cu + recipe/delayed_scaling.cu + recipe/fp8_block_scaling.cu + recipe/nvfp4.cu + hadamard_transform/hadamard_transform.cu + hadamard_transform/hadamard_transform_cast_fusion.cu + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/userbuffers/userbuffers.cu + comm_gemm_overlap/comm_gemm_overlap.cpp) + +if (NVTE_WITH_CUBLASMP) +list(APPEND transformer_engine_SOURCES + comm_gemm/comm_gemm.cpp) +endif() + +add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) +target_include_directories(transformer_engine PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}/include") + +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + "gemm/cutlass_grouped_gemm.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") +else() + message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") +endif() + +# Configure dependencies +target_link_libraries(transformer_engine PUBLIC + CUDA::cublas + CUDA::cudart + CUDNN::cudnn_all) + +target_include_directories(transformer_engine PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR}) +target_include_directories(transformer_engine SYSTEM PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) +target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +target_include_directories(transformer_engine PRIVATE + ${CUTLASS_INCLUDE_DIR} + ${CUTLASS_TOOLS_INCLUDE_DIR}) + +# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI +option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) +if (NVTE_UB_WITH_MPI) + find_package(MPI REQUIRED) + target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) + target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) + target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) +endif() + +option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF) +if (NVTE_ENABLE_NVSHMEM) + add_subdirectory(nvshmem_api) + target_link_libraries(transformer_engine PUBLIC nvshmemapi) + target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) +endif() + +option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) +if (NVTE_WITH_CUBLASMP) + target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) + target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) + find_library(CUBLASMP_LIB + NAMES cublasmp libcublasmp + PATHS ${CUBLASMP_DIR} + PATH_SUFFIXES lib + REQUIRED) + find_library(NVSHMEM_HOST_LIB + NAMES nvshmem_host libnvshmem_host.so.3 + PATHS ${NVSHMEM_DIR} + PATH_SUFFIXES lib + REQUIRED) + target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) + message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") + message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") +endif() + +# Hack to enable dynamic loading in cuDNN frontend +target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) + +# Helper functions to make header files with C++ strings +function(make_string_header STRING STRING_NAME) + configure_file(util/string_header.h.in + "string_headers/${STRING_NAME}.h" + @ONLY) +endfunction() +function(make_string_header_from_file file_ STRING_NAME) + file(READ "${file_}" STRING) + configure_file(util/string_header.h.in + "string_headers/${STRING_NAME}.h" + @ONLY) +endfunction() + +# Header files with C++ strings +list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path) +make_string_header("${cuda_include_path}" + string_path_cuda_include) +make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu + string_code_transpose_rtc_cast_transpose_fusion_cu) +make_string_header_from_file(transpose/rtc/cast_transpose.cu + string_code_transpose_rtc_cast_transpose_cu) +make_string_header_from_file(transpose/rtc/transpose.cu + string_code_transpose_rtc_transpose_cu) +make_string_header_from_file(transpose/rtc/swap_first_dims.cu + string_code_transpose_rtc_swap_first_dims_cu) +make_string_header_from_file(utils.cuh + string_code_utils_cuh) +make_string_header_from_file(util/math.h + string_code_util_math_h) +target_include_directories(transformer_engine PRIVATE + "${CMAKE_CURRENT_BINARY_DIR}/string_headers") + +# Compiler options +set_source_files_properties(fused_softmax/scaled_masked_softmax.cu + fused_softmax/scaled_upper_triang_masked_softmax.cu + fused_softmax/scaled_aligned_causal_masked_softmax.cu + multi_tensor/adam.cu + multi_tensor/compute_scale.cu + multi_tensor/l2norm.cu + multi_tensor/scale.cu + multi_tensor/sgd.cu + fused_attn/flash_attn.cu + fused_attn/context_parallel.cu + fused_attn/kv_cache.cu + PROPERTIES + COMPILE_OPTIONS "--use_fast_math") +option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) +if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) + set_source_files_properties(activation/gelu.cu + activation/relu.cu + activation/swiglu.cu + util/cast.cu + PROPERTIES + COMPILE_OPTIONS "--use_fast_math") +endif() +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") + +# Number of parallel build jobs +if(ENV{MAX_JOBS}) + set(BUILD_JOBS_STR "$ENV{MAX_JOBS}") +elseif(ENV{NVTE_BUILD_MAX_JOBS}) + set(BUILD_JOBS_STR "$ENV{NVTE_BUILD_MAX_JOBS}") +else() + set(BUILD_JOBS_STR "max") +endif() +message(STATUS "Parallel build jobs: ${BUILD_JOBS_STR}") + +# Number of threads per parallel build job +set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB}) +if (NOT BUILD_THREADS_PER_JOB) + set(BUILD_THREADS_PER_JOB 1) +endif() +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") +message(STATUS "Threads per parallel build job: ${BUILD_THREADS_PER_JOB}") + +# Install library +install(TARGETS transformer_engine DESTINATION .) diff --git a/transformer_engine/common/CMakeLists.txt.rej b/transformer_engine/common/CMakeLists.txt.rej new file mode 100644 index 0000000000..faade11dac --- /dev/null +++ b/transformer_engine/common/CMakeLists.txt.rej @@ -0,0 +1,12 @@ +--- transformer_engine/common/CMakeLists.txt ++++ transformer_engine/common/CMakeLists.txt +@@ -109,7 +109,8 @@ list(APPEND transformer_engine_SOURCES + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/userbuffers/userbuffers.cu +- comm_gemm_overlap/comm_gemm_overlap.cpp) ++ comm_gemm_overlap/comm_gemm_overlap.cpp ++ ubnext.cu) + add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) + target_include_directories(transformer_engine PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}/include") diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 56369db27f..5dbc0c54e9 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -488,7 +488,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons _ub_comm->cga_size = _cga_size; size_t m = transa ? A.size(0) : A.size(1); size_t k = transa ? A.size(1) : A.size(0); - size_t n = _ubuf.size(0); + size_t n = B.size(0); size_t m_chunk = m / _num_splits; const std::vector input_a_chunk_shape = (transa ? std::vector{m_chunk, k} : std::vector{k, m_chunk}); @@ -640,6 +640,15 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); } +uintptr_t CommOverlapBase::init_ubnext() { + NVTE_CHECK_CUDA(cudaMemset(_ub_comm->mem_ptr[_ub_reg], 0, _ub_comm->mem_size[_ub_reg])); + NVTE_CHECK_CUDA(cudaMemcpy(_ub_comm->mem_ptr[_ub_reg], + (reinterpret_cast(_ub_comm->mem_ptr[0])) + + (_ub_reg * _ub_comm->nvsize * sizeof(void *)), + _ub_comm->nvsize * sizeof(void *), cudaMemcpyDeviceToDevice)); + return (uintptr_t)(_ub_comm->mc_ptr[_ub_reg]); +} + /*************************************************************************************************** * Comm+GEMM Overlap P2P Base (Ring-Exchange) **************************************************************************************************/ diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp.orig b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp.orig new file mode 100644 index 0000000000..56369db27f --- /dev/null +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp.orig @@ -0,0 +1,1210 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include + +#include "common/common.h" +#include "common/util/cuda_driver.h" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" +#include "common/util/system.h" +#include "userbuffers/userbuffers.h" + +#define HALF_BYTES 2 +#define UB_MAX_SM 32 + +using namespace std::placeholders; + +namespace transformer_engine { + +namespace { + +std::vector shape_to_vector(const NVTEShape &shape) { + return std::vector(shape.data, shape.data + shape.ndim); +} + +} // namespace + +/*************************************************************************************************** + * Comm+GEMM Overlap Common Core + **************************************************************************************************/ + +bool ubuf_built_with_mpi() { +#ifdef NVTE_UB_WITH_MPI + return true; +#else + return false; +#endif +} + +CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, int tp_size, ExtAllgatherOp allgather_handle, + ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { + // Initialize userbuf communicator + if (!_comm_created) { + if (myrank == 0) { + printf("!!! [UB] Create Userbuffers Communicator\n"); + } +#ifdef NVTE_UB_WITH_MPI + create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); +#else + create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + allgather_handle, barrier_handle, 1, 1, tp_size, 1); +#endif + _comm_created = true; + } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { + _use_ce = static_cast(use_ce); + _num_comm_sm = num_comm_sm; + _cga_size = comm_cga_size; + + if (gemm_priority == 0 && comm_priority == 0) { + transformer_engine::cuda::stream_priority_range(&_gemm_priority, &_comm_priority); + } else { + _gemm_priority = gemm_priority; + _comm_priority = comm_priority; + } + for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority)); + _stream_compute.push_back(std::move(stream)); + } + + _num_splits = num_splits; + _rank = _ub_comm->myrank; + _tp_size = tp_size; + _tp_id = _rank % _tp_size; + + // Set the number of SMs for GEMM with margin + int sm_count = transformer_engine::cuda::sm_count(); + _math_sms = (set_sm_margin) ? sm_count - num_comm_sm : sm_count; + _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); + + _atomic_gemm = atomic_gemm; + if (_atomic_gemm) { + void *counter_ptr; + size_t counter_bytes = _num_splits * 2 * sizeof(int32_t); + NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes)); + NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes)); + NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2)); + _counter = TensorWrapper(counter_ptr, std::vector{static_cast(_num_splits * 2)}, + DType::kInt32); + } + // CUDA event creation + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_compute, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_compute, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_comm, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_comm, 0)); + + /* + Defining the launcher order between the communication and GEMM kernels + using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1. + The event is used to schedule the communication kernel before the GEMM. + This is needed only for Hopper, which uses persistent CTA execution. + */ + int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); + int runtime_version = 0; + NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&runtime_version)); + cudaDeviceProp deviceProp; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, 0)); + if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming)); + } else { + _comm_launch_event = 0; + } +} + +CommOverlapCore::~CommOverlapCore() { + cudaEventDestroy(_stop_comm); + cudaEventDestroy(_start_comm); + cudaEventDestroy(_stop_compute); + cudaEventDestroy(_start_compute); + if (_comm_launch_event) { + cudaEventDestroy(_comm_launch_event); + } + + if (_atomic_gemm) { + cudaFree(_counter.dptr()); + } + + for (size_t i = 0; i < _stream_compute.size(); i++) { + cudaStreamSynchronize(_stream_compute[i]); + cudaStreamDestroy(_stream_compute[i]); + } + + auto error = cudaGetLastError(); + if (error != cudaSuccess) { + NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); + } + + if (_comm_created) { + try { +#ifdef NVTE_UB_WITH_MPI + destroy_communicator_mpi(_ub_comm); +#else + destroy_communicator(_ub_comm); +#endif + } catch (const std::exception &e) { + NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); + } + _comm_created = false; + } +} + +TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, + const std::vector &chunk_shape) { + const auto scaling_mode = source.scaling_mode(); + + // Tensor dimensions + std::vector shape = shape_to_vector(source.shape()); + auto flatten_shape_to_2d = [](const std::vector &shape) -> std::pair { + if (shape.empty()) { + return {1, 1}; + } + size_t height = 1; + for (size_t i = 0; i < shape.size() - 1; ++i) { + height *= shape[i]; + } + return {height, shape.back()}; + }; + size_t height, width, chunk_height, chunk_width; + std::tie(height, width) = flatten_shape_to_2d(shape); + std::tie(chunk_height, chunk_width) = flatten_shape_to_2d(chunk_shape); + + // Check tensor dimensions +#define NVTE_DIM_CHECK(cond, message) \ + NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \ + ", chunk offset=", chunk_offset, ")") + NVTE_DIM_CHECK(height > 0 && width > 0, "Attempted to get chunk from empty tensor"); + NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk"); + NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width, + "Attempted to get out-of-bounds tensor chunk"); + if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { + // MXFP8 scale-inverses are padded to a 2D matrix with dims that + // are divisible by 128. UB doesn't handle this padding yet. + NVTE_DIM_CHECK(height % 128 == 0 && width % 128 == 0, + "Userbuffers requires MXFP8 tensor dims that are divisible by 128"); + NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0, + "Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128"); + } +#undef NVTE_DIM_CHECK + + // Construct tensor chunk + TensorWrapper chunk(scaling_mode); + for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { + auto param_type = static_cast(param_id); + auto param = source.get_parameter(param_type); + auto param_dptr = reinterpret_cast(param.data_ptr); + auto param_dtype = static_cast(param.dtype); + auto param_shape = shape_to_vector(param.shape); + + if (param_dptr != nullptr) { + if (param_type == NVTETensorParam::kNVTERowwiseData || + param_type == NVTETensorParam::kNVTEColumnwiseData) { + // Offset data pointer + param_dptr += get_buffer_size_bytes(chunk_offset, param_dtype); + param_shape = chunk_shape; + + if (param_type == NVTETensorParam::kNVTEColumnwiseData && + source.scaling_mode() == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) { + // Columnwise shape for FP8 tensor-scaled tensors shifts the last dimension to the front + auto last_dim = param_shape.back(); + param_shape.pop_back(); + param_shape.insert(param_shape.begin(), last_dim); + } + } else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING && + (param_type == NVTETensorParam::kNVTERowwiseScaleInv || + param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) { + // Calculate offset and size for MXFP8 scale-invs + size_t chunk_scale_height = chunk_height; + size_t chunk_scale_width = chunk_width; + if (param_type == NVTETensorParam::kNVTERowwiseScaleInv) { + chunk_scale_width /= 32; + } else { + chunk_scale_height /= 32; + } + param_dptr += get_buffer_size_bytes(chunk_offset / 32, param_dtype); + param_shape = {chunk_scale_height, chunk_scale_width}; + } + + // Set chunked source parameters into the chunked tensor output + chunk.set_parameter(param_type, reinterpret_cast(param_dptr), param_dtype, + param_shape); + } + } + return chunk; +} + +TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source, + size_t chunk_offset, + const std::vector &chunk_shape) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape); + + // Update chunk with offset data pointers from the communication buffer + auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + chunk_offset * _ubuf.element_size(); + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), + chunk.columnwise_shape()); + } + return chunk; +} + +/*************************************************************************************************** + * Comm+GEMM Overlap Base (Pipelined / Collective) + **************************************************************************************************/ + +CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, int tp_size, ExtAllgatherOp allgather_handle, + ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) + : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, + gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, + atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm) { + _rs_overlap_first_gemm = rs_overlap_first_gemm; + _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); + NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, + "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", + "or 2 (multi-atomic)."); + + NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); + size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); + void *buffer_ptr; + _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Register UBuf %d\n", _ub_reg); + } + _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); + + NVTE_CHECK_CUDA( + cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); +} + +CommOverlapBase::~CommOverlapBase() { + cudaEventDestroy(_start_d2dcopy); + cudaStreamSynchronize(_stream_comm); + cudaStreamDestroy(_stream_comm); +} + +/* +** Bulk GEMM + COMM +** This function assumes the communication input is pre-copied to _ubuf +*/ +void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0)); + + // Communication: AG and RS + int comm_elements = _ubuf.bytes() / 2; // UBUF uses 2Byte element size + if (comm_type == CommOverlapType::AG) { + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); + } else { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + comm_elements *= 2; + assert(rs_output.numel() == _ubuf.numel() / _tp_size); + assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); + assert(rs_output.element_size() == 2); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, + comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); + } else { + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); + } + } + + assert(pre_gelu_out.numel() == 0); + // When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch + if (_comm_launch_event) + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _comm_launch_event, 0)); + nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, + grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, + _stream_compute[0]); + + _ub_comm->sms = ori_sms; + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); + +} // CommOverlapBase::bulk_overlap + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); + size_t m_chunk = m / _num_splits; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Get input, output, and workspace data pointers + char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); + char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + + // Reset atomic counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _num_splits, false, stream_main); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0)); + + assert(pre_gelu_out.numel() == 0); + + auto output_d = get_buffer_chunk_like(D, 0, {n, m}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); + nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), + _stream_compute[0]); + + for (int i = 0; i < _num_splits; i++) { + if (_rs_kernel_type == 1) { + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_strided_atomic_fp8( + rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, + &counter_ptr[i], _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _num_splits, &counter_ptr[i], _ub_comm, + _stream_comm); + } + } else if (_rs_kernel_type == 2) { + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_strided_multiatomic_fp8( + rs_output_ptr, D.scale_inv(), _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + counter_ptr, _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, + _num_splits, counter_ptr, _ub_comm, + _stream_comm); + } + break; + } else { + consumer(counter_ptr, i, _stream_comm); + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, D.scale_inv(), + _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, _stream_comm); + } + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + + _ub_comm->sms = ori_sms; + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[0])); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // split_overlap_rs + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + // Get GEMM dimensions + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); + size_t m_chunk = m / _num_splits; + const std::vector input_a_chunk_shape = + (transa ? std::vector{m_chunk, k} : std::vector{k, m_chunk}); + const std::vector output_chunk_shape = {n, m_chunk}; + size_t input_a_chunk_size = m_chunk * k; + size_t output_chunk_size = n * m_chunk; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Helper function to get bias chunk if needed + auto maybe_get_bias_chunk = [this, &bias, m_chunk](size_t chunk_id) -> TensorWrapper { + if (bias.dptr() == nullptr) { + return TensorWrapper(); + } + return get_tensor_chunk(bias, chunk_id * m_chunk, {m_chunk}); + }; + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0)); + + assert(pre_gelu_out.numel() == 0); + + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_rs_overlap_first_gemm) { + auto input_a_chunk = get_tensor_chunk(A, 0, input_a_chunk_shape); + auto output_chunk = get_buffer_chunk_like(D, 0, output_chunk_shape); + auto bias_chunk = maybe_get_bias_chunk(0); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _stream_compute[0]); + + for (int i = 1; i < _num_splits; i++) { + input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape); + output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape); + bias_chunk = maybe_get_bias_chunk(i); + workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA( + cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication chunk + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, D.scale_inv(), _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm); + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + int last_compute_stream_id = + (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[last_compute_stream_id])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Last communication chunk with max SM + _ub_comm->sms = UB_MAX_SM; + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, D.scale_inv(), _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, + n, m, _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, + (_num_splits - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm); + } + } else { + for (int i = 0; i < _num_splits; i++) { + auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape); + auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape); + auto bias_chunk = maybe_get_bias_chunk(i); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication chunk. Uses MAX_SM at the last chunk + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, D.scale_inv(), _ub_reg, i * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm); + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // CommOverlapBase::split_overlap_rs + +void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) { + int comm_bytes = _ubuf.bytes(); + int comm_bytes_per_rank = comm_bytes / _tp_size; + + // We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush + userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, send_stream); + userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, recv_stream); + + // We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf + for (auto stream : {send_stream, recv_stream}) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0)); + } + + // Next we sync with the main stream + // We have to recapture an event off the comm stream to enable cuda graph capture otherwise the comm stream will be never be joined in the graph + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} + +/*************************************************************************************************** + * Comm+GEMM Overlap P2P Base (Ring-Exchange) + **************************************************************************************************/ + +CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm, bool aggregate) + : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, + gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate); +} + +void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate) { + _is_p2p = true; + _is_reduce_scatter = comm_type == CommOverlapType::RS; + _aggregate = aggregate; + + // Create workspace tensor with userbuffer + NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); + size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); + int buffer_chunk_bytes = buffer_bytes / _tp_size; + _num_ubuf_chunks = _tp_size; + if (_is_reduce_scatter) { + // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk + // outputs for reduction at the end of the pipelining. + buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1); + _num_ubuf_chunks = _tp_size * 2 - 1; + } + + void *buffer_ptr; + _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); + if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg); + _ubuf = TensorWrapper( + buffer_ptr, + std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]}, + buffer_dtype); + + // Create tensor chunks for easy management + char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); + for (int i = 0; i < _num_ubuf_chunks; i++) { + _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), + std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]}, + buffer_dtype)); + ubuf_byte_ptr += buffer_chunk_bytes; + } + + _rank_round_tp = (_rank / _tp_size) * _tp_size; + _next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp; + _prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp; + + _self_chunk_id = _tp_id; + if (_atomic_gemm && !_is_reduce_scatter) { + _use_multiatomic_ag = getenv("NVTE_AG_P2P_MULTI_ATOMIC"); + if (_use_multiatomic_ag) { + _use_ce = 0; + _ub_comm->push = 1; + if (_rank == 0) { + printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); + } + } + _self_chunk_id = 0; + NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); + } + + for (int i = 0; i < _stream_compute.size(); i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + _stream_send.push_back(std::move(stream)); + } + NVTE_CHECK_CUDA( + cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); +} + +CommOverlapP2PBase::~CommOverlapP2PBase() { + cudaEventDestroy(_stop_recv); + cudaEventDestroy(_stop_send); + cudaStreamDestroy(_stream_recv); + for (size_t i = 0; i < _stream_send.size(); i++) { + cudaStreamDestroy(_stream_send[i]); + } +} + +void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == source_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); +} + +TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, + size_t chunk_id) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, 0, shape_to_vector(_ubufs[chunk_id].shape())); + + // Update chunk with offset data pointers from the communication buffer + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.columnwise_shape()); + } + return chunk; +} + +/* +** Split AllGather + AtomicGEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG +** outputs in each rank to be in the contiguous memory space after all ring exchange phases. +*/ +void CommOverlapP2PBase::atomic_gemm_overlap_ag( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t n_chunk = _ubufs[0].size(0); + assert(pre_gelu_out.numel() == 0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].bytes(); + + // Create an GEMM output buffer with N+1 chunks in a contiguous memory + void *D_buffer_ptr; + int D_chunk_bytes = n_chunk * m * D.element_size(); + NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); + auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), + D.scale_inv(), D.scale_inv_shape(), D.scaling_mode()); + + // Reset atomic counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _tp_size, true, stream_main); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + + auto input_b = get_buffer_chunk_like(B, 0, shape_to_vector(B.shape())); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); + + for (int i = 0; i < _tp_size - 1; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = i; + int recv_chunk_id = i + 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + if (_use_multiatomic_ag) { + if (i == 0) { + _ub_comm->use_ce = 0; + userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, + _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, + true, _stream_recv); + } + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank, + _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank, + _stream_recv); + producer(counter_ptr, recv_chunk_id, _stream_recv); + } + if (i == 0) { + nvte_cublas_atomic_gemm(A.data(), input_b.data(), D_buffer.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, 0, _tp_size, false, + _counter.data(), stream_main); + } + } + + // Store the input activation for backprop + if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); + assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), + _ubufs[_self_chunk_id].bytes(), cudaMemcpyDeviceToDevice, + _stream_send[0])); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + } + + // Copy the first GEMM output chunk to the end chunk position of D_buffer + char *src_ptr = reinterpret_cast(D_buffer.dptr()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + D.bytes(), src_ptr, D_chunk_bytes, + cudaMemcpyDeviceToDevice, stream_main)); + + // Return the last N rows of D_buffer + NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.bytes(), + cudaMemcpyDeviceToDevice, stream_main)); + + // Clean up buffer allocation + NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main)); + + _ub_comm->sms = ori_sms; +} // CommOverlapP2PBase::atomic_gemm_overlap_ag + +/* +** Split AllGather + GEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG +** outputs in each rank to be in the contiguous memory space after all ring exchange phases. +*/ +void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].bytes(); + const bool do_gelu = pre_gelu_out.numel() > 0; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + if (_aggregate) { + const int num_steps = _tp_size / 2; + + // Chunk dims + std::vector input_b_chunk_shape = + (transb ? std::vector{k, 2 * n_chunk} : std::vector{2 * n_chunk, k}); + std::vector output_chunk_shape = {2 * n_chunk, m}; + size_t input_b_chunk_size = 2 * n_chunk * k; + size_t output_chunk_size = 2 * n_chunk * m; + + // Initial 1X input chunk exchange between neighboring peers + int send_chunk_id = _tp_id; + int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, + _stream_send[0]); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, + _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); + + int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; + const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; + const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; + + // Ring exchange of 2X inputs chunks + for (int i = 0; i < num_steps; i++) { + send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; + recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; + send_offset = comm_bytes * send_chunk_id; + recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + auto input_b_chunk = + get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape); + auto output_chunk = + get_tensor_chunk(D, output_chunk_size * send_chunk_id / 2, output_chunk_shape); + auto aux_chunk = (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id / 2, + {n_chunk * 2, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + if (i < num_steps - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, + next_rank, _stream_send[0]); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, + prev_rank, _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } + } + } else { + // Chunk dims + std::vector input_b_chunk_shape = + (transb ? std::vector{k, n_chunk} : std::vector{n_chunk, k}); + std::vector output_chunk_shape = {n_chunk, m}; + size_t input_b_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; + + for (int i = 0; i < _tp_size; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; + int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + auto input_b_chunk = + get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape); + auto output_chunk = + get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + if (i < _tp_size - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, + _next_rank, _stream_send[0]); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + _prev_rank, _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } + } + } + + // Copy all-gathered B from communication buffer into auxiliary output + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); +} // CommOverlapP2PBase::split_overlap_ag + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2PBase::atomic_gemm_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get communication and GEMM input chunk sizes + const int comm_bytes = _ubufs[0].bytes(); + + // Reset counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _tp_size, false, stream_main); + + // Catch up the main stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + + // Atomic GEMM + // Process GEMM chunks in the order that AG+GEMM places the output chunks. + auto output_d = get_buffer_chunk_like(D, 0, shape_to_vector(D.shape())); + nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, + _math_sms, 0, _tp_size, true, _counter.data(), stream_main); + + // P2P communication chunk + for (int i = 1; i < _tp_size; i++) { + int send_chunk_id = i - 1; + int recv_chunk_id = send_chunk_id + _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + + consumer(counter_ptr, send_chunk_id, _stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + _stream_recv); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + _ub_comm->sms = ori_sms; +} + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get communication and GEMM input chunk sizes + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n_chunk = _ubufs[0].size(0); + const int comm_bytes = _ubufs[0].bytes(); + + // Get input and workspace data pointers + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Catch up the main stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + for (size_t i = 0; i < _stream_send.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[i], _start_compute, 0)); + } + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + // GEMM and send/recv chunks + for (int i = 0; i < _tp_size; i++) { + // GEMM chunk + int stream_id = i % _stream_compute.size(); + int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; + + auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); + auto output_chunk = get_buffer_chunk_by_id(D, i); + + auto workspace_chunk = + get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _stream_compute[stream_id]); + + if (i > 0) { + // P2P communication chunk + int prev_stream_id = (i - 1) % _stream_compute.size(); + int send_offset = comm_bytes * (i - 1); + int recv_offset = comm_bytes * (i - 1 + _tp_size); + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[prev_stream_id])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[prev_stream_id], _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0)); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + _stream_send[prev_stream_id]); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + _stream_recv); + } + } + + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + + _ub_comm->sms = ori_sms; +} + +} // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.orig b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.orig new file mode 100644 index 0000000000..cffc411a0d --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.orig @@ -0,0 +1,327 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ +#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ + +#include +#include +#include + +#include + +#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" + +#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 + +namespace transformer_engine { + +/* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. + * This can turned on by building Transformer Engine with the `NVTE_UB_WITH_MPI=1` option. + * + * \return True if Userbuffers is built with MPI + */ +bool ubuf_built_with_mpi(); + +enum class CommOverlapType { RS = 0, AG = 1 }; + +enum class CommOverlapAlgo { + BULK_OVERLAP_AG = 0, + BULK_OVERLAP_RS = 1, + SPLIT_PIPELINED_AG_P2P = 2, + SPLIT_PIPELINED_RS = 3, + SPLIT_PIPELINED_RS_P2P = 4, + ATOMIC_GEMM_RS = 5, + ATOMIC_GEMM_AG_P2P = 6, + ATOMIC_GEMM_RS_P2P = 7, + EXTERNAL_BULK_OVERLAP_AG = 8, +}; + +class CommOverlapCore { + protected: + static inline communicator *_ub_comm{nullptr}; + static inline bool _comm_created{false}; + + int _rank; + int _tp_id; + int _tp_size; + int _num_splits; + int _math_sms; + int _num_comm_sm; + int _cga_size; + int _use_ce; + int _ub_reg; + int _gemm_priority; + int _comm_priority; + bool _atomic_gemm{false}; + bool _is_p2p{false}; + + TensorWrapper _ubuf; + TensorWrapper _counter; + float *_ubuf_scale_inv; + bool _ubuf_scale_inv_initialized{false}; + + std::vector _stream_compute; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; + + private: + void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + + public: + CommOverlapCore() {} // dummy constructor for exposing type to Python + + CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, + int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm); + + virtual ~CommOverlapCore(); + + void *get_ubuf_dptr() { return _ubuf.dptr(); } + + void set_ubuf_scale_inv(float *scale_inv) { + _ubuf_scale_inv = scale_inv; + _ubuf_scale_inv_initialized = true; + } + + virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) { + NVTE_ERROR("Operation is not implemented."); + } + + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, + const std::vector &shape); + + TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, + const std::vector &shape); + + int get_tp_size() { return _tp_size; } + + bool is_atomic_gemm() { return _atomic_gemm; } + + bool is_p2p_overlap() { return _is_p2p; } + + bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } +}; // CommOverlapCore + +class CommOverlapBase : public CommOverlapCore { + protected: + int _rs_kernel_type; + bool _rs_overlap_first_gemm; + cudaStream_t _stream_comm; + cudaEvent_t _start_d2dcopy; + + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm); + + public: + CommOverlapBase() {} // dummy constructor for exposing type to Python + + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, + bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); + + virtual ~CommOverlapBase(); + + /* + ** Bulk GEMM + COMM + ** This function assumes the communication input is pre-copied to _ubuf + */ + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) override; +}; // CommOverlapBase + +class CommOverlapP2PBase : public CommOverlapCore { + protected: + bool _is_reduce_scatter{false}; + bool _use_multiatomic_ag{false}; + bool _aggregate; + int _next_rank; + int _prev_rank; + int _rank_round_tp; + int _num_ubuf_chunks; + int _self_chunk_id; + std::vector _ubufs; + std::vector _stream_send; + cudaStream_t _stream_recv; + cudaEvent_t _stop_send, _stop_recv; + + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate); + + public: + CommOverlapP2PBase() {} // dummy constructor for exposing type to Python + + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, + int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, + bool atomic_gemm = false, bool aggregate = false); + + virtual ~CommOverlapP2PBase(); + + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); + + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + /* + ** Split AllGather + AtomicGEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. + */ + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; + + /* + ** Split AllGather + GEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. + */ + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + /* + ** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object. + ** The gemm for overlap_gemm is assumed to have been previously started. + */ + void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } +}; // CommOverlapP2PBase + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.rej b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.rej new file mode 100644 index 0000000000..f229f5eea6 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.rej @@ -0,0 +1,11 @@ +--- transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h ++++ transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +@@ -198,6 +198,8 @@ class CommOverlapBase : public CommOverlapCore { + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; ++ // initialize ubnext buffer and return multicast pointer for allreduce ++ uintptr_t init_ubnext(); + }; // CommOverlapBase + + class CommOverlapP2PBase : public CommOverlapCore { diff --git a/transformer_engine/common/include/transformer_engine/ubnext.h b/transformer_engine/common/include/transformer_engine/ubnext.h new file mode 100644 index 0000000000..aa82a27c33 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/ubnext.h @@ -0,0 +1,31 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_UBNEXT_H_ +#define TRANSFORMER_ENGINE_UBNEXT_H_ + +#include "transformer_engine.h" + +namespace transformer_engine { + +#ifdef __cplusplus +extern "C" { +#endif + +void allreduce_2shot_mc(int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* mcptr_in, + void* mcptr_out, size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size, cudaStream_t stream); +void allreduce_2shot_mc_lamport(int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* ucptr_out, + void* mcptr_in, void* mcptr_out, void* clear_ptr, size_t bytes, + bool poisoned, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size, cudaStream_t stream); +void allreduce_2shot_uc(int ranks, int myrank, void* uc0ptr, void* ucptr_in, void* ucptr_out, + size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size, cudaStream_t stream); + +#ifdef __cplusplus +} +#endif +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 706c237ccc..68e5f8aef8 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -19,7 +19,8 @@ *transformer_engine::CommOverlapP2PBase*; *transformer_engine::CommOverlapCore*; *nvshmem_wait_on_stream*; - *nvshmemi_init_thread* + *nvshmemi_init_thread*; + allreduce_*; }; local: *; }; diff --git a/transformer_engine/common/ubnext.cu b/transformer_engine/common/ubnext.cu new file mode 100644 index 0000000000..6358b07c85 --- /dev/null +++ b/transformer_engine/common/ubnext.cu @@ -0,0 +1,608 @@ +#include +#include +#include + +#include "./common.h" + +#define TIMEOUT 2000000000ull +//#define UB_TIMEOUT_ENABLED 1 + +#define NVTE_UB_MAXTHREADS 1024 +#define NVTE_UB_MAX_SMS 128 +#define NVTE_UB_LAMPORT_INT 0xFFFAFFFA + +//REG0 flags in use +#define NVTE_UB_FLAG_NVLS2_LAMPORT_ID 0 +#define NVTE_UB_FLAG_NVLS2_LAMPORT_SM_SYNC 1 +#define NVTE_UB_FLAG_NVLS2_LAMPORT_RS_BAR 2 +#define NVTE_UB_FLAG_NVLS2_ID 3 +#define NVTE_UB_FLAG_NVLS2_SM_SYNC 4 +#define NVTE_UB_FLAG_NVLS2_RS_BAR 5 +#define NVTE_UB_FLAG_NVLS2_AG_BAR 6 + +#define xhalf __nv_bfloat16 + +#define ATOMIC_MCINC(ptr) \ + asm volatile("multimem.red.add.u32 [%0], %1;" ::"l"(ptr), "r"(1) \ + : "memor" \ + "y"); +#define ATOMIC_UCINC(ptr) \ + asm volatile("red.global.add.u32 [%0], %1;" ::"l"(ptr), "r"(1) \ + : "memor" \ + "y"); +#define MULTIMEM_ST(val, ptr) \ + asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), \ + "r"(val.y), "r"(val.z), "r"(val.w) \ + : "memory"); + +#define MULTIMEM_LD(val, ptr) \ + asm("multimem.ld_reduce.global.add.v4.bf16x2.acc::f32 {%0,%1,%2,%3}, [%4];" \ + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) \ + : "l"(ptr) \ + : "memory"); + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +// Return true if producer > consumer, otherwise false while preventing integer overflow +// If we expect that producer will be 2B+ messages behind consumer +#define CHECK_IDS(producer, consumer) (((unsigned)(producer) - (unsigned)(consumer)) & (~INT_MAX)) + +#define FINAL_MASK 0xffffffff +template +__inline__ __device__ T warpReduceSumV2(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T) (0.0f); +} + +template +__inline__ __device__ T blockReduceSumV2(T* val) +{ + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSumV2(val); + + if (lane == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) + { + val[i] = is_mask ? shared[i][lane] : (T) (0.0f); + } + warpReduceSumV2(val); + return (T) 0.0f; +} + +template +__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) + userbuffers_fp16_sum_inplace_gpu_mc(const int RANKS, const int myrank, const int mylines, + int *uc_flagptr, int *mc_flagptr, uint4 *mc_ptr_in, + uint4 *mc_ptr_out, uint4 *residual_in, uint4 *residual_out, xhalf* gamma, float eps, const int hidden_size, bool fuse_layernorm) { + // flags[3,4,5,6]: reduce_id, sm_sync-local, flag-barrier-1,flag-barrier-2 + int reduce_id; + __shared__ float s_variance; + + if (threadIdx.x == 0) { + cudaGridDependencySynchronize(); + if (blockIdx.x == 0) ATOMIC_MCINC(mc_flagptr + NVTE_UB_FLAG_NVLS2_RS_BAR); + + reduce_id = uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] + 1; + + volatile int *flag = (volatile int *)&(uc_flagptr[NVTE_UB_FLAG_NVLS2_RS_BAR]); + + const int expected = reduce_id * RANKS; + +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } + } + + __syncthreads(); + + const int loop_step0 = blockDim.x; + const int loop_step = loop_step0 * UNROLL * gridDim.x; + const int start_elem = threadIdx.x + blockDim.x*blockIdx.x*UNROLL; + const int end_elem = max(start_elem, mylines); + //const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; + //const int end_aligned = start_elem + aligned_elem; + + for (int line = start_elem; line < end_elem; line += loop_step) { + uint4 val[UNROLL]; + xhalf *x = reinterpret_cast(&val[0]); +#pragma unroll + for (int i = 0; i < UNROLL; i++) MULTIMEM_LD(val[i], mc_ptr_in + (line + i * loop_step0)) + + if(residual_in!=nullptr) { + for (int i = 0; i < UNROLL; i++) { + uint4 resval = residual_in[line+i*loop_step0]; + xhalf *y = reinterpret_cast(&resval); + #pragma unroll + for (int j = 0; j < 8; j++) + x[i*8+j] += y[j]; + if(residual_out!=nullptr) + residual_out[line+i*loop_step0]=val[i]; + } + } + if(fuse_layernorm) { + float local_var_sum = 0.0f; + for (int j = 0; j < UNROLL*sizeof(int4) / sizeof(xhalf); j++) + local_var_sum += (float)(x[j])*(float)(x[j]); + + float packed[1] = {local_var_sum}; + blockReduceSumV2(packed); + float variance = packed[0]; + + if (threadIdx.x == 0) + { + variance = (variance / hidden_size); // Var[x] = E[x²] + s_variance = rsqrtf(variance + eps); + } + __syncthreads(); + } + + int i=0; +#pragma unroll + for (int g = 0; g < UNROLL; g++) { + if(fuse_layernorm) { + #pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(xhalf); j++) { + x[i] = (xhalf)((float)(x[i]) * s_variance * (float) gamma[(threadIdx.x+g*loop_step0)*sizeof(int4)/sizeof(xhalf)+j]); + i++; + } + } + MULTIMEM_ST(val[g], mc_ptr_out + (line + g * loop_step0)) + } + } + /* + for (int line = end_aligned; line < end_elem; line += loop_step0) { + uint4 val; + xhalf *x = reinterpret_cast(&val); + MULTIMEM_LD(val, mc_ptr_in + (line)) + + if(residual_in!=nullptr) { + uint4 resval = residual_in[line]; + xhalf *y = reinterpret_cast(&resval); + #pragma unroll + for (int j = 0; j < 8; j++) + x[j] += y[j]; + if(residual_out!=nullptr) + residual_out[line]=val; + } + + MULTIMEM_ST(val, mc_ptr_out + (line)) + } + */ + __syncthreads(); + if (threadIdx.x != 0) return; + + __threadfence(); + const int value_to_add = blockIdx.x == 0 ? NVTE_UB_MAX_SMS - gridDim.x + 1 : 1; + const int old_val_sm_sync = atomicAdd(uc_flagptr + NVTE_UB_FLAG_NVLS2_SM_SYNC, value_to_add); + + const int lastSM = + (gridDim.x == 1 || old_val_sm_sync + value_to_add == reduce_id * NVTE_UB_MAX_SMS); + if (!lastSM) return; + __threadfence_system(); + ATOMIC_MCINC(mc_flagptr + NVTE_UB_FLAG_NVLS2_AG_BAR); + uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] = reduce_id; + cudaTriggerProgrammaticLaunchCompletion(); + volatile int *flag = (volatile int *)&(uc_flagptr[NVTE_UB_FLAG_NVLS2_AG_BAR]); + const int expected = reduce_id * RANKS; + +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } +} // fp16 inplace reduce kernel (Hopper) MC + +template +__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) + userbuffers_fp16_sum_inplace_gpu_uc(const int myrank, const int numlines, + const int lineoffset_in, const int lineoffset_out, + int *uc_flagptr, void **commbuff, uint4 *residual_in, uint4 *residual_out, xhalf* gamma, float eps, const int hidden_size, bool fuse_layernorm) { + // flags[3,4,5,6]: reduce_id, sm_sync-local, flag-barrier-1,flag-barrier-2 + //NB! uc_flagptr is shifted by ranks*8 for easier flag offsets + // while lineoffset is relative to start of reg0 + __shared__ uint4 *userptr[RANKS]; + __shared__ int lastSM; + int reduce_id; + + if (threadIdx.x < RANKS) { + int *rem_flagptr = (reinterpret_cast(commbuff[threadIdx.x])); + cudaGridDependencySynchronize(); + if (blockIdx.x == 0) ATOMIC_UCINC(rem_flagptr + NVTE_UB_FLAG_NVLS2_RS_BAR + RANKS * 2); + + reduce_id = uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] + 1; + + userptr[threadIdx.x] = (uint4 *)rem_flagptr; + } + + if (threadIdx.x == 0) { + volatile int *flag = uc_flagptr + NVTE_UB_FLAG_NVLS2_RS_BAR; + lastSM = 0; + const int expected = reduce_id * RANKS; +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } + } + + __syncthreads(); + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; + line += blockDim.x * gridDim.x * RANKS) { + uint4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + // int dest = (i+myrank+warp)&(RANKS-1); + val[i] = userptr[dest[i]][lineoffset_in + line]; + } + + uint4 sum = val[0]; + xhalf *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + xhalf *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } + + if(residual_in!=nullptr) { + uint4 resval = residual_in[lineoffset_in + line]; + xhalf *y = reinterpret_cast(&resval); + #pragma unroll + for (int j = 0; j < 8; j++) + s[j] += y[j]; + if(residual_out!=nullptr) + residual_out[lineoffset_in + line]=sum; + } + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + // int dest = (i+myrank+warp)&(RANKS-1); + userptr[dest[i]][lineoffset_out + line] = sum; + } + } + + __syncthreads(); + + if (threadIdx.x == 0) { + __threadfence(); + const int value_to_add = blockIdx.x == 0 ? NVTE_UB_MAX_SMS - gridDim.x + 1 : 1; + const int old_val_sm_sync = atomicAdd(uc_flagptr + NVTE_UB_FLAG_NVLS2_SM_SYNC, value_to_add); + lastSM = (gridDim.x == 1 || old_val_sm_sync + value_to_add == reduce_id * NVTE_UB_MAX_SMS); + if (lastSM) uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] = reduce_id; + cudaTriggerProgrammaticLaunchCompletion(); + } + if (threadIdx.x >= RANKS) return; + __syncthreads(); + if (!lastSM) return; + if (threadIdx.x == 0) __threadfence_system(); + __syncthreads(); + ATOMIC_UCINC((int *)(userptr[threadIdx.x]) + NVTE_UB_FLAG_NVLS2_AG_BAR + RANKS * 2); + if (threadIdx.x != 0) return; + volatile int *flag = uc_flagptr + NVTE_UB_FLAG_NVLS2_AG_BAR; + const int expected = reduce_id * RANKS; +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } +} // UC 2shot kernel (non-lamport) + +__global__ void memset_int(uint32_t *data, int n, uint32_t val) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + data[idx] = val; + } +} + +template +__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) userbuffers_fp16_sum_inplace_gpu_mc_lamport( + const int RANKS, const int myrank, const int mylines, const int numlines, int *uc_flagptr, int *mc_flagptr, + uint4 *mc_ptr_in, uint4 *mc_ptr_out, uint4 *uc_ptr_out, uint4 *clear_ptr, uint4 *residual_in, uint4 *residual_out, xhalf* gamma, float eps, const int hidden_size, bool fuse_layernorm) { + // flags[0,1,2]: reduce_id, sm_sync-local, flag-barrier + // those go right after rank UC pointers, but its the CPU caller who should account for it + int reduce_id; + __shared__ float s_variance; + + if (threadIdx.x == 0) { + cudaGridDependencySynchronize(); + if (blockIdx.x == 0) ATOMIC_MCINC(mc_flagptr + NVTE_UB_FLAG_NVLS2_LAMPORT_RS_BAR); + reduce_id = uc_flagptr[NVTE_UB_FLAG_NVLS2_LAMPORT_ID]; + const int value_to_add = blockIdx.x == 0 ? NVTE_UB_MAX_SMS - gridDim.x + 1 : 1; + const int old_val_sm_sync = + atomicAdd(uc_flagptr + NVTE_UB_FLAG_NVLS2_LAMPORT_SM_SYNC, value_to_add); + volatile int *flag = (volatile int *)&(uc_flagptr[NVTE_UB_FLAG_NVLS2_LAMPORT_RS_BAR]); + reduce_id++; + const int lastSM = + (gridDim.x == 1 || old_val_sm_sync + value_to_add == reduce_id * NVTE_UB_MAX_SMS); + + if (lastSM) uc_flagptr[NVTE_UB_FLAG_NVLS2_LAMPORT_ID] = reduce_id; + cudaTriggerProgrammaticLaunchCompletion(); + + const int expected = reduce_id * RANKS; + +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } + } + __syncthreads(); + + const int loop_step0 = blockDim.x; + const int loop_step = loop_step0 * UNROLL * gridDim.x; + const int start_elem = threadIdx.x + blockDim.x*blockIdx.x*UNROLL; + const int end_elem = max(start_elem, mylines); + + for (int line = start_elem; line < end_elem; line += loop_step) { + uint4 val[UNROLL]; + xhalf *x = reinterpret_cast(&val[0]); +#pragma unroll + for (int i = 0; i < UNROLL; i++) MULTIMEM_LD(val[i], mc_ptr_in + (line + i * loop_step0)) + + if(residual_in!=nullptr) { + for (int i = 0; i < UNROLL; i++) { + uint4 resval = residual_in[line+i*loop_step0]; + xhalf *y = reinterpret_cast(&resval); + #pragma unroll + for (int j = 0; j < 8; j++) + x[i*8+j] += y[j]; + if(residual_out!=nullptr) + residual_out[line+i*loop_step0]=val[i]; + } + } + if(fuse_layernorm) { + float local_var_sum = 0.0f; + for (int j = 0; j < UNROLL*sizeof(int4) / sizeof(xhalf); j++) + local_var_sum += (float)(x[j])*(float)(x[j]); + + float packed[1] = {local_var_sum}; + blockReduceSumV2(packed); + float variance = packed[0]; + + if (threadIdx.x == 0) + { + variance = (variance / hidden_size); // Var[x] = E[x²] + s_variance = rsqrtf(variance + eps); + } + __syncthreads(); + } + + int i=0; +#pragma unroll + for (int g = 0; g < UNROLL; g++) { + if(fuse_layernorm) { + #pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(xhalf); j++) { + x[i] = (xhalf)((float)(x[i]) * s_variance * (float) gamma[(threadIdx.x+g*loop_step0)*sizeof(int4)/sizeof(xhalf)+j]); + i++; + } + } + MULTIMEM_ST(val[g], mc_ptr_out + (line + g * loop_step0)) + } + } + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < numlines; + line += blockDim.x * gridDim.x) { +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (true) { + uint4 result; + + asm volatile("ld.volatile.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(result.x), "=r"(result.y), "=r"(result.z), "=r"(result.w) + : "l"(&uc_ptr_out[line]) + : "memory"); + if (result.w != NVTE_UB_LAMPORT_INT) { + if (clear_ptr) clear_ptr[line].w = NVTE_UB_LAMPORT_INT; + break; + } +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("Lamport POLL:SM %d [%d]:expecting %d got (%d,%d,%d) %d\n", blockIdx.x, threadIdx.x, + NVTE_UB_LAMPORT_INT, result.x, result.y, result.z, result.w); + break; + } +#endif + } + } + +} // two-shot NVLS + lamport sync instead of last membar + +#define SETUP_LAUNCH_CONFIG(sms, threads, stream, cga_size, pdl_launch) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[3]; \ + attribute_ub[2].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[2].val.clusterDim.x = sms % cga_size == 0 ? cga_size : 1; \ + attribute_ub[2].val.clusterDim.y = 1; \ + attribute_ub[2].val.clusterDim.z = 1; \ + attribute_ub[1].id = cudaLaunchAttributeCooperative; \ + attribute_ub[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \ + attribute_ub[0].val.programmaticStreamSerializationAllowed = pdl_launch; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = 3; + +namespace transformer_engine { + + #define split_tokens(x) \ + const int elements = bytes/sizeof(half); \ + const int elements_per_thread = sizeof(uint4)/sizeof(half); \ + int nthreads=1024, nlines=4; \ + size_t total_bytes = bytes/ranks, start_bytes = myrank*total_bytes; \ + int sms=x; \ + if(hidden_size) { \ + assert(hidden_size<=32768); \ + assert(elements % hidden_size==0); \ + assert(hidden_size%elements_per_thread==0); \ + int ntokens = elements/hidden_size; \ + int my_tokens = ntokens / ranks; \ + int extra_tokens = ntokens % ranks; \ + int first_token = myrank*my_tokens; \ + first_token+= myrank1024) { \ + nlines++; \ + assert(nlines<=4); \ + if((hidden_size/elements_per_thread)%nlines==0) \ + nthreads=((hidden_size/elements_per_thread))/nlines; \ + } \ + if(sms>my_tokens) sms=my_tokens; \ + if (sms==0) sms=1; \ + } \ + bool residual_in_global = residual_in!=nullptr && residual_in!=residual_out && residual_out!=nullptr; // out residual is always local + +extern "C" void allreduce_2shot_mc(int ranks, int myrank, void *uc0ptr, void *mc0ptr, + void *mcptr_in, void *mcptr_out, size_t bytes, + void *residual_in, void *residual_out, bool fuse_layernorm, + void* gamma, float eps, const int hidden_size, + cudaStream_t stream) { + split_tokens(32); + + SETUP_LAUNCH_CONFIG(sms, nthreads, stream, 4, 1); + + int arg1 = ranks, arg2 = myrank, arg3 = total_bytes / sizeof(uint4); + void *arg4 = uc0ptr + (ranks * 8), *arg5 = mc0ptr + (ranks * 8), *arg6 = mcptr_in+start_bytes, + *arg7 = mcptr_out+start_bytes, *arg8 = residual_in_global?residual_in+start_bytes:residual_in, *arg9 = residual_out, *arg10 = gamma; + float arg11 = eps; int arg12 = hidden_size; bool arg13 = fuse_layernorm; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg4, + (void *)&arg5, (void *)&arg6, (void *)&arg7, (void *)&arg8, (void *)&arg9, (void *)&arg10, (void *)&arg11, (void *)&arg12, (void *)&arg13}; + #define call_mc_kernel(x,cond) \ + if(x==nlines || cond) {CUDACHECK(cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc), kernelArgs)); return;} + call_mc_kernel(1,false); + call_mc_kernel(2,false); + call_mc_kernel(3,false); + call_mc_kernel(4,true); +} + +extern "C" void allreduce_2shot_uc(int ranks, int myrank, void *uc0ptr, void *ucptr_in, + void *ucptr_out, size_t bytes, void *residual_in, void *residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size, cudaStream_t stream) { + SETUP_LAUNCH_CONFIG(64, 1024, stream, 4, 1); + + int arg1 = myrank, arg2 = bytes / 16, arg3 = (int4 *)ucptr_in - (int4 *)uc0ptr, + arg4 = (int4 *)ucptr_out - (int4 *)uc0ptr; + void *arg5 = uc0ptr + (ranks * 8), **arg6 = (void **)uc0ptr, *arg7 = residual_in, *arg8 = residual_out, *arg9 = gamma; + float arg10 = eps; int arg11 = hidden_size; bool arg12 = fuse_layernorm; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, + (void *)&arg4, (void *)&arg5, (void *)&arg6, (void *)&arg7, (void *)&arg8, (void *)&arg9, (void *)&arg10, (void *)&arg11, (void *)&arg12}; +#define call_uc_kernel(x) \ + if (x == ranks) \ + CUDACHECK( \ + cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_uc), kernelArgs)); + call_uc_kernel(2); + call_uc_kernel(4); + call_uc_kernel(8); +} + +extern "C" void allreduce_2shot_mc_lamport(int ranks, int myrank, void *uc0ptr, void *mc0ptr, + void *ucptr_out, void *mcptr_in, void *mcptr_out, + void *clear_ptr, size_t bytes, bool poisoned, + void *residual_in,void* residual_out, bool fuse_layernorm, + void* gamma, float eps, const int hidden_size, + cudaStream_t stream) { + if (!poisoned) { + //user tells us destination was not pre-poisoned, so we need to do it before calling allreduce + int threadsPerBlock = 512; + int blocks = (bytes / 4 + threadsPerBlock - 1) / threadsPerBlock; + memset_int<<>>((uint32_t *)ucptr_out, bytes / 4, + NVTE_UB_LAMPORT_INT); + } + split_tokens(64); + + SETUP_LAUNCH_CONFIG(64, nthreads, stream, 4, 1); + + int arg1 = ranks, arg2 = myrank, arg3 = total_bytes / sizeof(uint4), arg3a = bytes / sizeof(uint4); + void *arg4 = uc0ptr + (ranks * 8), *arg5 = mc0ptr + (ranks * 8), *arg6 = mcptr_in+start_bytes, + *arg7 = mcptr_out+start_bytes, *arg8 = ucptr_out, *arg9 = clear_ptr, *arg10 = residual_in_global?residual_in+start_bytes:residual_in, *arg11 = residual_out, *arg12 = gamma; + float arg13 = eps; int arg14 = hidden_size; bool arg15 = fuse_layernorm; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg3a, (void *)&arg4, (void *)&arg5, + (void *)&arg6, (void *)&arg7, (void *)&arg8, (void *)&arg9, (void *)&arg10, (void *)&arg11, (void *)&arg12, (void *)&arg13, (void *)&arg14, (void *)&arg15}; + + #define call_mc_lamport_kernel(x,cond) \ + if(x==nlines || cond) {CUDACHECK(cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc_lamport), kernelArgs)); return;} + + call_mc_lamport_kernel(1,false); + call_mc_lamport_kernel(2,false); + call_mc_lamport_kernel(3,false); + call_mc_lamport_kernel(4,true); + } + +} // namespace transformer_engine diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index bce124e705..6639b391f7 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "cuda_runtime.h" @@ -117,6 +118,8 @@ std::shared_ptr, \ transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + py::call_guard()) \ + .def("init_ubnext", &transformer_engine::CommOverlapBase::init_ubnext, \ py::call_guard()); \ py::class_, \ @@ -135,6 +138,34 @@ }, \ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); + py::call_guard()); \ + m.def( \ + "allreduce_2shot_mc", \ + [](int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* mcptr_in, void* mcptr_out, \ + size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size) { \ + transformer_engine::allreduce_2shot_mc(ranks, myrank, uc0ptr, mc0ptr, mcptr_in, mcptr_out, \ + bytes, residual_in, residual_out, fuse_layernorm, gamma, eps, hidden_size, at::cuda::getCurrentCUDAStream()); \ + }, \ + py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("mc0ptr"), \ + py::arg("mcptr_in"), py::arg("mcptr_out"), py::arg("bytes"), py::arg("residual_in"), py::arg("residual_out"), py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), py::arg("hidden_size")); \ + m.def( \ + "allreduce_2shot_uc", \ + [](int ranks, int myrank, void* uc0ptr, void* ucptr_in, void* ucptr_out, size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size) { \ + transformer_engine::allreduce_2shot_uc(ranks, myrank, uc0ptr, ucptr_in, ucptr_out, bytes, \ + residual_in, residual_out, fuse_layernorm, gamma, eps, hidden_size, at::cuda::getCurrentCUDAStream()); \ + }, \ + py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("ucptr_in"), \ + py::arg("ucptr_out"), py::arg("bytes"), py::arg("residual_in"), py::arg("residual_out"), py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), py::arg("hidden_size")); \ + m.def( \ + "allreduce_2shot_mc_lamport", \ + [](int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* ucptr_out, void* mcptr_in, \ + void* mcptr_out, void* clear_ptr, size_t bytes, bool poisoned, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size) { \ + transformer_engine::allreduce_2shot_mc_lamport( \ + ranks, myrank, uc0ptr, mc0ptr, ucptr_out, mcptr_in, mcptr_out, clear_ptr, bytes, \ + poisoned, residual_in, residual_out, fuse_layernorm, gamma, eps, hidden_size, at::cuda::getCurrentCUDAStream()); \ + }, \ + py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("mc0ptr"), \ + py::arg("ucptr_out"), py::arg("mcptr_in"), py::arg("mcptr_out"), py::arg("clear_ptr"), \ + py::arg("bytes"), py::arg("poisoned"), py::arg("residual_in"), py::arg("residual_out"), py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), py::arg("hidden_size")); #endif diff --git a/transformer_engine/common/util/pybind_helper.h.orig b/transformer_engine/common/util/pybind_helper.h.orig new file mode 100644 index 0000000000..bce124e705 --- /dev/null +++ b/transformer_engine/common/util/pybind_helper.h.orig @@ -0,0 +1,140 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ + +#include +#include +#include +#include + +#include "cuda_runtime.h" + +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType", pybind11::module_local()) \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ + .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ + pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_Softmax_Type", pybind11::module_local()) \ + .value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \ + .value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \ + .value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \ + pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) \ + .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ + .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ + .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) \ + .value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \ + .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ + .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ + .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \ + pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \ + .value("NVTE_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD) \ + .value("NVTE_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD) \ + .value("NVTE_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD) \ + .value("NVTE_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD) \ + .value("NVTE_Paged_KV_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD) \ + .value("NVTE_Paged_KV_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD) \ + .value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \ + .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ + .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ + .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_( \ + m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \ + .value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \ + .value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \ + pybind11::enum_(m, "CommOverlapType", \ + pybind11::module_local()) \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo", \ + pybind11::module_local()) \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ + .value("EXTERNAL_BULK_OVERLAP_AG", \ + transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \ + py::class_>(m, "CommOverlapCore", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ + py::call_guard()) \ + .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ + py::call_guard()) \ + .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ + py::call_guard()) \ + .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ + py::call_guard()); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def( \ + "get_stream_priority_range", \ + [](int device_id = -1) { \ + int low_pri, high_pri; \ + transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ + return std::make_pair(low_pri, high_pri); \ + }, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); + +#endif diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index 944d1849bf..07d150b0f0 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -7,3 +7,4 @@ from .fused_attn import * from .gemm import * +from .symm_allocator import * diff --git a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py new file mode 100644 index 0000000000..103d3884e1 --- /dev/null +++ b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py @@ -0,0 +1,389 @@ +import torch +import os +import gc +import weakref +from typing import List, Tuple, Optional, Dict +from threading import Lock +import torch.distributed._symmetric_memory as symm_mem +from ctypes import pythonapi, c_void_p, py_object + +def to_capsule(ptr): + # Set the return type to py_object to get a Python object (PyCapsule) + pythonapi.PyCapsule_New.restype = py_object + pythonapi.PyCapsule_New.argtypes = [c_void_p, c_void_p, c_void_p] + # Create capsule with a name (optional, can be None) and no destructor + capsule = pythonapi.PyCapsule_New(ptr, None, None) + return capsule + + +class SymmTensor(torch.Tensor): + """Custom tensor subclass that uses custom memory""" + + @staticmethod + def __new__( + cls, + pool: torch.Tensor, + offset: int, + shape: torch.Size, + dtype: torch.dtype, + allocator: "SymmAllocator", + ): + # Calculate number of elements and bytes + num_elements = torch.Size(shape).numel() + element_size = torch.tensor(0, dtype=dtype).element_size() + nbytes = element_size * num_elements + + # Validate pool + assert pool.dtype == torch.uint8, f"Expected uint8 pool, got {pool.dtype}" + assert ( + pool.numel() >= offset + nbytes + ), f"Pool too small: {pool.numel()} bytes, need {offset + nbytes}" + + # Slice the pool to get the required bytes + byte_slice = pool[offset : offset + nbytes] + + # Reinterpret the uint8 bytes as the target dtype + tensor = byte_slice.view(dtype=dtype) + tensor = tensor.view(*shape) + + # Initialize as a subclass of torch.Tensor + self = torch.Tensor._make_subclass(cls, tensor) + if not isinstance(allocator, SymmAllocator): + raise TypeError(f"Expected SymmAllocator, got {type(allocator)}") + self._allocator = allocator + self._ptr = tensor.data_ptr() + self._offset = offset + self._size = nbytes + return self + + def __del__(self): + """Custom deallocator to return memory to the pool.""" + if hasattr(self, "_allocator") and hasattr(self, "_ptr"): + self._allocator.free(self._ptr) + + +class SymmAllocator: + def __init__(self, size_bytes: int, device: torch.device, dist_group: torch.distributed.group): + """Initialize the allocator with a preallocated memory pool.""" + # Preallocate the memory pool using torch.empty + self.reg0_size = 1024 # NVL72*8 plus up to 112 flags + self.device = device + self.world_size = torch.distributed.get_world_size(dist_group) + self.myrank = torch.distributed.get_rank(dist_group) + self.dist_group = dist_group + + from ..module.base import get_ub + + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): + self.ub_obj = get_ub("ubnext") + self.internal_pool = self.ub_obj.get_buffer(False).reshape(-1) + self.mc0_ptr = self.ub_obj.init_ubnext() + self.pool_size = self.internal_pool.numel() + else: + alignment = 2 * 1024 * 1024 # memory is allocated in 2MB pages anyways + self.pool_size = int((size_bytes + alignment - 1) / alignment) * alignment +# symm_mem.set_backend("NCCL") + self.internal_pool = symm_mem.empty(self.pool_size, dtype=torch.uint8, device=device) + self.hdl0 = symm_mem.rendezvous(self.internal_pool, dist_group) + self.mc0_ptr = self.hdl0.multicast_ptr + self.internal_pool.fill_(0) + self.internal_pool.view(torch.int64)[: self.world_size].copy_( + torch.tensor(self.hdl0.buffer_ptrs).view(torch.int64) + ) + # self.hdl0.barrier(channel=0) + # Synchronize all processes before proceeding + torch.distributed.barrier(group=dist_group) + + # Track the raw pointer to the pool + self.pool_ptr = self.internal_pool.data_ptr() + # Track allocated segments: (offset, size) + self.allocated: List[Tuple[int, int]] = [] + # Track free segments: (offset, size) + self.freelist: List[Tuple[int, int]] = [(self.reg0_size, self.pool_size - self.reg0_size)] + self.nextpoisoned = None + self.residual = None + self.residual_tokens = 0 + self.tensors = weakref.WeakSet() + self.lock = Lock() + + def allocate(self, nbytes: int) -> Tuple[Optional[int], Optional[torch.Tensor]]: + """Allocate nbytes from the pool, returning a pointer and pool reference.""" + with self.lock: + for i, (offset, size) in enumerate(self.freelist): + if size >= nbytes: + self.freelist.pop(i) + self.allocated.append((offset, nbytes)) + if size > nbytes: + self.freelist.append((offset + nbytes, size - nbytes)) + return self.pool_ptr + offset, self.internal_pool + return None, None + + # No suitable free segment found + raise MemoryError( + f"Preallocated pool exhausted: requested {nbytes} bytes, " + f"available segments: {self.freelist}" + ) + + def free(self, ptr: int): + """Free the memory at ptr, returning it to the pool.""" + with self.lock: + offset = ptr - self.pool_ptr + for i, (alloc_offset, size) in enumerate(self.allocated): + if alloc_offset == offset: + self.allocated.pop(i) + self.freelist.append((offset, size)) + self.freelist.sort(key=lambda x: x[0]) + self._merge_free_segments() + return + # Ignore invalid pointers silently + pass + + raise ValueError(f"Invalid pointer {ptr} not found in allocated segments") + + def _merge_free_segments(self): + """Merge adjacent free segments to reduce fragmentation.""" + if not self.freelist: + return + merged = [] + current_offset, current_size = self.freelist[0] + for offset, size in self.freelist[1:]: + if current_offset + current_size == offset: + # Adjacent segments, merge them + current_size += size + else: + # Non-adjacent, keep current and start new + merged.append((current_offset, current_size)) + current_offset, current_size = offset, size + merged.append((current_offset, current_size)) + self.freelist = merged + + def create_tensor( + self, shape: torch.Size, dtype: torch.dtype = torch.float32 + ) -> Optional[torch.Tensor]: + """Create a PooledTensor using memory from the pool.""" + nbytes = torch.tensor(0, dtype=dtype).element_size() * torch.Size(shape).numel() + ptr, pool = self.allocate(nbytes) + if ptr is None: + return None + offset = ptr - self.pool_ptr + tensor = SymmTensor(pool, offset, torch.Size(shape), dtype, self) + self.tensors.add(tensor) + return tensor + + def allreduce_uc(self, tensor_in: torch.Tensor, hidden_size: int = 0, residual_in: Optional[torch.Tensor] = None,residual_out: Optional[torch.Tensor] = None, fuse_layernorm: bool = False, gamma: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> torch.Tensor: + """Performs in-place allreduce on the given SymmTensor using best algo""" + assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" + + # tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + + ucptr_in = tensor_in.data_ptr() + # mcptr_out = tensor_out.data_ptr() + nbytes = tensor_in.numel() * tensor_in.element_size() + + # Import your pybind module if not imported + from transformer_engine_torch import allreduce_2shot_uc + + allreduce_2shot_uc( + self.world_size, + self.myrank, + to_capsule(self.internal_pool.data_ptr()), + to_capsule(ucptr_in), + to_capsule(ucptr_in), # out + nbytes, + to_capsule(residual_in.data_ptr()) if residual_in is not None else None, + to_capsule(residual_out.data_ptr()) if residual_out is not None else None, + fuse_layernorm, + to_capsule(gamma.data_ptr()) if gamma is not None else None, + eps if eps is not None else 0.0, + hidden_size + ) + return tensor_in + + def allreduce_simple(self, tensor_in: torch.Tensor, hidden_size: int = 0, residual_in: Optional[torch.Tensor] = None,residual_out: Optional[torch.Tensor] = None, fuse_layernorm: bool = False, gamma: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> torch.Tensor: + """Performs in-place allreduce on the given SymmTensor using best algo""" + assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" + + # tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + + mcptr_in = self.mc0_ptr + (tensor_in.data_ptr() - self.internal_pool.data_ptr()) + # mcptr_out = self.hdl.multicast_ptr + (tensor_out.data_ptr() - self.internal_pool.data_ptr()) + nbytes = tensor_in.numel() * tensor_in.element_size() + + # Import your pybind module if not imported + from transformer_engine_torch import allreduce_2shot_mc + + allreduce_2shot_mc( + self.world_size, + self.myrank, + to_capsule(self.internal_pool.data_ptr()), + to_capsule(self.mc0_ptr), + to_capsule(mcptr_in), + to_capsule(mcptr_in), # out + nbytes, + to_capsule(residual_in.data_ptr()) if residual_in is not None else None, + to_capsule(residual_out.data_ptr()) if residual_out is not None else None, + fuse_layernorm, + to_capsule(gamma.data_ptr()) if gamma is not None else None, + eps if eps is not None else 0.0, + hidden_size + ) + return tensor_in + + def allreduce_lamport(self, tensor_in: torch.Tensor, hidden_size: int = 0, residual_in: Optional[torch.Tensor] = None,residual_out: Optional[torch.Tensor] = None, fuse_layernorm: bool = False, gamma: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> torch.Tensor: + """ + Performs allreduce using 2-shot multicast Lamport variant: + - Takes `tensor_in` as input (SymmTensor). + - Allocates `tensor_out` of same shape and dtype. + - Runs `allreduce_2shot_mc_lamport` over them. + - Returns `tensor_out`. + """ + assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" + if self.mc0_ptr is None or self.mc0_ptr == 0: + return self.allreduce_uc(tensor_in,hidden_size,residual_in,residual_out,fuse_layernorm, gamma, eps) + from transformer_engine_torch import allreduce_2shot_mc_lamport + + # Allocate output tensor of same shape/dtype + tensor_out = self.nextpoisoned + poisonedout = True + + if self.nextpoisoned is None or self.nextpoisoned.shape != tensor_in.shape: + if self.nextpoisoned is not None: + del self.nextpoisoned + self.nextpoisoned = None + tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + poisonedout = False + if tensor_out is None: + return self.allreduce_simple(tensor_in,hidden_size,residual_in,residual_out,fuse_layernorm, gamma, eps) + + # alllcate potential output for next allreduce (speculative) and poison it now + self.nextpoisoned = self.create_tensor(tensor_in.shape, tensor_in.dtype) + + # Calculate mcptr_in and mcptr_out with offset relative to internal_pool + offset = tensor_in.data_ptr() - self.internal_pool.data_ptr() + mcptr_in = self.mc0_ptr + offset + mcptr_out = self.mc0_ptr + (tensor_out.data_ptr() - self.internal_pool.data_ptr()) + + # Use clear_ptr to clear output memory before reduction; here we use tensor_out + # clear_ptr = self.nextpoisoned.data_ptr() if self.nextpoisoned is not None else 0 + + nbytes = tensor_in.numel() * tensor_in.element_size() + + # Call your pybind lamport allreduce + allreduce_2shot_mc_lamport( + self.world_size, + self.myrank, + to_capsule(self.internal_pool.data_ptr()), + to_capsule(self.mc0_ptr), + to_capsule(tensor_out.data_ptr()), + to_capsule(mcptr_in), + to_capsule(mcptr_out), + to_capsule(self.nextpoisoned.data_ptr()) if self.nextpoisoned is not None else None, + nbytes, + poisonedout, + to_capsule(residual_in.data_ptr()) if residual_in is not None else None, + to_capsule(residual_out.data_ptr()) if residual_out is not None else None, + fuse_layernorm, + to_capsule(gamma.data_ptr()) if gamma is not None else None, + eps if eps is not None else 0.0, + hidden_size + ) + + return tensor_out + + +_allocator_map: Dict[torch.distributed.group, Tuple[int, "SymmAllocator"]] = {} + + +def ubsymm_request_allocator( + dist_group: torch.distributed.group, + shape: Optional[torch.Size] = None, + dtype: torch.dtype = torch.bfloat16, +) -> None: + if shape is not None: + num_elements = torch.Size(shape).numel() + element_size = torch.tensor(0, dtype=dtype).element_size() + tensor_size = num_elements * element_size + else: + tensor_size = 0 + + if dist_group not in _allocator_map: + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): + assert not _allocator_map, "Current UBNEXT-UB bypass supports only one process group." + _allocator_map[dist_group] = (tensor_size, None) + else: + old_size, allocator = _allocator_map[dist_group] + assert allocator is None, "Second element of tuple must be None" + max_size = max(old_size, tensor_size) + _allocator_map[dist_group] = (max_size, None) + + +def ubsymm_get_sym_tensor( + shape: torch.Size, dtype: torch.dtype, dist_group: torch.distributed.group +) -> torch.Tensor: + if dtype != torch.bfloat16: + return None # Unsupported dtype, do fallback to nccl + if dist_group not in _allocator_map: + return None # No allocator requested earlier, do fallback to nccl + (max_size, allocator) = _allocator_map[dist_group] + if allocator is None: + new_max_size = int( + os.environ.get("NVTE_UB_SYMM_POOL_SIZE", ((6 * max_size + 1048575) / 1024 / 1024)) + ) + allocator = SymmAllocator( + new_max_size * 1024 * 1024, + torch.device(f"cuda:{torch.cuda.current_device()}"), + dist_group, + ) + _allocator_map[dist_group] = (new_max_size, allocator) + return allocator.create_tensor(shape, dtype) + + +def ubsymm_allreduce(tensor_in: SymmTensor,residual_global: Optional[torch.Tensor] = None, gamma: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> SymmTensor: + """ + Performs allreduce on the given SymmTensor using best algo + Four modes: + standalone allreduce: no residual, no layernorm (residual_global passed by user, both eps and gamma=None) + first PROJ layer: layernorm fused, global residual in, internal residual out (residual_global passed by user, both eps and gamma not None) + this allocates internal residual if it wasnt allocated previously or token count is different from previous allreduce + middle layers: layernorm fused, internal residual in, internal residual out (residual_global=None, both eps and gamma not None) + Last FC2 layer: no layernorm, internal residual in, no residual out(layer output is actually the global residual) (residual_global=None, fboth eps and gamma=None) + this is different from standalone once there is no internal residual allocated + """ + fuse_layernorm = gamma is not None and eps is not None + internal_residual = tensor_in._allocator.residual + num_ranks = tensor_in._allocator.world_size + hidden_size = tensor_in.shape[-1] if fuse_layernorm or internal_residual is not None or residual_global is not None else tensor_in.numel() // num_ranks + num_tokens = tensor_in.numel() // hidden_size + myrank = tensor_in._allocator.myrank + if residual_global is not None and (internal_residual is None or tensor_in._allocator.residual_tokens != num_tokens): + my_tokens = num_tokens // num_ranks + extra_tokens = num_tokens % num_ranks + first_token = myrank*my_tokens + if myrank < extra_tokens: + my_tokens += 1 + first_token += myrank + else: + first_token += extra_tokens + if my_tokens == 0: + my_tokens = 1 #avoid empty residual + if tensor_in._allocator.residual is not None: + del tensor_in._allocator.residual + tensor_in._allocator.residual = torch.empty(my_tokens*hidden_size, dtype=tensor_in.dtype, device=tensor_in.device) + tensor_in._allocator.residual_tokens = num_tokens + internal_residual = tensor_in._allocator.residual + + residual_in = residual_global if residual_global is not None else internal_residual + + residual_out = internal_residual if fuse_layernorm else None #without layernorm new full residual is output of allreduce + if tensor_in.numel() > 1048576: + return tensor_in._allocator.allreduce_simple(tensor_in,hidden_size,residual_in,residual_out,fuse_layernorm, gamma, eps) + else: + return tensor_in._allocator.allreduce_lamport(tensor_in,hidden_size,residual_in,residual_out,fuse_layernorm, gamma, eps) + + +def ubsymm_free_residual(tensor_in: SymmTensor): + if tensor_in._allocator.residual is not None: + del tensor_in._allocator.residual + tensor_in._allocator.residual_tokens = 0 + tensor_in._allocator.residual = None + \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 38947c5a9d..d0eed0e716 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -260,12 +260,12 @@ void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) // Userbuffers data void *dst_ptr; if (local_chunk) { - NVTE_CHECK(_ubufs[_tp_id].numel() == input_size, + NVTE_CHECK(_ubufs[_tp_id].numel() >= input_size, "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", "(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); dst_ptr = _ubufs[_tp_id].dptr(); } else { - NVTE_CHECK(_ubuf.numel() == input_size, + NVTE_CHECK(_ubuf.numel() >= input_size, "Tried to copy an invalid tensor into a Userbuffers buffer ", "(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")"); dst_ptr = _ubuf.dptr(); @@ -282,11 +282,11 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional intra_domain_group) { +#ifndef NVTE_UB_WITH_MPI + pgs.insert({"world", world_group}); + myrank = pgs["world"]->getRank(); + numranks = pgs["world"]->getSize(); + c10d::ProcessGroup::BackendType backend = pgs["world"]->getBackendType(); + backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); + + if (intra_domain_group.has_value()) { + // Get local rank on node and number of local ranks + NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, + "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", + "group!", pgs["world"]->getBackendName()); + pgs.insert({"intra", intra_domain_group.value()}); + mylocal = pgs["intra"]->getRank(); + numlocal = pgs["intra"]->getSize(); + + if (numlocal == numranks) { + // Intra-node group is same as the world group so there can only be 1 node + NVTE_CHECK( + mylocal == myrank, + "Internal TE error: Local rank must be equal to global rank when intra-node group size ", + "is equal to the world group size!"); + mynode = 0; + numnodes = 1; + } else { + // Get node ID and number of nodes + mynode = myrank / numlocal; + numnodes = numranks / numlocal; + } + } else { + // Intra-node group is not set so we assume there is only 1 node + mylocal = myrank; + numlocal = numranks; + pgs.insert({"intra", world_group}); + + mynode = 0; + numnodes = 1; + } + + initialized = true; +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_UB_WITH_MPI=1!"); +#endif +} + +CommOverlapHelper::~CommOverlapHelper() { +#ifndef NVTE_UB_WITH_MPI + for (auto &pg : pgs) pg.second = nullptr; + backend_is_nccl = false; + initialized = false; +#endif +} + +void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void *localdata, + size_t localbytes, ExtComm group) { +#ifndef NVTE_UB_WITH_MPI + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + + auto localtensor = + torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; + auto globaltensor = + torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; + + std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; + std::vector localchunk = {localtmp}; + auto work = pgs[group]->allgather(globalchunks, localchunk); + work->wait(); + + if (backend_is_nccl) { + globaltensor.copy_(globaltmp.cpu()); + globaltmp = torch::Tensor(); + localtmp = torch::Tensor(); + } +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_allgather is a no-op when TE is compiled ", + "with NVTE_UB_WITH_MPI=1!"); +#endif +} + +void CommOverlapHelper::ub_barrier(ExtComm group) { +#ifndef NVTE_UB_WITH_MPI + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + auto work = pgs[group]->barrier(); + work->wait(); +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_barrier is a no-op when TE is compiled ", + "with NVTE_UB_WITH_MPI=1!"); +#endif +} + +/*************************************************************************************************** + * CommOverlap + **************************************************************************************************/ + +CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits, + int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) + : te::CommOverlapBase(buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), + helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, + helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, + num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} + +/* +** Helper function to copy input to _ubuf +*/ +void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) { + const auto &input_ = input.contiguous(); + + // Check element size + const size_t element_size = input.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t input_size = input_.numel(); + const void *src_ptr = input_.data_ptr(); + + // Userbuffers data + const size_t ubuf_size = _ubuf.numel(); + void *dst_ptr = _ubuf.dptr(); + if (local_chunk) { + NVTE_CHECK(input_size * _tp_size == ubuf_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(input_size=", input_size, ", tensor_parallel_size=", _tp_size, + ", ubuf_size=", ubuf_size, ")"); + dst_ptr = (reinterpret_cast(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size); + } else { + NVTE_CHECK(input_size == ubuf_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(input_size=", input_size, ", ubuf_size=", ubuf_size, ")"); + } + + // Copy data + auto stream_main = at::cuda::getCurrentCUDAStream(); + NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size, + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); +} + +at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional> shape) { + // Check buffer shape + const size_t ubuf_size = _ubuf.numel(); + if (shape) { + const size_t requested_size = transformer_engine::pytorch::product(*shape); + if (local_chunk) { + NVTE_CHECK(requested_size * _tp_size == ubuf_size, + "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape, + ", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")"); + } else { + NVTE_CHECK(requested_size == ubuf_size, + "Invalid shape for a Userbuffers buffer (requested shape=", *shape, + ", ubuf_size=", ubuf_size, ")"); + } + } else { + int64_t dim0 = _ubuf.size(0); + int64_t dim1 = _ubuf.size(1); + if (local_chunk) { + dim0 /= _tp_size; + } + shape = {dim0, dim1}; + } + + // Data pointer + void *ubuf_ptr = _ubuf.dptr(); + if (local_chunk) { + ubuf_ptr = (reinterpret_cast(ubuf_ptr) + + (ubuf_size / _tp_size) * _tp_id * _ubuf.element_size()); + } + + // Construct PyTorch tensor + const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); + return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); +} + +std::pair CommOverlap::get_communication_stream() { + // Return the same stream for both send and recv + return {at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()), + at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device())}; +} + +/*************************************************************************************************** + * CommOverlapP2P + **************************************************************************************************/ + +CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + te::CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, + bool aggregate) + : te::CommOverlapP2PBase( + buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, + tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm, aggregate) {} + +/* +** Copy input to _ubufs[0] +*/ +void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) { + const auto &input_ = input.contiguous(); + + // Check element size + const size_t element_size = input.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t input_size = input_.numel(); + const void *src_ptr = input_.data_ptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == input_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == input_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size, + cudaMemcpyDeviceToDevice, + (cudaStream_t)at::cuda::getCurrentCUDAStream())); +} + +at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional> shape) { + // Check buffer shape + if (shape) { + const size_t requested_size = transformer_engine::pytorch::product(*shape); + if (local_chunk) { + NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(), + "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape, + ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + } else { + NVTE_CHECK(requested_size == _ubuf.numel(), + "Invalid shape for a Userbuffers buffer (requested shape=", *shape, + ", ubuf_size=", _ubuf.numel(), ")"); + } + } else { + int64_t dim0 = _ubuf.size(0); + int64_t dim1 = _ubuf.size(1); + if (local_chunk) { + dim0 /= _tp_size; + } + shape = {dim0, dim1}; + } + + // Data pointer + void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr(); + + // Construct PyTorch tensor + const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); + return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); +} + +std::pair CommOverlapP2P::get_communication_stream() { + return {at::cuda::getStreamFromExternal(_stream_send[0], at::cuda::current_device()), + at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())}; +} + +void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm( + CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) { + auto main_stream = at::cuda::getCurrentCUDAStream(); + allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream), + at::cuda::CUDAStream(recv_stream), main_stream); +} diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index bf4fb97d2d..fc4c28f8bc 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -311,6 +311,9 @@ def initialize_ub( "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], "external": ["proj_wgrad", "fc2_wgrad"], } + # Add "ubnext" to bulk methods if environment variable is set + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): + methods["bulk"].append("ubnext") # AG-RS overlap pairs of layers forming a tensor-parallel block ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} @@ -435,8 +438,12 @@ def add_ub( ) else: ub_obj = tex.CommOverlap( - shape, # Communication buffer shape - buffer_dtype, # Communication buffer data type + ( + shape + if name != "ubnext" + else (int(os.environ.get("NVTE_UB_SYMM_POOL_SIZE", 64)), 1024 * 1024) + ), # Communication buffer shape + buffer_dtype if name != "ubnext" else torch.uint8, # Communication buffer data type helper, # Helper for torch.distributed callbacks during bootstrapping tp_size, # Tensor-parallel group size (may be different than local_size) num_splits=num_splits, diff --git a/transformer_engine/pytorch/module/base.py.orig b/transformer_engine/pytorch/module/base.py.orig new file mode 100644 index 0000000000..bf4fb97d2d --- /dev/null +++ b/transformer_engine/pytorch/module/base.py.orig @@ -0,0 +1,1597 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Base modules and utilities for TransformerEngine PyTorch API""" +import io +import math +import os +import pickle +import warnings +from enum import Enum +from abc import ABC, abstractmethod +from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union +from contextlib import contextmanager +import logging +from types import MethodType + +import torch +import torch.nn.functional as F + +import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe + +from ._common import _ParameterInitMeta, noop_cat +from ..fp8 import ( + MXFP8BlockScalingRecipeState, + DelayedScalingRecipeState, + Float8CurrentScalingRecipeState, + Float8BlockScalingRecipeState, + NVFP4BlockScalingRecipeState, + FP8GlobalStateManager, + RecipeState, +) +from ..distributed import ( + gather_along_first_dim, + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, + _fsdp_gather_tensors, +) +from ..constants import dist_group_type +from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer +from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer +from ..tensor._internal.float8_tensor_base import Float8TensorBase +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ...common.recipe import DelayedScaling, Recipe +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor +from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled + +__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] + +_2X_ACC_FPROP = False +_2X_ACC_DGRAD = True +_2X_ACC_WGRAD = True +_multi_stream_cublas_workspace = [] +_dummy_wgrads = {} +_cublas_workspace = None +_ub_communicators = None +_NUM_MAX_UB_STREAMS = 3 +_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None +layers_atomic_ring_exchange = [] + + +class UserBufferQuantizationMode(Enum): + """ + UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer. + """ + + NONE = "none" + FP8 = "fp8" + + +def get_cublas_workspace_size_bytes() -> None: + """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" + if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: + # 32 MiB for NVFP4 GEMM, plus 256 B for misc scales + return 32 * 1024 * 1024 + 256 + return 4_194_304 + + +def get_workspace() -> torch.Tensor: + """Returns workspace for cublas.""" + global _cublas_workspace + if _cublas_workspace is None: + _cublas_workspace = torch.empty( + get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" + ) + return _cublas_workspace + + +def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: + """Returns workspace for multi-stream cublas.""" + global _multi_stream_cublas_workspace + if not _multi_stream_cublas_workspace: + for _ in range(tex.get_num_cublas_streams()): + _multi_stream_cublas_workspace.append( + torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda") + ) + return _multi_stream_cublas_workspace + + +def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: + """Returns a dummy tensor of given shape.""" + assert len(shape) == 2 + global _dummy_wgrads + if (shape[0], shape[1], dtype) not in _dummy_wgrads: + _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( + shape, + dtype=dtype, + device="cuda", + requires_grad=False, + ) + if zero: + _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) + return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() + + +def initialize_ub( + shape: list, + tp_size: int, + use_fp8: bool = False, + quantization_modes: List[UserBufferQuantizationMode] = None, + dtype: torch.dtype = torch.bfloat16, + ub_cfgs: Optional[Union[dict, List[dict]]] = None, + bootstrap_backend: Union[str, torch.distributed.Backend] = None, +) -> None: + r""" + Initialize the Userbuffers communicator for overlapping tensor-parallel communications with + GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules. + + Parameters + ---------- + shape : list + shape of the communication buffer, typically set to be the same as the global shape of + the input tensor to a te.TransformerLayer forward pass, with the sequence and batch + dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)` + tp_size : int + number of GPUs in the tensor-parallel process group + use_fp8 : bool = False + allocate the communication buffer for FP8 GEMM inputs/outputs. + DEPRECATED: Please use `quantization_modes` instead. + quantization_modes : List[UserBufferQuantizationMode] = None + if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list. + falls back to the legacy `use_fp8` parameter if `None` is provided. + dtype : torch.dtype = torch.bfloat16 + non-FP8 data type of the communication buffer when `use_fp8 = False` + ub_cfgs: dict = None + Configuration dictionary with the structure + ``` + { + : { + "method": <"ring_exchange" or "pipeline">, + "is_reduce_scatter": bool, + "num_sm": int, + "cga_size": int, + "set_sm_margin": bool, + "num_splits": int, + "aggregate": bool, + "atomic_gemm": bool, + "use_ce": bool, + "fp8_buf": bool, + } + } + ``` + for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", + "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", + "fc2_fprop", "fc2_wgrad"]`. + a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes` + bootstrap_backend : str = None + `torch.distributed` communication backend for the all-gather, broadcast and + barrier collectives during Userbuffers initialization. Not all backends are + valid for every cluster configuration and distributed launch method even if + they are available in PyTorch. When left unset, the initialization prefers + to use the MPI backend, falling back first on Gloo and then NCCL if MPI is + not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this + option and always initializes Userbuffers with direct MPI calls in C++, + which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time. + """ + if not tex.device_supports_multicast(): + assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( + "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " + + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." + ) + + if not quantization_modes: + warnings.warn( + "Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes" + " instead.", + DeprecationWarning, + ) + quantization_modes = [ + UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE + ] + else: + assert isinstance(quantization_modes, list), "quantization_modes must be a list" + assert all( + isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes + ), "quantization_modes must be a list of UserBufferQuantizationMode" + + if isinstance(ub_cfgs, dict) or ub_cfgs is None: + ub_cfgs = [ub_cfgs] * len(quantization_modes) + else: + assert len(ub_cfgs) == len( + quantization_modes + ), "Number of ub_cfgs settings must match number of quantization configurations" + + global _ub_communicators + assert _ub_communicators is None, "UB communicators are already initialized." + _ub_communicators = {} + + if tex.ubuf_built_with_mpi(): + # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force + # an MPI_Init() here by creating a new MPI process group... + assert torch.distributed.is_mpi_available() + _ = torch.distributed.new_group(backend="mpi") + helper = tex.CommOverlapHelper() + else: + # Bootstrapping with torch.distributed API, so check backend and construct + # intra/inter-node process groups... + assert ( + torch.distributed.is_initialized() + ), "torch.distributed must be initialized before Userbuffers" + if bootstrap_backend is None: + bootstrap_backend = "nccl" + if torch.distributed.is_mpi_available(): + bootstrap_backend = "mpi" + elif torch.distributed.is_gloo_available(): + bootstrap_backend = "gloo" + else: + assert bootstrap_backend in [ + "gloo", + "mpi", + "nccl", + ], "Invalid torch.distributed backend for bootstrapping Userbuffers!" + assert torch.distributed.is_backend_available(bootstrap_backend), ( + f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " + f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." + ) + + world_group = torch.distributed.new_group(backend=bootstrap_backend) + world_rank = torch.distributed.get_rank(world_group) + world_size = torch.distributed.get_world_size(world_group) + + num_domains = world_size // tp_size + mydomain_idx = world_rank // tp_size + if num_domains > 1: + ranks_per_domain_list = [ + [i * tp_size + t for t in range(tp_size)] for i in range(num_domains) + ] + tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( + ranks_per_domain_list, backend=bootstrap_backend + ) + local_rank = torch.distributed.get_rank(tp_domain_group) + tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group) + + helper = tex.CommOverlapHelper(world_group, tp_domain_group) + else: + # TP model on single NVLink domain, no replication, no data-parallelism + mydomain_idx = 0 + local_rank = world_rank + tp_domain_ranks = list(range(world_size)) + + helper = tex.CommOverlapHelper(world_group) + + if world_rank == 0: + print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True) + if local_rank == 0: + print( + f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n", + end="", + flush=True, + ) + + # Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls + global _cublas_workspace + if _cublas_workspace is None: + _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) + elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS: + # This ensures we don't do `.repeat()` on an already expanded workspace + _cublas_workspace = torch.empty( + get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" + ).repeat(_NUM_MAX_UB_STREAMS) + + # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe + layers_all_gather_overlap = [ + "qkv_fprop", + "qkv_dgrad", + "proj_dgrad", + "proj_wgrad", + "fc1_fprop", + "fc1_dgrad", + "fc2_dgrad", + "fc2_wgrad", + ] + layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] + dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] + # Default overlap methods for layers + methods = { + "ring_exchange": [ + "qkv_fprop", + "fc1_fprop", + "proj_dgrad", + "fc2_dgrad", + ], + "pipeline": ["proj_fprop", "fc2_fprop"], + "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], + "external": ["proj_wgrad", "fc2_wgrad"], + } + + # AG-RS overlap pairs of layers forming a tensor-parallel block + ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} + rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} + external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"} + global layers_atomic_ring_exchange + layers_atomic_ring_exchange = [] + + def get_method(name): + for method, names in methods.items(): + if name in names: + return method + raise KeyError(f"Given layer name {name} does not exist.") + + def get_default_config(name): + global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY + method = get_method(name) + is_reduce_scatter = name in layers_reduce_scatter_overlap + if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() + default_cfg = { + "method": method, + "is_reduce_scatter": is_reduce_scatter, + "num_sm": 1 if method == "ring_exchange" else 16, + "cga_size": 1 if method == "ring_exchange" else 2, + "set_sm_margin": not method == "ring_exchange", + "num_splits": tp_size if method == "ring_exchange" else 4, + "aggregate": False, + "atomic_gemm": False, + "use_ce": True, + "fp8_buf": name in layers_all_gather_overlap, + "comm_priority": _MAX_STREAM_PRIORITY, + "gemm_priority": _MIN_STREAM_PRIORITY, + "pipeline_rs_overlap_first_gemm": False, + } + return default_cfg + + def add_ub( + name: str, + quantization_mode: UserBufferQuantizationMode, + method: str, + is_reduce_scatter: bool, + num_sm: int = 16, + cga_size: int = 2, + set_sm_margin: bool = False, + num_splits: int = 0, + aggregate: bool = False, + atomic_gemm: bool = False, + use_ce: bool = True, + fp8_buf: bool = False, + comm_priority: int = 0, + gemm_priority: int = 0, + pipeline_rs_overlap_first_gemm: bool = False, + ) -> None: + if atomic_gemm: + warnings.warn( + "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." + ) + assert ( + quantization_mode == UserBufferQuantizationMode.FP8 + ), "Atomic GEMM overlap supported only for FP8 GEMM." + if method in ("bulk", "external"): + warnings.warn( + f"At {name}, atoimic GEMM not is supported for a bulk overlap." + "Defaulting to `atomic_gemm=False`." + ) + atomic_gemm = 0 + if not is_reduce_scatter and method == "pipeline": + raise ValueError( + f"At {name}, `pipeline` overlap method is not supported for AllGather." + ) + # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`. + # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality. + global layers_atomic_ring_exchange + if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs: + layers_atomic_ring_exchange += [name, ag_rs_pairs[name]] + if name in rs_ag_pairs: + assert_message = ( + f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk " + "outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and " + "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses " + "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config " + "for functionality." + ) + if name in layers_atomic_ring_exchange: + assert atomic_gemm and method == "ring_exchange", assert_message + else: + if atomic_gemm and method == "ring_exchange": + assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message + + if name in external_gemm_to_overlap: + assert method == "external", ( + f"At {name}, `external` overlap method is specified, but the selected method is" + f" {method}" + ) + assert external_gemm_to_overlap[name] in methods["ring_exchange"], ( + f"At {name}, `external` overlap method is specified, but the external gemm" + f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" + ) + + buffer_dtype = ( + torch.uint8 + if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf) + else dtype + ) + if method == "ring_exchange": + ub_obj = tex.CommOverlapP2P( + shape, # Communication buffer shape + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping + tp_size, # Tensor-parallel group size (may be different than local_size) + tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, + use_ce=use_ce, + aggregate=aggregate, + gemm_priority=gemm_priority, + comm_priority=comm_priority, + ) + else: + ub_obj = tex.CommOverlap( + shape, # Communication buffer shape + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping + tp_size, # Tensor-parallel group size (may be different than local_size) + num_splits=num_splits, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, + gemm_priority=gemm_priority, + comm_priority=comm_priority, + rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, + ) + _ub_communicators[(name, quantization_mode)] = ub_obj + + for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs): + if user_ub_cfg is not None: + for name in dgrad_reduce_scatter_overlap: + if ( + name in user_ub_cfg + and "method" in user_ub_cfg[name] + and user_ub_cfg[name]["method"] != "bulk" + ): + wgrad_name = name.replace("dgrad", "wgrad") + assert wgrad_name not in user_ub_cfg + layers_reduce_scatter_overlap.remove(wgrad_name) + layers_all_gather_overlap.remove(name) + layers_reduce_scatter_overlap.append(name) + methods["bulk"].remove(name) + new_method = user_ub_cfg[name]["method"] + methods[new_method].append(name) + + for name in ( + methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] + ): + ub_cfg = get_default_config(name) + if user_ub_cfg is not None and name in user_ub_cfg: + fp8_buf = (name in layers_all_gather_overlap) or ( + user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"] + ) + ub_cfg.update(user_ub_cfg[name]) + ub_cfg["fp8_buf"] = fp8_buf + add_ub(name, quantization_mode, **ub_cfg) + + +def get_ub(name: str, use_fp8: bool): + """Get userbuffer communicator corresponding to give key.""" + # For now use `use_fp8` boolean input as it matches the current design in the modules + # So favour simplicity until the correct design becomes clear. + # This is mainly an internal API so we don't need to worry about future changes + key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE) + assert _ub_communicators is not None, "UB manager is not initialized." + assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered." + return _ub_communicators[key] + + +def destroy_ub(): + """Destroy all allocated userbuffer communicators.""" + global _ub_communicators + _ub_communicators = None + global layers_atomic_ring_exchange + layers_atomic_ring_exchange = [] + + +def fill_userbuffers_buffer_for_all_gather( + comm, + local_tensor: torch.Tensor, + quantizer: Optional[Quantizer], + process_group, +) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]: + """Fill local shard of Userbuffers buffer with data for all-gather + + Returns the full tensor and the local shard, both using the + Userbuffers buffer as their underlying data. These tensors should + be used carefully (e.g. only immediately before and after a + Userbuffers operation) since the underlying data may be + overwritten by other Userbuffers operations. + + May perform blocking communication if needed for the gathered + tensor's metadata, e.g. scaling factors. + + """ + + # Tensor dimensions + local_shape = local_tensor.size() + if not local_shape: + raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})") + process_group_size = torch.distributed.get_world_size(process_group) + global_shape = list(local_shape) + global_shape[0] *= process_group_size + + # Unquantized data + if quantizer is None: + if isinstance(local_tensor, QuantizedTensorBase): + local_tensor = local_tensor.dequantize() + if comm.is_fp8_ubuf(): + raise RuntimeError( + "Attempting to all-gather unquantized tensor, " + "but Userbuffers is initialized with FP8 buffers" + ) + comm.copy_into_buffer(local_tensor, local_chunk=True) + global_tensor = comm.get_buffer(shape=global_shape) + return global_tensor, local_tensor + + # FP8 data + if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + if not isinstance(local_tensor, Float8TensorBase): + if isinstance(local_tensor, QuantizedTensorBase): + local_tensor.dequantize() + quantizer.set_usage(rowwise=True, columnwise=False) + local_tensor = quantizer(local_tensor) + if not comm.is_fp8_ubuf(): + raise RuntimeError( + "Attempting to all-gather FP8 tensor, " + "but Userbuffers is not initialized with FP8 buffers" + ) + comm.copy_into_buffer(local_tensor._data, local_chunk=True) + global_tensor_data = comm.get_buffer(shape=global_shape) + global_tensor = Float8TensorBase( + data=global_tensor_data, + fp8_scale_inv=local_tensor._scale_inv, + fp8_dtype=local_tensor._fp8_dtype, + quantizer=quantizer, + ) + return global_tensor, local_tensor + + # MXFP8 data + if isinstance(quantizer, MXFP8Quantizer): + + # Cast to MXFP8 if needed + if not isinstance(local_tensor, MXFP8TensorBase): + if isinstance(local_tensor, QuantizedTensorBase): + local_tensor.dequantize() + local_tensor = quantizer(local_tensor) + if not comm.is_fp8_ubuf(): + raise RuntimeError( + "Attempting to all-gather MXFP8 tensor, " + "but Userbuffers is not initialized with FP8 buffers" + ) + + # Check which MXFP8 buffer to communicate + if quantizer.rowwise_usage == quantizer.columnwise_usage: + raise ValueError( + "Userbuffers can only communicate one MXFP8 buffer at a time, " + f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, " + f"columnwise_usage={quantizer.columnwise_usage}" + ) + with_rowwise_data = quantizer.rowwise_usage + + # Copy MXFP8 data to local chunk of Userbuffers buffer + local_data = ( + local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data + ) + comm.copy_into_buffer(local_data, local_chunk=True) + + # Gather scaling-inverses + if math.prod(local_shape[:-1]) % 128 != 0: + raise ValueError( + "Userbuffers requires MXFP8 tensor dims that are divisible by 128, " + f"but got MXFP8 tensor with shape={tuple(local_shape)}" + ) + local_scale_inv = ( + local_tensor._rowwise_scale_inv + if with_rowwise_data + else local_tensor._columnwise_scale_inv + ) + local_scale_inv_size = list(local_scale_inv.size()) + global_scale_inv = torch.empty( + [process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:], + dtype=local_scale_inv.dtype, + device=local_scale_inv.device, + ) + torch.distributed.all_gather_into_tensor( + global_scale_inv, + local_scale_inv, + group=process_group, + ) + + # Construct MXFP8 tensor with Userbuffers buffer + rowwise_data, rowwise_scale_inv = None, None + columnwise_data, columnwise_scale_inv = None, None + global_data = comm.get_buffer(shape=global_shape) + if with_rowwise_data: + rowwise_data, rowwise_scale_inv = global_data, global_scale_inv + else: + columnwise_data, columnwise_scale_inv = global_data, global_scale_inv + global_tensor = MXFP8TensorBase( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=local_tensor._fp8_dtype, + quantizer=quantizer, + ) + return global_tensor, local_tensor + + # Unsupported data format + raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})") + + +class TransformerEngineBaseModule(torch.nn.Module, ABC): + """Base TE module.""" + + def __init__(self) -> None: + super().__init__() + assert torch.cuda.is_available(), "TransformerEngine needs CUDA." + self.name = None + self.next_iter_when_debug_should_be_run = 0 + self.fp8_initialized = False + self.fp8 = False + self.fp8_calibration = False + self.fp8_meta = {} + self.fp8_meta["fp8_checkpoint"] = False + self.fp8_meta["fp8_group"] = None + self.fp8_meta_tensors_initialized = False + self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} + self.tp_group = None + self.tp_size = 1 + self.sequence_parallel = False + self.param_init_meta = {} + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() + self.fsdp_wrapped = False + self.fsdp_group = None + self._fp8_workspaces: Dict[str, QuantizedTensor] = {} + self.activation_dtype: Optional[torch.dtype] = None + self.wgrad_accumulation_and_reduce_hooks = [] + + if not TEDebugState.debug_enabled: + TEDebugState.initialize() + + # Names of attributes that can be set quickly (see __setattr__ + # method) + _fast_setattr_names: Set[str] = { + "activation_dtype", + "fp8", + "fp8_initialized", + "fp8_calibration", + "fp8_parameters", + } + + def __setattr__(self, name: str, value: Any) -> None: + if name in TransformerEngineBaseModule._fast_setattr_names: + # torch.nn.Module has a custom __setattr__ that handles + # modules, parameters, and buffers. This is unnecessary + # overhead when setting plain attrs. + self.__dict__[name] = value + else: + # Default case + super().__setattr__(name, value) + + def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: + """ + Delayed scaling only. + + Increase or decrease size of amax history based on given `length`. + + .. warning:: + This changes the underlying amax memory location. + """ + if fwd is None: + fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd") + else: + fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",) + + for meta_key in fp8_meta_tensor_keys: + if meta_key not in self.fp8_meta: + # Handles non-parameter FP8 modules, e.g. DPA. + continue + curr_len = self.fp8_meta[meta_key].amax_history.shape[0] + if length == curr_len: + continue + if length < curr_len: + self.fp8_meta[meta_key].amax_history = ( + self.fp8_meta[meta_key].amax_history[:length].clone() + ) + elif length > curr_len: + extra_rows = length - curr_len + self.fp8_meta[meta_key].amax_history = F.pad( + self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows) + ) + + # Update quantizers with new amax pointers. + self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers() + # Make sure weight tensors has correct quantizers + self._update_weight_quantizers() + + # Update the global buffers with new amax and history pointers. + if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: + fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[ + FP8GlobalStateManager.get_buffer_info() + ] + for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): + if buffer_key in FP8GlobalStateManager.global_amax_buffer: + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[ + meta_key + ].amax_history[0] + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( + self.fp8_meta[meta_key].amax_history + ) + + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + + # Return early if recipe state matches recipe + if self.fp8_meta_tensors_initialized: + recipe_state = self.fp8_meta[fp8_meta_tensor_key] + if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState): + self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd) + return + if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): + return + if recipe.float8_current_scaling() and isinstance( + recipe_state, Float8CurrentScalingRecipeState + ): + return + if recipe.float8_block_scaling() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): + return + if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): + return + + # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and + # 2 (grad_output and grad_input) for bwd + num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 + + # Initialize recipe state and quantizers + recipe_state = RecipeState.create( + recipe, + mode=("forward" if fwd else "backward"), + num_quantizers=num_fp8_tensors, + ) + + self.fp8_meta[fp8_meta_tensor_key] = recipe_state + self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() + + def _update_weight_quantizers(self) -> None: + """Update the quantizers for the weight tensors.""" + weight_tensors = self._get_weight_tensors() + weight_quantizers = self._get_weight_quantizers() + assert len(weight_tensors) == len(weight_quantizers), ( + f"Number of weight tensors ({len(weight_tensors)}) and quantizers " + f"({len(weight_quantizers)}) must match" + ) + for weight, quantizer in zip(weight_tensors, weight_quantizers): + if quantizer is not None and isinstance(weight, QuantizedTensorBase): + weight.update_quantizer(quantizer) + + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + """Get the weight tensors of the module.""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement _get_weight_tensors function" + ) + + def _get_weight_quantizers(self) -> List[Quantizer]: + """Get the weight quantizers of the module.""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement _get_weight_quantizers function" + ) + + def init_fp8_meta_tensors(self, recipe: Recipe) -> None: + """Init scales and amaxes.""" + self.set_meta_tensor(True, recipe) + self.set_meta_tensor(False, recipe) + + self.fp8_meta_tensors_initialized = True + + def get_fp8_meta_tensors(self) -> None: + """Get scales and amaxes.""" + fwd_key, bwd_key = "scaling_fwd", "scaling_bwd" + if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta: + return None + + fp8_meta_tensors = {fwd_key: [], bwd_key: []} + with torch.no_grad(): + for key in (fwd_key, bwd_key): + fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone()) + fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone()) + return fp8_meta_tensors + + def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None: + """Reset scales and amaxes.""" + + def reset(key): + if key in self.fp8_meta: + if fp8_meta_tensors is None: + self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) + self.fp8_meta[key].amax_history.copy_( + torch.zeros_like(self.fp8_meta[key].amax_history) + ) + else: + assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." + self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) + self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1]) + + with torch.no_grad(): + reset("scaling_fwd") + reset("scaling_bwd") + + def get_extra_state(self) -> torch.Tensor: + """Save before checkpointing.""" + + # This implementation is working around a few issues: + # + # (1) PyTorch's "extra state" infrastructure might be able to + # support any picklable type, but they make no guarantees. + # We have experienced problems (e.g. in ONNX export) with + # non-tensor extra state. + # (2) PyTorch's checkpointing infrastructure does not remap + # devices for "extra state" like it does for "state dict". + # Thus, we want to avoid putting extra state on the GPU + # since it may be loaded on the wrong device. + # (3) The extra state consists of many small tensors. If we + # want to copy them all to CPU, then we need to avoid the + # overhead of many GPU-CPU memory transfers. + # + # See: https://github.com/NVIDIA/TransformerEngine/pull/351 + # See: https://github.com/NVIDIA/TransformerEngine/pull/363 + + def to_cpu(src: torch.Tensor) -> torch.Tensor: + """Helper function to make CPU copy of tensor + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + dst = torch.empty_like(src, device="cpu") + dst.copy_(src, non_blocking=True) + return dst + + # Store FP8 state if needed + state = None + fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration + if not fp8_checkpoint: + return torch.empty(0, dtype=torch.uint8) + + # Copy tensors to CPU and store + state = {} + state["recipe"] = self.fp8_meta["recipe"] + if state["recipe"].delayed(): + state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) + state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) + state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) + state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history) + + # Store other pickelable values + extra = {} + for k, v in self.fp8_meta.items(): + if k != "buffer_index_and_autocast_key" and isinstance( + v, (bool, int, float, str, tuple, list) + ): + extra[k] = v + state["extra_fp8_variables"] = extra + + # Serialize state into byte tensor + torch.cuda.synchronize() + state_serialized = bytearray(pickle.dumps(state)) + state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) + return state_serialized + + def set_extra_state(self, state: torch.Tensor) -> None: + """Load previous state.""" + + # Maintain backwards compatibility with older checkpoints. + if state is None: + return + + # Load state + if isinstance(state, torch.Tensor): + # No FP8 is indicated by an empty tensor we don't need to unpickle. + if state.numel() == 0: + return + # Default format: byte tensor with pickled data + state = pickle.loads(state.detach().cpu().numpy().tobytes()) + elif isinstance(state, io.BytesIO): + # Deprecated format with io.BytesIO + state.seek(0) + state = torch.load(state, map_location="cuda") + else: + raise RuntimeError("Unsupported checkpoint format.") + + if state is None: + return + + # TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing + if "recipe" not in state: + # TE 1.x only supported delayed scaling, which was the default recipe + state["recipe"] = DelayedScaling() + # TE 1.x also saved scale_inv, which is not needed with Recipe object + state.pop("scale_inv_fwd", None) + state.pop("scale_inv_bwd", None) + + # Load extra items + self.fp8_meta.update(state["extra_fp8_variables"]) + self.fp8_meta["recipe"] = state["recipe"] + if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta: + del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] + + # Initialize before loading + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + + def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: + """Helper function to copy tensor from CPU + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + dst.copy_(src, non_blocking=True) + + # Load tensors + if self.fp8_meta["recipe"].delayed(): + copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) + copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) + copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) + copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history) + torch.cuda.synchronize() + + def set_activation_dtype(self, inp: torch.Tensor) -> None: + """Get activation data type for AMP.""" + # Native AMP (`torch.autocast`) gets highest priority + if torch.is_autocast_enabled(): + self.activation_dtype = torch_get_autocast_gpu_dtype() + return + + # All checks after this have already been performed once, thus skip + if self.activation_dtype == inp.dtype: + return + + dtype = inp.dtype + if not self.allow_different_data_and_param_types: + for name, param in self.named_parameters(): + if param is not None: + assert dtype == param.dtype, ( + "Data types for parameters must match when outside of autocasted region. " + f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) + self.activation_dtype = dtype + + def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: + """ + Set the tensor parallel group for the given + module before executing the forward pass. + + Parameters + ---------- + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + """ + self.tp_group = tp_group + self.tp_group_initialized = True + + def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: + """returns the FP8 weights.""" + fp8_params = [] + for param in self.parameters(recurse=False): + if isinstance(param, QuantizedTensor) and param.requires_grad: + fp8_params.append(param) + if len(fp8_params) == 0: + return None + return fp8_params + + # This routine is shared across FP8 and FP8_calibration paths so should not actually + # assume FP8 execution. + def init_fp8_metadata(self, num_gemms: int = 1) -> None: + """Initialize fp8 related metadata and tensors during fprop.""" + _original_recipe = self.fp8_meta.get("recipe", None) + + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + self.fp8 = FP8GlobalStateManager.is_fp8_enabled() + self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + fp8_enabled = self.fp8 or self.fp8_calibration + self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + + if self.fp8_parameters or fp8_enabled: + if ( + self.fp8_initialized + and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] + ): + # FP8 init has already been run and recipe is the same, don't do anything. + return + self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + else: + # If fp8 isn't enabled, turn off and return. + self.fp8_initialized = False + return + + if self.fp8_parameters and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + + if fp8_enabled: + # Set FP8 and other FP8 metadata + self.fp8_meta["num_gemms"] = num_gemms + self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + + # Set FP8_MAX per tensor according to recipe + self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd + self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + + # Allocate scales and amaxes + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + self.fp8_initialized = True + + self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + + _current_recipe = self.fp8_meta["recipe"] + if _original_recipe is not None and not ( + issubclass(_current_recipe.__class__, _original_recipe.__class__) + or issubclass(_original_recipe.__class__, _current_recipe.__class__) + ): + warnings.warn( + f"Recipe type changed from {_original_recipe.__class__.__name__} " + f"to {_current_recipe.__class__.__name__}. " + "This may affect model behavior." + ) + # Clear cached workspaces as they were created with the old recipe/quantizer type + self._fp8_workspaces.clear() + + @contextmanager + def prepare_forward( + self, + inp: torch.Tensor, + num_gemms: int = 1, + allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, + ) -> Generator[torch.Tensor, None, None]: + """Checks and prep for FWD. + The context manager is needed because there isn't a way for a module to know + if it's the last FP8 module in the forward autocast. It is useful + to setup the forward aggregated amax reduction for every module + just in case. The autocast exit will pick up the most recent one. + """ + self.allow_different_data_and_param_types = allow_different_data_and_param_types + self.forwarded_at_least_once = True + # Activation recomputation is used and this is the second forward phase. + if self.fp8 and in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) + else: + assert inp.is_cuda, "TransformerEngine needs CUDA." + + if self.tp_size > 1: + assert self.tp_group_initialized, "TP group not initialized." + + self.set_activation_dtype(inp) + self.init_fp8_metadata(num_gemms=num_gemms) + self._check_weight_tensor_recipe_correspondence() + + if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): + assert self.fp8_meta["recipe"].reduce_amax, ( + "Amax reduction across tensor parallel group is " + "necessary when using sequence parallelism with FP8." + ) + + if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) + + # Activation recomputation is used and this is the first forward phase. + if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): + FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) + + with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + yield inp + + if self.fp8 and in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) + + def set_nccl_overlap_warning_if_tp(self) -> None: + """When using TP, the NCCL communication needs to be scheduled + before the GEMM for there to be a guaranteed overlap. From the + host side in TE, the comm calls are always launched first, but + to ensure that the GEMM isn't scheduled first, the environment + variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to + force a single channel. + """ + if self.tp_size == 1: + return + num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) + if num_cuda_work_queues != 1: + warnings.warn( + "To guarantee overlapping TP and SP collectives with the backward" + "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1" + ) + + @staticmethod + def grad_output_preprocess( + ctx, + grad_output: torch.Tensor, + row_parallel_mode: bool, + quantizer: Optional[Quantizer], + ) -> Tuple[Union[torch.Tensor, None], ...]: + """Utility function for backward. + Returns tuple in order (all optional/None based on training precion/recipe): + R1: gathered `grad_output`. + R2: bias gradient on R1. + + """ + grad_output = grad_output.reshape((-1, grad_output.shape[-1])) + grad_output = grad_output.contiguous() + gather_grad_output = row_parallel_mode and ctx.sequence_parallel + + # Non-FP8 case: bgrad is fused with wgrad for this case. + if not ctx.fp8 and not ctx.debug: + if gather_grad_output: + if not ctx.ub_overlap_ag: # Perform NCCL all-gather + grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) + else: # Initialize Userbuffers all-gather + grad_output, _ = fill_userbuffers_buffer_for_all_gather( + ctx.ub_obj_gradout, + grad_output, + None, + ctx.tp_group, + ) + return grad_output, None + + # FP8 with all-gather: unfused bgrad, fused cast + transpose + # Also supports debug quantization, which is handled inside gather_along_first_dim. + if gather_grad_output: + grad_bias = None + if ctx.use_bias: + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + if ctx.ub_overlap_ag: + # Quantize the gradient if needed + if not isinstance( + grad_output, + ( + QuantizedTensor, + Float8TensorBase, + MXFP8TensorBase, + Float8BlockwiseQTensorBase, + ), + ): + grad_output = quantizer(grad_output) + + # Copy into communication buffer, and replace original gradient with it + grad_output, _ = fill_userbuffers_buffer_for_all_gather( + ctx.ub_obj_gradout, + grad_output, + quantizer, + ctx.tp_group, + ) + else: + grad_output, _ = gather_along_first_dim( + grad_output, + ctx.tp_group, + quantizer=quantizer, + ) + return grad_output, grad_bias + + # Debug without all-gather: unfused cast and bgrad + # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None + if ctx.debug: + grad_output_ = quantizer(grad_output) + if ( + isinstance( + grad_output_.get_tensor(True), + ( + QuantizedTensor, + Float8TensorBase, + MXFP8TensorBase, + Float8BlockwiseQTensorBase, + ), + ) + and ctx.use_bias + ): + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias = None + grad_output = grad_output_ + return grad_output, grad_bias + + # FP8 without all-gather: fused bgrad + cast + transpose + grad_bias = None + if ctx.use_bias: + if isinstance( + grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ): + grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) + else: + # TODO(ksivaman): Re-add fusion once kernel is available. + if isinstance(quantizer, (Float8BlockQuantizer, NVFP4Quantizer)): + # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) + if not isinstance(grad_output, QuantizedTensorBase): + grad_output = quantizer(grad_output) + return grad_output, grad_bias + + def register_parameter(self, name, param, **kwargs): + """ + Thin wrapper around PyTorch parameter registration to stash additional parameter + metedata used in deferred initialization. + """ + super().register_parameter(name, param) + self.param_init_meta[name] = _ParameterInitMeta(**kwargs) + + def reset_parameters(self, defer_init: Optional[bool] = False) -> None: + """ + Reset all module parameters to initial values. Unless deferred initialization + is specified, all parameters on a 'meta' device are also materialized on a real cuda + device before the values are reset to initial. + """ + if defer_init: + return + + for name, param in self.named_parameters(recurse=False): + # Ensure parameter is on a real device + if param.device == torch.device("meta"): + param = torch.empty_like(param, device="cuda") + + # Initialize the parameter values on device + init_fn = self.param_init_meta[name].init_fn + get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker + if get_rng_state_tracker is None: + init_fn(param) + else: + if hasattr(self, "rng_tracker_name") and self.rng_tracker_name: + with get_rng_state_tracker().fork(self.rng_tracker_name): + init_fn(param) + else: + with get_rng_state_tracker().fork(): + init_fn(param) + + # Wrap parameters in QuantizedTensor if needed + fp8_meta_index = self.param_init_meta[name].fp8_meta_index + high_precision_init_val = None + if self.primary_weights_in_fp8 and fp8_meta_index is not None: + + # Keep high-precision values on CPU if needed + if self.preserve_high_precision_init_val: + high_precision_init_val = param.detach().cpu() + + # Configure quantizer + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] + if quantizer is None: + raise RuntimeError("Weight quantizer has not been initialized") + quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + quantizer.internal = False + + # Quantize parameter + param = quantizer(param) + + # Redo parameter wrap in case we broke it above + # NOTE: Currently this can only be broken when primary weights are in Fp8 but + # re-applying the nn.Parameter() wrap is a no-op when the input is already + # a parameter so we always re-apply it just for extra safety. + param = torch.nn.Parameter(param) + + # Keep high-precision values on CPU if needed + if high_precision_init_val is not None: + + # - Master weights are initialized from model weights, if we use fp8 primary + # weights to initialize master weights, the numerical values of master weights + # are not consistent with the numerical values when we initialize them from + # bf16/fp16 weights. + # - So we add a `_high_precision_init_val` attribute to each model weight to store + # the original bf16/fp16 weight on cpu before casting it to fp8. And users can + # use `get_high_precision_init_val` to get this cpu tensor. + # - This cpu tensor is not needed once the master weight is initialized, so users + # should call `clear_high_precision_init_val` to remove it after master weight + # is initialized. + + def get(self): + if hasattr(self, "_high_precision_init_val"): + return self._high_precision_init_val + return None + + def clear(self): + if hasattr(self, "_high_precision_init_val"): + del self._high_precision_init_val + + param._high_precision_init_val = high_precision_init_val + param.get_high_precision_init_val = MethodType(get, param) + param.clear_high_precision_init_val = MethodType(clear, param) + + setattr(self, name, param) + + @abstractmethod + def forward(self): + """Needs override.""" + + def get_weight_workspace( + self, + *, + tensor: Optional[torch.Tensor] = None, + quantizer: Optional[Quantizer] = None, + cache_name: Optional[str] = None, + update_workspace: bool = True, + skip_update_flag: Optional[torch.Tensor] = None, + fsdp_group: Optional[dist_group_type] = None, + workspace_dtype: Optional[torch.dtype] = None, + ) -> QuantizedTensor: + """Get workspace buffer for weights and maybe update its values + + The workspace buffer may be cached for future function calls. + + Parameters + ---------- + tensor : torch.Tensor, optional + Values to copy into workspace. Required if the workspace + is being constructed or updated. + quantizer: Quantizer, optional + Quantizer used to cast the weights. Required if the + workspace is being constructed or updated. + cache_name: str, optional + Key for caching. + update_workspace: bool, default = `True` + Update workspace with values from `tensor`. + skip_update_flag: torch.Tensor, optional + GPU flag to skip updating the workspace. Take precedence + over `update_workspace` if provided. + fsdp_group: bool, default = None + FSDP process group that the weights are distributed over. + workspace_dtype: torch.dtype, default = None + If weight workspace contains high-precision tensor - for example + for debug quantization, this is dtype of the tensor. + """ + + # Handle case where weights are already quantized + # Note: Make sure weights have required usages, but do not + # destroy unnecessary usages since they may be used later. + if isinstance(tensor, QuantizedTensor): + update_rowwise_usage = True if quantizer.rowwise_usage else None + update_columnwise_usage = True if quantizer.columnwise_usage else None + tensor.update_usage( + rowwise_usage=update_rowwise_usage, + columnwise_usage=update_columnwise_usage, + ) + return tensor + + # Try getting workspace from cache + out = None + if cache_name is not None: + out = self._fp8_workspaces.get(cache_name, None) + + # Reset cache if workspace is invalid + if out is not None and quantizer is not None: + reset_cache = False + if isinstance(out, Float8TensorBase): + if ( + not is_non_tn_fp8_gemm_supported() + and quantizer.columnwise_usage + and out._transpose is None + ): + reset_cache = True + elif isinstance(out, MXFP8TensorBase): + if quantizer.rowwise_usage and out._rowwise_data is None: + reset_cache = True + elif quantizer.columnwise_usage and out._columnwise_data is None: + reset_cache = True + if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): + reset_cache = True + if reset_cache: + out = None + del self._fp8_workspaces[cache_name] + + # Gather cached Fp8 workspace if it's distributed + # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work + # for models initialized with Fp8 primary weights. + if ( + out is not None + and tensor is not None + and fsdp_group is not None + and out.data.shape != tensor.data.shape + ): + _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) + + # Construct workspace if needed + if out is None: + if tensor is None or quantizer is None: + raise ValueError( + "tensor and quantizer kwargs must be provided to construct FP8 workspace" + ) + + if cache_name is not None: + # Ensure the tensor in the cache is an instance of torch.Tensor, + # as it persists beyond a single forward pass. + # Setting internal=True would cause the data to be removed in prepare_for_saving(...). + quantizer_internal = quantizer.internal + quantizer.internal = False + out = quantizer.quantize(tensor, dtype=workspace_dtype) + if cache_name is not None: + quantizer.internal = quantizer_internal + + # Update cache + if cache_name is not None: + self._fp8_workspaces[cache_name] = out + return out + + # Update workspace if needed + if skip_update_flag is not None: + update_workspace = True + if update_workspace: + if tensor is None: + raise ValueError("tensor kwarg must be provided to update FP8 workspace") + if hasattr(out, "quantize_"): + out.quantize_(tensor, noop_flag=skip_update_flag) + else: + tex.quantize(tensor, quantizer, out, skip_update_flag) + return out + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + """ + This function loads tensors and extra state including fp8 metadata. + This metadata is essential for copying fp8 tensors, as the copy_ function + uses the scale_inv parameter from fp8_meta to set the correct scaling factor + for the new tensor. + Hence, this extra state must be loaded before the tensor copying process, + not after, as is typically done in _load_from_state_dict. + Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True, + otherwise, this behavior is not required. + """ + if self.primary_weights_in_fp8: + extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook): + """ + This method is used to manually control the weight gradient accumulation and reduce. + This method should be called before the backward() method. + Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation + and reduce in backward(); + And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method. + """ + self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook) + + def backward_dw(self): + """ + Execute the delayed weight gradient computation. + This method is called after the main backward pass to compute weight gradients. + """ + if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): + return + with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): + (wgrad, bgrad), _ = self.wgrad_store.pop() + if not self.fuse_wgrad_accumulation: + weight_tensor = noop_cat(self._get_weight_tensors()) + weight_tensor.grad = wgrad.to(weight_tensor.dtype) + if self.use_bias: + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + if bias_tensor.grad is None: + bias_tensor.grad = bgrad.to(bias_tensor.dtype) + del wgrad + del bgrad + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() + + def is_debug_iter(self) -> bool: + """ + This function checks if the debug should be enabled for this layer. + """ + debug = TEDebugState.debug_enabled + if not debug: + return False + self._validate_name() + + # If layer is run first time in new iteration, + # we need to check if the debug should be enabled for this layer - + # maybe in previous iterations debug features returned information + # that no feature will be active for this layer for multiple next iterations. + started_new_iteration = TEDebugState.get_iteration() != getattr( + self, "debug_last_iteration", None + ) + if started_new_iteration: + if self.next_iter_when_debug_should_be_run is None: + debug = False + else: + debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run + self.debug_last_iteration = TEDebugState.get_iteration() + return debug + + def no_debug_features_active(self, quantizers): + """ + Checks if any debug feature is active for this layer. + """ + run_current = any_feature_enabled(quantizers) + + # Sometimes features inform that they will not be enabled for particular layer + # for multiple next iterations. + self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) + + if not run_current: + return True + + if self.primary_weights_in_fp8: + raise RuntimeError("FP8 weights are not supported in debug mode.") + return False + + def _validate_name(self): + """ + Validate name passed to the module. + This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. + If no name is assigned, it creates a default name with layer count as the variable. + """ + if self.name is not None: + return + assert TEDebugState.debug_enabled + import nvdlfw_inspect.api as debug_api + + if self.name is None: + debug_api.log_message( + "Names are not provided to debug modules. ", + "Creating and using generic names. Pass names to debug modules for better" + " insight. ", + level=logging.WARNING, + ) + self.name = f"Layer_{TEDebugState.get_layer_count()}" + + def _check_weight_tensor_recipe_correspondence(self) -> None: + """ + Verify that the weight tensor types match their corresponding recipe type. + This is invoked in the forward(). + + This establishes a 1:1 correspondence between recipe types and tensor types: + - DelayedScaling → Float8Tensor + - Float8CurrentScaling → Float8Tensor + - MXFP8BlockScaling → MXFP8Tensor + - Float8BlockScaling → Float8BlockTensor + + Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()), + but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()). + """ + if not self.fp8 and not self.fp8_calibration: + return + if not hasattr(self, "weight_names") or not self.weight_names: + return + + recipe = self.fp8_meta["recipe"] + weight_tensors = [getattr(self, name) for name in self.weight_names] + for i, tensor in enumerate(weight_tensors): + if isinstance(tensor, QuantizedTensorBase): + quantizer = tensor._get_quantizer() + if quantizer is None: + continue + compatible_recipe_class = quantizer._get_compatible_recipe() + if compatible_recipe_class is None: + continue + if not isinstance(recipe, compatible_recipe_class): + raise RuntimeError( + f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe" + f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}." + " Please check the recipes assigned during fp8_model_init() and" + " fp8_autocast() calls." + ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6dbbd335eb..b77cd0bcda 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -8,6 +8,7 @@ from typing import Callable, Dict, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op +import os import torch from torch.nn import init @@ -73,6 +74,9 @@ from ..cpp_extensions import ( general_gemm, + ubsymm_request_allocator, + ubsymm_get_sym_tensor, + ubsymm_allreduce, ) __all__ = ["LayerNormLinear"] @@ -128,6 +132,7 @@ def forward( module: torch.nn.Module, skip_fp8_weight_update: bool, symmetric_ar_type: str, + skip_layernorm: bool = False, debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -204,20 +209,25 @@ def forward( ) # Apply normalization - nvtx_range_push(f"{nvtx_label}.norm") - ln_out, mu, rsigma = apply_normalization( - inputmat, - None, # ln_out - ln_weight, - ln_bias, - eps, - input_quantizer if with_quantized_norm else None, - inputmat.dtype, - normalization, - fwd_ln_sm_margin, - zero_centered_gamma, - ) - nvtx_range_pop(f"{nvtx_label}.norm") + if skip_layernorm: + ln_out = inputmat + mu = None + rsigma = None + else: + nvtx_range_push(f"{nvtx_label}.norm") + ln_out, mu, rsigma = apply_normalization( + inputmat, + None, # ln_out + ln_weight, + ln_bias, + eps, + input_quantizer if with_quantized_norm else None, + inputmat.dtype, + normalization, + fwd_ln_sm_margin, + zero_centered_gamma, + ) + nvtx_range_pop(f"{nvtx_label}.norm") # Store unquantized layer norm output if we need to return it ln_out_return = None @@ -335,7 +345,15 @@ def forward( out_shape[0] //= tp_world_size out_shape[-1] = out_features reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) - + symm_out = None + if symmetric_ar_type is not None and symmetric_ar_type.startswith("ubnext") and parallel_mode == "row" and tp_size > 1: + out_shape_list = list(tuple(inp.shape)) + out_shape_list[-1] = out_features + symm_out = ubsymm_get_sym_tensor( + torch.Size(out_shape_list), + activation_dtype, + tp_group, + ) # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T @@ -352,6 +370,7 @@ def forward( ub=ub_obj, ub_type=ub_type, extra_output=reduce_scatter_out, + out=symm_out, ) nvtx_range_pop(f"{nvtx_label}.gemm") # ------------------------------------------------------ @@ -380,7 +399,17 @@ def forward( out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: if symmetric_ar_type is not None: - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + if symm_out is not None: + out = ubsymm_allreduce(symm_out) + else: + fallback_symmetric = ( + "multimem_all_reduce" + if symmetric_ar_type.startswith("ubnext") + else symmetric_ar_type + ) + out, _ = symmetric_all_reduce( + out, tp_group, all_reduce_type=fallback_symmetric + ) else: out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") @@ -1046,6 +1075,7 @@ def wgrad_gemm( None, # module None, # skip_fp8_weight_update None, # symmetric_ar_type + None, # skip_layernorm ) @@ -1175,6 +1205,7 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, + skip_layernorm: bool = False, name: str = None, ) -> None: super().__init__() @@ -1268,7 +1299,15 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - + if self.symmetric_ar_type.startswith("ubnext") and parallel_mode == "row" and tp_size > 1: + ubsymm_request_allocator( + self.tp_group, + ( + int(os.environ.get("NVTE_UB_MAXBATCH", 64)), + self.out_features, + ), + params_dtype, + ) self.eps = eps layer_norm_weight = torch.nn.Parameter( torch.empty(self.in_features, device=device, dtype=params_dtype) @@ -1287,7 +1326,7 @@ def __init__( ) else: self.layer_norm_bias = None - + self.skip_layernorm = skip_layernorm # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1599,6 +1638,7 @@ def forward( self, skip_fp8_weight_update, self.symmetric_ar_type, + self.skip_layernorm, debug, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py.orig b/transformer_engine/pytorch/module/layernorm_linear.py.orig new file mode 100644 index 0000000000..6dbbd335eb --- /dev/null +++ b/transformer_engine/pytorch/module/layernorm_linear.py.orig @@ -0,0 +1,1827 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""LayerNormLinear API""" +import os +import warnings +from typing import Callable, Dict, Optional, Tuple, Union, List +from functools import reduce +from operator import mul as multiply_op + +import torch +from torch.nn import init + +import transformer_engine_torch as tex + +from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.tensor.utils import is_experimental +from .base import ( + fill_userbuffers_buffer_for_all_gather, + get_workspace, + get_ub, + TransformerEngineBaseModule, + get_dummy_wgrad, + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, +) +from ..fp8 import FP8GlobalStateManager +from ..utils import ( + assert_dim_for_fp8_exec, + assert_dim_for_all_gather, + cast_if_needed, + clear_tensor_data, + divide, + get_default_init_method, + init_method_constant, + nvtx_range_pop, + nvtx_range_push, + requires_grad, + needs_quantized_gemm, +) +from ..distributed import ( + set_tensor_model_parallel_attributes, + get_distributed_world_size, + allreduce, + symmetric_all_reduce, + reduce_scatter_along_first_dim, + gather_along_first_dim, + in_fp8_activation_recompute_phase, + _fsdp_scatter_tensors, + _fsdp_gather_tensors, +) +from ..constants import GemmParallelModes, dist_group_type +from ..jit import no_torch_dynamo +from ..graph import is_graph_capturing +from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers +from ..tensor.quantized_tensor import ( + QuantizedTensor, + QuantizedTensorBase, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from ...debug.pytorch.debug_state import TEDebugState +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..export import is_in_onnx_export_mode, assert_warmed_up +from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload + +from ..cpp_extensions import ( + general_gemm, +) + +__all__ = ["LayerNormLinear"] + + +class _LayerNormLinear(torch.autograd.Function): + """LayerNormLinear semi-top level module + Calls custom cuda extensions. + """ + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + ln_weight: torch.Tensor, + ln_bias: Union[torch.Tensor, None], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + is_first_microbatch: Union[bool, None], + fp8: bool, + fp8_calibration: bool, + wgrad_store: WeightGradStore, + fuse_wgrad_accumulation: bool, + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], + cpu_offloading: bool, + tp_group: Union[dist_group_type, None], + tp_size: int, + sequence_parallel: bool, + tensor_parallel: bool, + activation_dtype: torch.dtype, + parallel_mode: Union[str, None], + return_layernorm_output: bool, + return_layernorm_output_gathered: bool, + is_grad_enabled: bool, + fwd_ln_sm_margin: int, + bwd_ln_sm_margin: int, + zero_centered_gamma: bool, + normalization: str, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_rs_dgrad: bool, + ub_bulk_wgrad: bool, + ub_bulk_dgrad: bool, + ub_name: str, + fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, + symmetric_ar_type: str, + debug: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + # pylint: disable=missing-function-docstring + + # NVTX label for profiling + nvtx_label = "transformer_engine._LayerNormLinear.forward" + if ub_name is not None: + nvtx_label = f"{nvtx_label}.{ub_name}" + + with_input_all_gather = parallel_mode == "column" and sequence_parallel + + # Make sure input dimensions are compatible + out_features, in_features = weight.shape + inp_shape = inp.shape + inp_requires_grad = inp.requires_grad + assert inp_shape[-1] == in_features, "GEMM not possible" + inp = inp.view((-1, in_features)) + inputmat = inp + if fp8: + assert_dim_for_fp8_exec(inputmat, weight) + assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer) + + # Cast for native AMP + nvtx_range_push(f"{nvtx_label}.norm_input_cast") + inputmat = cast_if_needed(inputmat, activation_dtype) + ln_weight = cast_if_needed(ln_weight, activation_dtype) + if ln_bias is not None: + ln_bias = cast_if_needed(ln_bias, activation_dtype) + nvtx_range_pop(f"{nvtx_label}.norm_input_cast") + + tp_world_size = get_distributed_world_size(tp_group) + + weight_requires_grad = weight.requires_grad + backward_needs_input = is_grad_enabled and weight_requires_grad + + # Configure Userbuffers communication (comm+GEMM overlap) + if debug: # turn off userbuffers in debug mode + ub_overlap_ag_fprop = False + ub_overlap_rs_fprop = False + ub_overlap_ag_dgrad = False + ub_overlap_rs_dgrad = False + ub_bulk_wgrad = False + ub_bulk_dgrad = False + ub_obj = None + ub_type = None + ub_overlap_ag_fprop = ( + ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output + ) + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop", fp8) + ub_type = tex.CommOverlapType.RS + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop", fp8) + ub_type = tex.CommOverlapType.AG + + # Configure quantizer for norm output + if fp8: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): + # All-gather is not supported with FP8 column-wise data + input_quantizer.set_usage(columnwise=False) + + # Avoid quantized norm kernel if norm output will be returned + # or if a gather of ln_out must be in high precision. + experimental = is_experimental(input_quantizer) + with_quantized_norm = ( + fp8 + and not debug + and not return_layernorm_output + and not return_layernorm_output_gathered + and not experimental + ) + + # Apply normalization + nvtx_range_push(f"{nvtx_label}.norm") + ln_out, mu, rsigma = apply_normalization( + inputmat, + None, # ln_out + ln_weight, + ln_bias, + eps, + input_quantizer if with_quantized_norm else None, + inputmat.dtype, + normalization, + fwd_ln_sm_margin, + zero_centered_gamma, + ) + nvtx_range_pop(f"{nvtx_label}.norm") + + # Store unquantized layer norm output if we need to return it + ln_out_return = None + if return_layernorm_output or return_layernorm_output_gathered: + ln_out_return = ln_out + + # ------------------------------------------------------ + # Prepare GEMM input tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + # ------------------------------------------------------ + nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") + ln_out_total = None + if with_input_all_gather: + if return_layernorm_output_gathered: + # Perform all-gather in high precision if gathered + # norm output will be returned + ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) + ln_out_return = ln_out_total + if fp8 or debug: + ln_out = input_quantizer(ln_out) + input_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(input_quantizer, Float8BlockQuantizer): + input_quantizer.all_gather_usage = False + ln_out_total = input_quantizer(ln_out_total) + else: + quantizer = None + if fp8 or debug: + quantizer = input_quantizer + # experimental recipe doesn't need to support quantized AG + if not with_quantized_norm and not experimental: + ln_out = quantizer(ln_out) + quantizer.set_usage(rowwise=True, columnwise=False) + if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather + ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj, + ln_out, + quantizer, + tp_group, + ) + else: # Perform NCCL all-gather + ln_out_total, _ = gather_along_first_dim( + ln_out, + tp_group, + quantizer=quantizer, + ) + else: + if (fp8 or debug) and not with_quantized_norm: + ln_out = input_quantizer(ln_out) + ln_out_total = ln_out + nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") + # ------------------------------------------------------ + # GEMM input tensor is ready... + # ------------------------------------------------------ + + # ------------------------------------------------------ + # Prepare weight tensor + # ------------------------------------------------------ + weightmat = weight + quantized_weight = False + if fp8 or debug: + quantized_weight = not isinstance(weight, QuantizedTensorBase) + + # Configure quantizer + if weight_quantizer is not None: + weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + + # Get quantized weight + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, + ) + weightmat.update_usage(rowwise_usage=True) + + else: + weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP + # ------------------------------------------------------ + # Weight tensor is ready for GEMM... + # ------------------------------------------------------ + + # Cast bias to expected dtype + bias_dtype = activation_dtype + if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: + # cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16 + bias_dtype = torch.bfloat16 + bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if input_quantizer is not None: + input_quantizer.calibrate(ln_out_total) + if weight_quantizer is not None: + weight_quantizer.calibrate(weight) + + # Choose whether to use GEMM kernel with split accumulator + use_split_accumulator = _2X_ACC_FPROP + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + + # Configure output quantizer + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Output buffer for Userbuffers reduce-scatter + reduce_scatter_out = None + if ub_overlap_rs_fprop: + out_shape = list(inp_shape) + out_shape[0] //= tp_world_size + out_shape[-1] = out_features + reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) + + # ------------------------------------------------------ + # Forward GEMM + # Note: y = x * w^T + # ------------------------------------------------------ + nvtx_range_push(f"{nvtx_label}.gemm") + gemm_out, *_, reduce_scatter_out = general_gemm( + weightmat, + ln_out_total, + get_workspace(), + quantization_params=output_quantizer, + out_dtype=activation_dtype, + bias=bias, + use_split_accumulator=use_split_accumulator, + ub=ub_obj, + ub_type=ub_type, + extra_output=reduce_scatter_out, + ) + nvtx_range_pop(f"{nvtx_label}.gemm") + # ------------------------------------------------------ + # Finished forward GEMM... + # ------------------------------------------------------ + + # Deallocate GEMM input tensor if no longer needed + if not weight.requires_grad and not return_layernorm_output: + clear_tensor_data(ln_out, ln_out_total) + ln_out = ln_out_total = None + elif with_input_all_gather and not return_layernorm_output_gathered: + clear_tensor_data(ln_out_total) + ln_out_total = None + + # ------------------------------------------------------ + # Prepare output tensor + # Note: Perform tensor-parallel communication + # ------------------------------------------------------ + out = None + if ub_overlap_rs_fprop: + out = reduce_scatter_out + elif parallel_mode == "row" and tp_size > 1: + nvtx_range_push(f"{nvtx_label}.row_parallel_comm") + out = gemm_out + if sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + if symmetric_ar_type is not None: + out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + else: + out, _ = allreduce(out, tp_group) + nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") + else: + out = gemm_out + out = out.view(-1, *inp_shape[1:-1], out_features) + # ------------------------------------------------------ + # Output tensor is ready to return... + # ------------------------------------------------------ + + # ------------------------------------------------------ + # Cache state for backward pass + # ------------------------------------------------------ + + if is_grad_enabled: + ctx.weight_quantizer = weight_quantizer + ctx.ln_out_needs_gather = ( + weight.requires_grad and parallel_mode == "column" and sequence_parallel + ) + + # Input with column-wise usage is needed for wgrad GEMM. + if backward_needs_input: + if isinstance(ln_out, QuantizedTensorBase): + # For sequence parallel in vanilla FP8, rowwise data is + # to gather the input. For MXFP8, columnwise only data + # can be allgathered. + if ( + isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) + or not ctx.ln_out_needs_gather + ): + ln_out.update_usage(rowwise_usage=False) + + # Weight with column-wise usage is needed for dgrad GEMM. + if isinstance(weightmat, QuantizedTensorBase): + weightmat.update_usage(columnwise_usage=True) + + if cpu_offloading: + mark_activation_offload(inputmat, mu, rsigma, ln_out) + + # Scatter intermediate/activation tensors saved for the backward pass + # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_scatter") + ctx.fsdp_group = fsdp_group + ctx.fsdp_shapes = _fsdp_scatter_tensors( + fsdp_group, + mu, + rsigma, + weightmat if quantized_weight else None, + ln_out if weight.requires_grad else None, + ) + nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + + if cpu_offloading: + ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + + if ctx.grad_added_to_main_grad: + # If you are passing torch.nn.Parameter through the Torch hooks, you will + # get back torch.Tensor. Torch rips off the Parameter wrapper. + # You need to preserve the weight object to have all the attributes user + # sets for the weights. Because of this, it is not recommended to offload + # weights if weights are externally touched outside this module + ctx.weight_object = weight + + tensors_to_save, tensor_objects = prepare_for_saving( + inputmat, + weightmat, + weight, + bias, + ln_weight, + ln_out, + mu, + rsigma, + ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + ctx.requires_dgrad = inp_requires_grad + ctx.requires_wgrad = weight.requires_grad + ctx.quantized_weight = quantized_weight + if fuse_wgrad_accumulation and weight.requires_grad: + # This check is needed to ensure that main_grad is not created + # during the forward pass when using MCore FSDP as it creates + # the main_grad buffer lazily before backprop + if hasattr(weight, "__fsdp_param__"): + # MCore FSDP creates main_grad lazily before backward + ctx.main_grad_func = weight.get_main_grad + else: + ctx.main_grad_func = lambda: weight.main_grad + ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.input_quantizer = input_quantizer + ctx.owns_input = inputmat is not inp + ctx.weight = weight + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.cpu_offloading = cpu_offloading + ctx.is_first_microbatch = is_first_microbatch + ctx.use_bias = bias is not None + ctx.sequence_parallel = sequence_parallel + ctx.tensor_parallel = tensor_parallel + ctx.inp_shape = inp_shape + ctx.parallel_mode = parallel_mode + ctx.tp_group = tp_group + ctx.tp_size = tp_size + ctx.return_layernorm_output = return_layernorm_output + ctx.return_layernorm_output_gathered = return_layernorm_output_gathered + ctx.bwd_ln_sm_margin = bwd_ln_sm_margin + ctx.zero_centered_gamma = zero_centered_gamma + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + ctx.ub_bulk_wgrad = ub_bulk_wgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_name = ub_name + ctx.requires_dgrad = inp_requires_grad + ctx.normalization = normalization + ctx.reduce_and_update_bwd_fp8_tensors = False + if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.wgrad_store = wgrad_store + ctx.debug = debug + + # ------------------------------------------------------ + # Cached state for backward pass is ready... + # ------------------------------------------------------ + + if return_layernorm_output: + if return_layernorm_output_gathered: + shape = list(inp_shape) + shape[0] *= tp_size if with_input_all_gather else 1 + return out, ln_out_return.view(shape) + return out, ln_out_return.view(inp_shape) + return out + + @staticmethod + def backward( + ctx, *grad_outputs: Tuple[torch.Tensor, ...] + ) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring + + # NVTX label for profiling + nvtx_label = "transformer_engine._LayerNormLinear.backward" + if ctx.ub_name is not None: + nvtx_label = f"{nvtx_label}.{ctx.ub_name}" + + with torch.cuda.nvtx.range("_LayerNormLinear_backward"): + saved_tensors = ctx.saved_tensors + ( # pylint: disable=unbalanced-tuple-unpacking + inputmat, + weight, + origin_weight, + bias, + ln_weight, + ln_out, + mu, + rsigma, + ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None + + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ( + ctx.main_grad_func() + if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad + else None + ) + + # Gather intermediate/activation tensors if needed + # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_gather") + _fsdp_gather_tensors( + ctx.fsdp_group, + ctx.fsdp_shapes, + mu, + rsigma, + weight if ctx.fp8 and ctx.quantized_weight else None, + ln_out, + ) + nvtx_range_pop(f"{nvtx_label}.fsdp_gather") + + # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, + # we need to connect them into one. + if ctx.cpu_offloading: + if ctx.grad_added_to_main_grad: + origin_weight = ctx.weight_object + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + origin_weight.main_grad = main_grad + + # Configure Userbuffers communication (comm+GEMM overlap) + ctx.ub_obj_gradout = None + ub_obj_dgrad = None + ub_obj_wgrad = None + ub_type_dgrad = None + ub_type_wgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + if ctx.ub_overlap_ag: + # Overlap grad_output all-gather with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_type_wgrad = tex.CommOverlapType.RS + + # -------------------------------------------------- + # Prepare grad output tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + # -------------------------------------------------- + + # Configure quantizer for grad output tensor + # Note: dgrad GEMM requires row-wise usage, wgrad GEMM + # requires column-wise usage + if ctx.grad_output_quantizer is not None: + quantizer = ctx.grad_output_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + if ctx.ub_overlap_ag: + # Userbuffers only supports communication for one + # tensor usage at a time. Configure quantizer with + # usage for only dgrad GEMM. + quantizer.set_usage(columnwise=False) + + # Prepare grad output tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") + ( + grad_output, + grad_bias, + ) = TransformerEngineBaseModule.grad_output_preprocess( + ctx, + grad_outputs[0], + ctx.parallel_mode == "row", + ctx.grad_output_quantizer, + ) + nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") + + # -------------------------------------------------- + # Grad output tensor is ready for computing grad input... + # -------------------------------------------------- + + # -------------------------------------------------- + # Prepare GEMM input tensor + # Note: Input tensor is needed for wgrad GEMM. + # Tensor-parallel communication is overlapped with dgrad + # GEMM. + # -------------------------------------------------- + ln_out_total = None + ln_out_total_work = None + if ctx.ln_out_needs_gather: + quantizer = None + if ctx.input_quantizer is not None: + quantizer = ctx.input_quantizer + if quantizer.supports_only_rowwise_all_gather(): + # If data is in FP8, we compute FP8 transposes manually + quantizer.set_usage(rowwise=True, columnwise=False) + else: + # wgrad GEMM requires input with column-wise usage + quantizer.set_usage(rowwise=False, columnwise=True) + if ctx.ub_bulk_dgrad: + ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_dgrad, + ln_out, + quantizer, + ctx.tp_group, + ) + else: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + ln_out_total, ln_out_total_work = gather_along_first_dim( + ln_out, + ctx.tp_group, + async_op=True, + quantizer=quantizer, + ) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") + else: + ln_out_total = ln_out + # -------------------------------------------------- + # Input tensor is ready for computing grad weight... + # -------------------------------------------------- + + # -------------------------------------------------- + # Compute grad input tensor + # Note: Gradient w.r.t. GEMM input (i.e. norm output). + # -------------------------------------------------- + + # Make sure required data is available + if isinstance(grad_output, QuantizedTensorBase): + grad_output.update_usage(rowwise_usage=True) + if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase): + weight.update_usage(columnwise_usage=True) + + # Choose whether to use GEMM kernel with split accumulator + use_split_accumulator = _2X_ACC_DGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + + # Update grad input quantizer + if ctx.grad_input_quantizer is not None: + ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + + # Output buffers for Userbuffers reduce-scatter + gemm_out = None + reduce_scatter_out = None + if ctx.ub_overlap_rs_dgrad: + reduce_scatter_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device + ) + elif ctx.ub_bulk_wgrad: + gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False) + + # dgrad GEMM + # Note: dx = dy * w + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + gemm_out, *_, reduce_scatter_out = general_gemm( + weight, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=gemm_out, + out_dtype=ctx.activation_dtype, + use_split_accumulator=use_split_accumulator, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=reduce_scatter_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") + + # Prepare grad input tensor + # Note: Perform tensor-parallel communication + dgrad = None + dgrad_work = None + if ctx.ub_overlap_rs_dgrad: + dgrad = reduce_scatter_out + elif ctx.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) + elif ctx.parallel_mode == "column" and ctx.tp_size > 1: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") + dgrad = gemm_out + if ctx.sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, + ) + else: + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + else: + dgrad = gemm_out + + # -------------------------------------------------- + # Grad input tensor has been computed... + # -------------------------------------------------- + + # -------------------------------------------------- + # Compute grad weight + # -------------------------------------------------- + + wgrad = None + if ctx.requires_wgrad: + # Prepare grad output tensor + # Note: Synchronize tensor-parallel communication and + # make sure required data is available + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + # UB does not support pipelined overlapping grad output + # all-gather with wgrad GEMM. Also, we can't + # convert row-scaled MXFP8 to column-scaled, so we + # can't reuse the grad output that was gathered + # for the dgrad GEMM. We work around by explicitly + # overlapping the AG operation with the dgrad GEMM. + + # Get the communication stream from the dgrad GEMM to use for the AG + dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() + + # This object is separate from the ub_obj_wgrad object which is passed to the GEMM + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + + # We use the send stream to copy into the userbuffers. + # This is the same stream that we will use to access the data in the AG, + # so we dont need to add any syncs yet. + with torch.cuda.stream(dgrad_send_stream): + grad_output, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_overlap_wgrad, + grad_outputs[0], + ctx.grad_output_quantizer, + ctx.tp_group, + ) + + # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm + tex.bulk_overlap_ag_with_external_gemm( + ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream + ) + + # Prepare input tensor + # Note: Synchronize tensor-parallel communication and + # make sure required data is available + if ln_out_total_work is not None: + ln_out_total_work.wait() + ln_out_total_work = None + if ctx.fp8 or ctx.debug: + if isinstance(ln_out_total, QuantizedTensorBase): + ln_out_total.update_usage(columnwise_usage=True) + else: + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out_total = ctx.input_quantizer(ln_out_total) + + if ctx.fp8 or ctx.debug: + if isinstance(grad_output, QuantizedTensorBase): + grad_output.update_usage(columnwise_usage=True) + else: + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + grad_output = ctx.grad_output_quantizer(grad_output) + + # Figure out whether to use split accumulator + use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator + + # Figure out whether to output wgrad GEMM directly into main grad + if ctx.is_first_microbatch is not None: + accumulate_wgrad_into_param_main_grad = ( + ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + ) + else: + accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + + # Output buffer for overlapping FP8 grad input + # reduce-scatter with wgrad GEMM + reduce_scatter_out = None + if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + reduce_scatter_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device + ) + + # Arguments to include in wgrad GEMM closure + wgrad_gemm_kwargs = { + "workspace": get_workspace(), + "out_dtype": ( + main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + ), + "quantization_params": ctx.grad_weight_quantizer, + "accumulate": accumulate_wgrad_into_param_main_grad, + "layout": "NT", + "out": main_grad if ctx.fuse_wgrad_accumulation else None, + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "use_split_accumulator": use_split_accumulator, + "grad": True, + "ub": ub_obj_wgrad, + "ub_type": ub_type_wgrad, + "extra_output": reduce_scatter_out, + "bulk_overlap": ctx.ub_bulk_wgrad, + } + + def wgrad_gemm( + x: torch.Tensor, + dy: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform wgrad GEMM: dw = dy^T * x + + May be fused with bgrad computation. + + May be called outside of this function to enable + some advanced communication/compute overlapping. + + """ + nvtx_range_push(f"{nvtx_label}.wgrad_gemm") + dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs) + nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") + return dw, db + + # Choose whether to call wgrad GEMM now or delay + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + if ( + wgrad_gemm_kwargs["ub"] is not None + or wgrad_gemm_kwargs["ub_type"] is not None + or wgrad_gemm_kwargs["extra_output"] is not None + or wgrad_gemm_kwargs["bulk_overlap"] + ): + raise NotImplementedError( + "Delayed weight grad computation is not supported " + "with Userbuffers (tensor-parallel communication overlapping)" + ) + ctx.wgrad_store.put([ln_out_total, grad_output], wgrad_gemm) + else: + + # Call wgrad GEMM now + wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output) + + # Update grad bias if needed + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ + + # Deallocate input tensors if permitted + if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: + # Input tensors have not been exposed externally + clear_tensor_data(ln_out) + elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered: + # Non-gathered input has not been exposed externally + clear_tensor_data(ln_out) + if ctx.ln_out_needs_gather: + # Gathered input is internal + clear_tensor_data(ln_out_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) + + # Update grad input if overlapping reduce-scatter with wgrad GEMM + if ctx.ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = reduce_scatter_out + else: + dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone() + + # -------------------------------------------------- + # Grad weight has been computed... + # -------------------------------------------------- + + # Don't return grad bias if not needed + if not ctx.use_bias: + grad_bias = None + + # Synchronize tensor parallel communication + if ln_out_total_work is not None: + ln_out_total_work.wait() + ln_out_total_work = None + if dgrad_work is not None: + dgrad_work.wait() + dgrad_work = None + + # Residual gradient + dgrad = dgrad.view(inputmat.shape) + if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: + dgrad = dgrad + grad_outputs[1].view_as(dgrad) + + # Norm gradient + dgamma = None + dbeta = None + nvtx_range_push(f"{nvtx_label}.norm") + if ctx.normalization == "LayerNorm": + dgrad, dgamma, dbeta = tex.layernorm_bwd( + dgrad, + inputmat, + mu, + rsigma, + ln_weight, + ctx.bwd_ln_sm_margin, + ctx.zero_centered_gamma, + ) + dgrad = dgrad.reshape(inputmat.size()) + elif ctx.normalization == "RMSNorm": + dgrad, dgamma = tex.rmsnorm_bwd( + dgrad, + inputmat, + rsigma, + ln_weight, + ctx.bwd_ln_sm_margin, + ctx.zero_centered_gamma, + ) + dgrad = dgrad.reshape(inputmat.size()) + dbeta = None + nvtx_range_pop(f"{nvtx_label}.norm") + clear_tensor_data(mu) + clear_tensor_data(rsigma) + + if ctx.requires_wgrad: + # Handle custom DDP from mcore. + if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): + origin_weight.grad_added_to_main_grad = True + if getattr(origin_weight, "zero_out_wgrad", False): + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, + zero=True, + ) + else: + wgrad = get_dummy_wgrad( + list(origin_weight.main_grad.shape), + origin_weight.dtype, + ) + elif ctx.fuse_wgrad_accumulation: + wgrad = None + else: + wgrad = None + + if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") + + # Scatter fp8 weight buffers + # if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): + # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) + + return ( + dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, + dgamma, + dbeta, + wgrad, + grad_bias, + None, # eps + None, # is_first_microbatch + None, # fp8 + None, # fp8_calibration + None, # wgrad_store + None, # fuse_wgrad_accumulation + None, # input_quantizer + None, # weight_quantizer + None, # output_quantizer + None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # grad_output_quantizer + None, # cpu_offloading + None, # tp_group + None, # tp_size + None, # sequence_parallel + None, # tensor_parallel + None, # activation_dtype + None, # parallel_mode + None, # return_layernorm_output + None, # return_layernorm_output_gathered + None, # is_grad_enabled + None, # fwd_ln_sm_margin + None, # bwd_ln_sm_margin + None, # zero_centered_gamma + None, # normalization + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad + None, # ub_name + None, # fsdp_group + None, # debug + None, # module + None, # skip_fp8_weight_update + None, # symmetric_ar_type + ) + + +class LayerNormLinear(TransformerEngineBaseModule): + r""" + Applies layer normalization followed by linear transformation to the incoming data. + + Parameters + ---------- + in_features : int + size of each input sample. + out_features : int + size of each output sample. + eps : float, default = 1e-5 + a value added to the denominator of layer normalization for numerical stability. + bias : bool, default = `True` + if set to `False`, the layer will not learn an additive bias. + normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' + type of normalization applied. + init_method : Callable, default = `None` + used for initializing weights in the following way: `init_method(weight)`. + When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. + return_layernorm_output : bool, default = `False` + if set to `True`, output of layernorm is returned from the forward + together with the output of the linear transformation. + Example use case: residual connection for transformer module is + taken post layernorm. + return_layernorm_output_gathered : bool, default = `False` + if set to `True`, output of layernorm is returned after the all + gather operation. Ignored if return_layernorm_output is False. + Example use case: with sequence parallel, input to residual connection + for transformer module (e.g. LoRA) will need to be gathered. + Returning layernorm output gathered will prevent a redundant gather. + parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None + Configuration for splitting the weight and bias tensors along dim 0 into + multiple PyTorch parameters. If a list or tuple of strings is provided, + they are used to make the names of equally-sized parameters. If a dict + (preferably an OrderedDict) is provided, the keys are used as names and + values as split sizes along dim 0. The resulting parameters will have + names that end in `_weight` or `_bias`, so trailing underscores are + stripped from any provided names. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + device : Union[torch.device, str], default = "cuda" + The device on which the parameters of the model will be allocated. It is the user's + responsibility to ensure all parameters are moved to the GPU before running the + forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. + + Parallelism parameters + ---------------------- + sequence_parallel : bool, default = `False` + if set to `True`, uses sequence parallelism. + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + tp_size : int, default = 1 + used as TP (tensor parallel) world size when TP groups are not formed during + initialization. In this case, users must call the + `set_tensor_parallel_group(tp_group)` method on the initialized module before the + forward pass to supply the tensor parallel group needed for tensor and sequence + parallel collectives. + parallel_mode : {None, 'column', 'row'}, default = `None` + used to decide whether this Linear layer is Column Parallel Linear or Row + Parallel Linear as described `here `_. + When set to `None`, no communication is performed. + + Optimization parameters + ----------------------- + fuse_wgrad_accumulation : bool, default = 'False' + if set to `True`, enables fusing of creation and accumulation of + the weight gradient. When enabled, it is assumed that the weights + have an additional `main_grad` attribute (used instead of the + regular `grad`) which is a pre-allocated buffer of the correct + size to accumulate gradients in. + return_bias : bool, default = `False` + when set to `True`, this module will not apply the additive bias itself, but + instead return the bias value during the forward pass together with the + output of the linear transformation :math:`y = xA^T`. This is useful when + the bias addition can be fused to subsequent operations. + params_dtype : torch.dtype, default = `torch.get_default_dtype()` + it controls the type used to allocate the initial parameters. Useful when + the model is trained with lower precision and the original FP32 parameters + would not fit in GPU memory. + delay_wgrad_compute : bool, default = `False` + Whether or not to delay weight gradient computation. If set to `True`, + it's the user's responsibility to call `module.backward_dw` to compute + weight gradients. + symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None + Type of symmetric memory all-reduce to use during the forward pass. + This can help in latency bound communication situations. + Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + is used. + """ + + def __init__( + self, + in_features: int, + out_features: int, + eps: float = 1e-5, + sequence_parallel: bool = False, + fuse_wgrad_accumulation: bool = False, + tp_group: Optional[dist_group_type] = None, + tp_size: int = 1, + get_rng_state_tracker: Optional[Callable] = None, + init_method: Optional[Callable] = None, + bias: bool = True, + normalization: str = "LayerNorm", + return_bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + parallel_mode: Optional[str] = None, + return_layernorm_output: bool = False, + return_layernorm_output_gathered: bool = False, + parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, + zero_centered_gamma: bool = False, + device: Union[torch.device, str] = "cuda", + ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_wgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_name: Optional[str] = None, + delay_wgrad_compute: bool = False, + symmetric_ar_type: Optional[str] = None, + name: str = None, + ) -> None: + super().__init__() + + params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.in_features = in_features + self.out_features = out_features + self.fuse_wgrad_accumulation = fuse_wgrad_accumulation + self.normalization = normalization + assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" + self.use_bias = bias + self.return_bias = return_bias + self.apply_bias = self.use_bias and not return_bias + self.return_layernorm_output = return_layernorm_output + self.return_layernorm_output_gathered = ( + return_layernorm_output_gathered if return_layernorm_output else False + ) + self.zero_centered_gamma = zero_centered_gamma + self.symmetric_ar_type = symmetric_ar_type + + self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) + self.name = name + + if tp_group is None: + self.tp_size = tp_size + if tp_size == 1: + self.set_tensor_parallel_group(tp_group) + else: + self.tp_size = get_distributed_world_size(tp_group) + self.set_tensor_parallel_group(tp_group) + self.set_nccl_overlap_warning_if_tp() + + self.parallel_mode = parallel_mode + assert ( + self.parallel_mode in GemmParallelModes + ), f"parallel_mode {parallel_mode} not supported" + + if self.parallel_mode == "column": + self.out_features = divide(self.out_features, self.tp_size) + elif self.parallel_mode == "row": + self.in_features = divide(self.in_features, self.tp_size) + + if init_method is None: + init_method = get_default_init_method() + + self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + + # Column-parallel overlaps + self.ub_overlap_ag_fprop = ( + ub_overlap_ag and self.sequence_parallel and self.parallel_mode == "column" + ) + self.ub_overlap_rs_dgrad = ( + ub_overlap_rs_dgrad and self.sequence_parallel and self.parallel_mode == "column" + ) + self.ub_bulk_wgrad = ( + ub_bulk_wgrad + and self.sequence_parallel + and self.parallel_mode == "column" + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + ub_bulk_dgrad + and self.sequence_parallel + and self.parallel_mode == "column" + and not self.ub_overlap_rs_dgrad + ) + + # Row-parallel overlaps + self.ub_overlap_rs_fprop = ( + ub_overlap_rs and self.sequence_parallel and self.parallel_mode == "row" + ) + self.ub_overlap_ag_dgrad = ( + ub_overlap_ag and self.sequence_parallel and self.parallel_mode == "row" + ) + if any( + [ + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + ] + ): + assert ub_name is not None, "Userbuffer name [string] is not set." + self.ub_name = ub_name + + if self.symmetric_ar_type is not None: + assert torch_version() >= ( + 2, + 7, + 0, + ), "Torch version must be at least 2.7 to use symmetric memory" + + self.eps = eps + layer_norm_weight = torch.nn.Parameter( + torch.empty(self.in_features, device=device, dtype=params_dtype) + ) + self.register_parameter( + "layer_norm_weight", + layer_norm_weight, + init_fn=init_method_constant(float(not self.zero_centered_gamma)), + ) + if self.normalization != "RMSNorm": + layer_norm_bias = torch.nn.Parameter( + torch.empty(self.in_features, device=device, dtype=params_dtype) + ) + self.register_parameter( + "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0) + ) + else: + self.layer_norm_bias = None + + # Initialize params in FP8 + with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() + + # Contiguous buffers for params + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device=device, + dtype=params_dtype, + ) + bias_tensor = None + if self.use_bias: + bias_tensor = torch.empty( + self.out_features, + device=device, + dtype=params_dtype, + ) + + # Configure parameter splits + self.weight_names = [] + self.bias_names = [] + self.parameter_split_sizes = [] + if parameters_split is None: + # Split into a single parameter by default + self.weight_names = ["weight"] + self.bias_names = ["bias"] + self.parameter_split_sizes = [out_features] + elif not parameters_split: + raise ValueError("Cannot split weight buffer into 0 parameters") + elif isinstance(parameters_split, dict): + # Split parameters with provided sizes + for name, split_size in parameters_split.items(): + self.weight_names.append(f"{name.rstrip('_')}_weight") + self.bias_names.append(f"{name.rstrip('_')}_bias") + self.parameter_split_sizes.append(split_size) + elif all(isinstance(name, str) for name in parameters_split): + # Split parameters evenly + split_size = out_features // len(parameters_split) + for name in parameters_split: + self.weight_names.append(f"{name.rstrip('_')}_weight") + self.bias_names.append(f"{name.rstrip('_')}_bias") + self.parameter_split_sizes.append(split_size) + else: + raise TypeError("Invalid configuration for parameters split") + + # Make sure parameter splits are valid + if sum(self.parameter_split_sizes) != out_features: + raise ValueError( + f"Trying to split weight buffer ({out_features=}) " + f"with split sizes {self.parameter_split_sizes}" + ) + + # Adjust parameter splits for tensor-parallel distribution + if self.parallel_mode == "column": + for i, size in enumerate(self.parameter_split_sizes): + if size % self.tp_size != 0: + raise RuntimeError( + f"Attempting to distribute a parameter with out_features={size} " + f"between {self.tp_size} tensor-parallel processes" + ) + self.parameter_split_sizes[i] = size // self.tp_size + + # Construct weight parameters + # Note: Register weights together so that they are adjacent to + # each other in LayerNormLinear.parameters(). This makes it + # more likely that they will stay contiguous if the weights + # are manipulated externally, e.g. by FSDP. + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + + # Check if parameters are subviews of buffers + is_subview = (split_start, split_end) != (0, self.out_features) + if is_subview and with_fp8_params: + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) + + # Construct weight parameter + self.register_parameter( + self.weight_names[i], + torch.nn.Parameter(weight_tensor[split_start:split_end]), + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + + # Construct bias parameters if needed + if self.use_bias: + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + self.register_parameter( + self.bias_names[i], + torch.nn.Parameter(bias_tensor[split_start:split_end]), + init_fn=init_method_constant(0.0), + ) + else: + for name in self.bias_names: + bias = torch.Tensor().to(dtype=params_dtype, device=device) + setattr(self, name, bias) + + if with_fp8_params: + self.init_fp8_metadata() + + self.reset_parameters(defer_init=device == "meta") + + # For RPL, bias has to be added after TP collectives + # So it cannot be fused with the GEMM + if self.parallel_mode == "row" and self.apply_bias: + self.gemm_bias_unfused_add = True + else: + self.gemm_bias_unfused_add = False + + # These many SMs are subtracted from the total SM count when calling forward + # and backward LayerNorm C APIs. These envvars can be used to prevent the LN + # kernels from using all SMs in the device. This is useful for cases such as + # communication overlap with LN. + self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + + if self.wgrad_store.delay_wgrad_compute(): + for name, param in self.named_parameters(): + if name in self.weight_names or name in self.bias_names: + param.skip_backward_post_hook = True + + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + elif recipe.float8_block_scaling(): + self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) + # elif other recipes (mxfp8, etc) + + def reset_layer_norm_parameters(self) -> None: + """Init LN params""" + warnings.warn( + "This method will be deprecated in an upcoming release. " + "Update your code to use LayerNormLinear.reset_parameters() instead.", + DeprecationWarning, + stacklevel=2, + ) + if not self.zero_centered_gamma: + init.ones_(self.layer_norm_weight) + else: + init.zeros_(self.layer_norm_weight) + if self.layer_norm_bias is not None: + init.zeros_(self.layer_norm_bias) + + def reset_parameters(self, defer_init=False): + super().reset_parameters(defer_init=defer_init) + + if not defer_init: + # Set parallelism attributes for layer norm parameters + setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) + if self.normalization != "RMSNorm": + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + + # Set parallelism attributes for linear weights + for weight in self.weight_names: + set_tensor_model_parallel_attributes( + tensor=getattr(self, weight), + is_parallel=True, + dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) + + # Set parallelism attributes for linear biases + if self.use_bias: + for bias in self.bias_names: + if self.parallel_mode == "row": + setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) + elif self.parallel_mode == "column": + set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) + + @no_torch_dynamo() + def forward( + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None, + fp8_output: Optional[bool] = False, + fp8_grad: Optional[bool] = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + """ + Apply layer normalization to the input followed by a linear transformation. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + is_first_microbatch : {True, False, None}, default = None + During training using either gradient accumulation or + pipeline parallelism a minibatch of data is further split + into microbatches. Between the microbatches of the same minibatch + the model weights are not updated. Setting this parameter indicates + whether the current microbatch is the first in a minibatch or not. + When set, this parameter enables additional optimizations: + + * during FP8 training, it allows caching of the FP8 versions of + the weights + * it also allows skipping gradient accumulation during the + first microbatch (since it is the first gradient being + produced) + """ + if is_in_onnx_export_mode(): + return self.onnx_forward(inp, fp8_output) + + debug = self.is_debug_iter() + + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None + if skip_fp8_weight_update is not None: + is_first_microbatch = False + + if self.ub_overlap_rs_fprop: + if get_ub( + self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): + fp8_output = True + if self.ub_overlap_rs_dgrad: + if get_ub( + self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): + fp8_grad = True + + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward( + inp, allow_non_contiguous=False # removed .contiguous from inside the layer + ) as inp: + + # Get concatenated weight and bias tensors + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + + quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad) + + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + + if torch.is_grad_enabled(): + fwd_fn = _LayerNormLinear.apply + args = [] + else: + fwd_fn = _LayerNormLinear.forward + args = [None] + args += ( + inp, + self.layer_norm_weight, + self.layer_norm_bias, + weight_tensor, + bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + self.fuse_wgrad_accumulation, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + torch.is_grad_enabled(), + self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.normalization, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_rs_dgrad, + self.ub_bulk_wgrad, + self.ub_bulk_dgrad, + self.ub_name, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + debug, + ) + out = fwd_fn(*args) + + if self.return_layernorm_output: + out, ln_out = out + + if self.gemm_bias_unfused_add: + out = out + cast_if_needed(bias_tensor, self.activation_dtype) + + if self.return_bias: + if self.return_layernorm_output: + return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out + return out, cast_if_needed(bias_tensor, self.activation_dtype) + if self.return_layernorm_output: + return out, ln_out + return out + + def _get_quantizers(self, fp8_output, fp8_grad): + if not self.fp8: + return [None] * 6 + grad_input_quantizer = None + grad_weight_quantizer = None + grad_output_quantizer = None + output_quantizer = None + input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer.internal = True + (weight_quantizer,) = self._get_weight_quantizers() + if fp8_output: + output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + if torch.is_grad_enabled(): + grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer.internal = True + if fp8_grad: + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) + + def _get_debug_quantizers(self, fp8_output, fp8_grad): + original_quantizers = self._get_quantizers(fp8_output, fp8_grad) + assert TEDebugState.debug_enabled + from ...debug.pytorch.debug_quantization import DebugQuantizer + + names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return tuple( + DebugQuantizer(self.name, name, q, self.tp_group) + for name, q in zip(names, original_quantizers) + ) + + def _get_weight_and_bias_tensors(self): + # Get concatenated weight and bias tensors + unfused_weights = self._get_weight_tensors() + + weight_tensor = noop_cat(unfused_weights) + if self.use_bias: + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + else: + bias_tensor = getattr(self, self.bias_names[0]) # Unused + return weight_tensor, bias_tensor + + def onnx_forward( + self, + inp: torch.Tensor, + fp8_output: bool, + ) -> torch.Tensor: + """ + ONNX-compatible version of the forward function that provides numerical equivalence + while only using operations that have defined ONNX symbolic translations. + This simplified implementation is designed specifically for inference scenarios. + """ + from ..export import onnx_layernorm, onnx_gemm + + assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export" + assert_warmed_up(self) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + *_, + ) = self._get_quantizers(fp8_output, fp8_grad=False) + inp_dtype = inp.dtype + + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + ln_out, ln_out_return = onnx_layernorm( + inp, + self.layer_norm_weight, + self.layer_norm_bias, + self.eps, + self.normalization, + self.zero_centered_gamma, + inp_dtype, + self.return_layernorm_output, + input_quantizer, + ) + + if weight_quantizer is not None: + weight_tensor_quantized = weight_quantizer.onnx_quantize(weight_tensor) + weight_tensor = weight_quantizer.onnx_dequantize(weight_tensor_quantized) + weight_tensor = weight_tensor.to(inp_dtype) + + if bias_tensor is not None: + bias_tensor = bias_tensor.to(inp_dtype) + + output = onnx_gemm(weight_tensor, ln_out, bias_tensor if self.apply_bias else None) + + if output_quantizer is not None: + raise NotImplementedError("ONNX export of quantized output is not supported") + if self.return_layernorm_output and self.return_bias: + return output, bias_tensor.to(inp_dtype), ln_out_return + if self.return_layernorm_output: + return output, ln_out_return + if self.return_bias: + return output, bias_tensor.to(inp_dtype) + return output + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_linear.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # also set weight quantizer with same amax_epsilon & power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # parallel related + if self.sequence_parallel and self.parallel_mode == "column": + # set input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + # parallel related + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_linear.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # set input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + """Get the weight tensors of the module.""" + unfused_weights = [getattr(self, name) for name in self.weight_names] + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): + if self.fp8: + if len(unfused_weights) != 1: + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) + else: + warnings.warn( + "You are using quantized weights without quantized compute. " + "Please make sure this is intentional." + ) + unfused_weights = [w.dequantize() for w in unfused_weights] + return unfused_weights + + def _get_weight_quantizers(self) -> List[Quantizer]: + """Get the weight quantizers of the module.""" + if not self.fp8 and not self.fp8_calibration: + return [None] + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + return [weight_quantizer] + + def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on blockwise scaling recipe + layernorm_linear.""" + assert ( + recipe.float8_block_scaling() + ), "blockwise scaling recipe quantizer customization here" + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].all_gather_usage = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index cf7f58947b..7d3f2a167d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -7,6 +7,7 @@ from functools import reduce from operator import mul as multiply_op import warnings +import os import torch @@ -53,6 +54,9 @@ ) from ..cpp_extensions import ( general_gemm, + ubsymm_request_allocator, + ubsymm_get_sym_tensor, + ubsymm_allreduce, ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo @@ -118,6 +122,9 @@ def forward( symmetric_ar_type: str, save_original_input: bool = False, debug: Optional[bool] = False, + residual: Optional[torch.Tensor] = None, + eps: Optional[float] = None, + ln_weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -301,6 +308,16 @@ def forward( out_shape[-1] = out_features reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) + symm_out = None + if symmetric_ar_type is not None and symmetric_ar_type.startswith("ubnext") and parallel_mode == "row" and tp_size > 1: + out_shape_list = list(tuple(inp.shape)) + out_shape_list[-1] = out_features + symm_out = ubsymm_get_sym_tensor( + torch.Size(out_shape_list), + activation_dtype, + tp_group, + ) + assert symm_out is not None or symmetric_ar_type == "ubnext", "No symmetric pool out of space fallback for fused ops, increase NVTE_UB_SYMM_POOL_SIZE" # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T @@ -317,6 +334,7 @@ def forward( ub=ub_obj, ub_type=ub_type, extra_output=reduce_scatter_out, + out=symm_out, ) nvtx_range_pop(f"{nvtx_label}.gemm") # ------------------------------------------------------ @@ -344,7 +362,17 @@ def forward( out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: if symmetric_ar_type is not None: - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + if symm_out is not None: + out = ubsymm_allreduce(symm_out,residual_global=residual,gamma=ln_weight,eps=eps) + else: + fallback_symmetric = ( + "multimem_all_reduce" + if symmetric_ar_type.startswith("ubnext") + else symmetric_ar_type + ) + out, _ = symmetric_all_reduce( + out, tp_group, all_reduce_type=fallback_symmetric + ) else: out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") @@ -1112,6 +1140,8 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + eps: Optional[float] = None, + ln_weight: Optional[torch.Tensor] = None, ) -> None: super().__init__() @@ -1200,7 +1230,17 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - + if self.symmetric_ar_type.startswith("ubnext") and parallel_mode == "row" and tp_size > 1: + ubsymm_request_allocator( + self.tp_group, + ( + int(os.environ.get("NVTE_UB_MAXBATCH", 64)), + self.out_features, + ), + params_dtype, + ) + self.eps = eps + self.layer_norm_weight = ln_weight # in general expected to be filled with reference to layernorm_weight from next LayerNormLinear later # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1366,6 +1406,7 @@ def forward( is_first_microbatch: Optional[bool] = None, fp8_output: Optional[bool] = False, fp8_grad: Optional[bool] = False, + residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -1478,6 +1519,9 @@ def forward( self.symmetric_ar_type, self.save_original_input, debug, + residual, + self.eps, + self.layer_norm_weight, ) out = linear_fn(*args) if self.gemm_bias_unfused_add: diff --git a/transformer_engine/pytorch/module/linear.py.orig b/transformer_engine/pytorch/module/linear.py.orig new file mode 100644 index 0000000000..cf7f58947b --- /dev/null +++ b/transformer_engine/pytorch/module/linear.py.orig @@ -0,0 +1,1710 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear API""" +from typing import Callable, Dict, Optional, Tuple, Union, List +from functools import reduce +from operator import mul as multiply_op +import warnings + +import torch + +import transformer_engine_torch as tex + +from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch import torch_version + +from .base import ( + fill_userbuffers_buffer_for_all_gather, + get_dummy_wgrad, + get_ub, + get_workspace, + TransformerEngineBaseModule, + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, +) +from ._common import noop_cat, WeightGradStore, get_module_quantizers +from ..fp8 import FP8GlobalStateManager +from ..utils import ( + cast_if_needed, + clear_tensor_data, + divide, + init_method_constant, + requires_grad, + needs_quantized_gemm, + assert_dim_for_fp8_exec, + assert_dim_for_all_gather, + nvtx_range_pop, + nvtx_range_push, +) +from ..distributed import ( + set_tensor_model_parallel_attributes, + get_distributed_world_size, + allreduce, + symmetric_all_reduce, + reduce_scatter_along_first_dim, + gather_along_first_dim, + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, + _fsdp_scatter_tensors, + _fsdp_gather_tensors, +) +from ..cpp_extensions import ( + general_gemm, +) +from ..constants import GemmParallelModes, dist_group_type +from ..jit import no_torch_dynamo +from ..graph import is_graph_capturing +from ..tensor.quantized_tensor import ( + QuantizedTensor, + QuantizedTensorBase, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.utils import is_experimental +from ..export import is_in_onnx_export_mode, assert_warmed_up +from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...debug.pytorch.debug_state import TEDebugState + +__all__ = ["Linear"] + + +class _Linear(torch.autograd.Function): + """Linear semi-top level module + Calls custom cuda extensions. + """ + + @staticmethod + def forward( + ctx, + weight: torch.Tensor, + inp: torch.Tensor, + bias: Optional[torch.Tensor], + is_first_microbatch: Union[bool, None], + fp8: bool, + fp8_calibration: bool, + wgrad_store: WeightGradStore, + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], + fuse_wgrad_accumulation: bool, + cpu_offloading: bool, + tp_group: Union[dist_group_type, None], + tp_size: int, + sequence_parallel: bool, + tensor_parallel: bool, + activation_dtype: torch.dtype, + parallel_mode: Union[str, None], + is_grad_enabled: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_dgrad: bool, + ub_bulk_dgrad: bool, + ub_bulk_wgrad: bool, + ub_name: str, + fp8_output: bool, # pylint: disable=unused-argument + fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, + symmetric_ar_type: str, + save_original_input: bool = False, + debug: Optional[bool] = False, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + + # NVTX label for profiling + nvtx_label = "transformer_engine._Linear.forward" + if ub_name is not None: + nvtx_label = f"{nvtx_label}.{ub_name}" + + # Make sure input dimensions are compatible + out_features, in_features = weight.shape + assert inp.shape[-1] == in_features, "GEMM not possible" + + # Configure tensor-parallel communication + tp_world_size = get_distributed_world_size(tp_group) + backward_needs_input = is_grad_enabled and weight.requires_grad + with_input_all_gather_nccl = ( + parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop + ) + + # Configure Userbuffers communication (comm+GEMM overlap) + if debug: # turn off userbuffers in debug mode + ub_overlap_rs_fprop = False + ub_overlap_ag_fprop = False + ub_overlap_rs_dgrad = False + ub_bulk_wgrad = False + ub_bulk_dgrad = False + ub_obj = None + ub_type = None + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop", fp8) + ub_type = tex.CommOverlapType.RS + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop", fp8) + ub_type = tex.CommOverlapType.AG + + # experimental recipe check + experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer) + + # ------------------------------------------------------ + # Prepare input tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + # ------------------------------------------------------ + nvtx_range_push(f"{nvtx_label}.input_cast_comm") + inputmat = inp # Input tensor to save for backward (maybe sharded) + inputmat_total = None # Input tensor to pass to GEMM (gathered) + own_quantized_input = False + if fp8: + assert_dim_for_fp8_exec(inputmat, weight) + assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) + if save_original_input: + assert not isinstance( + input_quantizer, Float8Quantizer + ), "DelayedScaling recipe is not supported with save_original_input" + + if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor + + # Cast local input tensor if needed + if fp8 or debug: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if not isinstance(inputmat, QuantizedTensorBase) and not experimental: + own_quantized_input = True + input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + if isinstance( + input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): + # All-gather is not supported with FP8 column-wise data + input_quantizer.set_usage(columnwise=False) + if save_original_input: + # No need for column-wise data since this + # tensor will not be cached for backward pass + input_quantizer.set_usage(columnwise=False) + own_quantized_input = False + inputmat = input_quantizer(inputmat) + else: + inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP + + # Initialize gathered input tensor + quantizer = None + if fp8 or debug: + quantizer = input_quantizer + quantizer.set_usage(rowwise=True, columnwise=False) + if with_input_all_gather_nccl: # Perform NCCL all-gather + inputmat_total, _ = gather_along_first_dim( + inputmat, + tp_group, + quantizer=quantizer, + ) + elif ub_overlap_ag_fprop: # Initialize Userbuffers all-gather + inputmat_total, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj, + inputmat, + quantizer, + tp_group, + ) + + else: # Do not all-gather input tensor + if fp8 or debug: + if isinstance(inputmat, QuantizedTensorBase): + inputmat.update_usage(rowwise_usage=True) + else: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + input_quantizer.set_usage( + rowwise=True, columnwise=backward_needs_input and not save_original_input + ) + inputmat = input_quantizer(inputmat) + own_quantized_input = True + else: + inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP + inputmat_total = inputmat + nvtx_range_pop(f"{nvtx_label}.input_cast_comm") + # ------------------------------------------------------ + # Input tensor is ready for GEMM... + # ------------------------------------------------------ + + # ------------------------------------------------------ + # Prepare weight tensor + # ------------------------------------------------------ + weightmat = weight + if fp8 or debug: + # Configure quantizer + if weight_quantizer is not None: + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + + # Get quantized weight + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, + ) + weightmat.update_usage(rowwise_usage=True) + + else: + weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP + # ------------------------------------------------------ + # Weight tensor is ready for GEMM... + # ------------------------------------------------------ + + # Cast bias to expected dtype + bias_dtype = activation_dtype + if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32: + # cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16 + bias_dtype = torch.bfloat16 + bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if input_quantizer is not None: + input_quantizer.calibrate(inputmat_total) + if weight_quantizer is not None: + weight_quantizer.calibrate(weight) + + # Choose whether to use GEMM kernel with split accumulator + use_split_accumulator = _2X_ACC_FPROP + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + + # Configure output quantizer + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Output buffer for Userbuffers reduce-scatter + reduce_scatter_out = None + if ub_overlap_rs_fprop: + out_shape = list(inp.shape) + out_shape[0] //= tp_world_size + out_shape[-1] = out_features + reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) + + # ------------------------------------------------------ + # Forward GEMM + # Note: y = x * w^T + # ------------------------------------------------------ + nvtx_range_push(f"{nvtx_label}.gemm") + gemm_out, *_, reduce_scatter_out = general_gemm( + weightmat, + inputmat_total, + get_workspace(), + quantization_params=output_quantizer, + out_dtype=activation_dtype, + bias=bias, + use_split_accumulator=use_split_accumulator, + ub=ub_obj, + ub_type=ub_type, + extra_output=reduce_scatter_out, + ) + nvtx_range_pop(f"{nvtx_label}.gemm") + # ------------------------------------------------------ + # Finished forward GEMM... + # ------------------------------------------------------ + + # Deallocate GEMM input tensor if no longer needed + # TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + if with_input_all_gather_nccl: + clear_tensor_data(inputmat_total) + inputmat_total = None + + # ------------------------------------------------------ + # Prepare output tensor + # Note: Perform tensor-parallel communication + # ------------------------------------------------------ + out = None + if ub_overlap_rs_fprop: + out = reduce_scatter_out + elif parallel_mode == "row" and tp_size > 1: + nvtx_range_push(f"{nvtx_label}.row_parallel_comm") + out = gemm_out + if sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + if symmetric_ar_type is not None: + out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + else: + out, _ = allreduce(out, tp_group) + nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") + else: + out = gemm_out + # ------------------------------------------------------ + # Output tensor is ready to return... + # ------------------------------------------------------ + + # ------------------------------------------------------ + # Cache state for backward pass + # ------------------------------------------------------ + + if is_grad_enabled: + if save_original_input: + inputmat = inp + + ctx.weight_quantizer = weight_quantizer + + ctx.backward_input_needs_gather = ( + weight.requires_grad and parallel_mode == "column" and sequence_parallel + ) + + # Discard unneeded data in input tensor + if ( + backward_needs_input + and own_quantized_input + and isinstance(inputmat, QuantizedTensorBase) + ): + if ( + ctx.backward_input_needs_gather + and weight_quantizer.supports_only_rowwise_all_gather() + ): + # All-gather is not supported with FP8 column-wise data + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + else: + # Discard row-wise data since it is not needed in backward pass + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Cached input tensor + saved_inputmat = None + if backward_needs_input: + saved_inputmat = inputmat + + # Weight with column-wise usage is needed for dgrad GEMM. + if inp.requires_grad: + if isinstance(weightmat, QuantizedTensorBase): + weightmat.update_usage(columnwise_usage=True) + + if cpu_offloading and saved_inputmat is not None: + mark_activation_offload(saved_inputmat) + + # Scatter intermediate/activation tensors saved for the backward pass + # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights + nvtx_range_push(f"{nvtx_label}.fsdp_scatter") + ctx.fsdp_group = fsdp_group + ctx.fsdp_shapes = _fsdp_scatter_tensors( + fsdp_group, + saved_inputmat, + weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None, + ) + nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + + if cpu_offloading: + ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + + if ctx.grad_added_to_main_grad: + # If you are passing torch.nn.Parameter through the Torch hooks, you will + # get back torch.Tensor. Torch rips off the Parameter wrapper. + # You need to preserve the weight object to have all the attributes user + # sets for the weights. Because of this, it is not recommended to offload + # weights if weights are externally touched outside this module + ctx.weight_object = weight + + # TODO(ksivamani): Check memory usage + tensors_to_save, tensor_objects = prepare_for_saving( + saved_inputmat, + weightmat, + weight, + bias, + ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.input_quantizer = input_quantizer + ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + if fuse_wgrad_accumulation and weight.requires_grad: + # This check is needed to ensure that main_grad is not created + # during the forward pass when using MCore FSDP as it creates + # the main_grad buffer lazily before backprop + if hasattr(weight, "__fsdp_param__"): + # MCore FSDP creates main_grad lazily before backward + ctx.main_grad_func = weight.get_main_grad + else: + ctx.main_grad_func = lambda: weight.main_grad + + ctx.debug = debug + ctx.experimental = experimental + ctx.cpu_offloading = cpu_offloading + ctx.is_first_microbatch = is_first_microbatch + ctx.use_bias = bias is not None + ctx.sequence_parallel = sequence_parallel + ctx.tensor_parallel = tensor_parallel + ctx.inp_shape = inp.shape + ctx.parallel_mode = parallel_mode + ctx.tp_group = tp_group + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_bulk_wgrad = ub_bulk_wgrad + ctx.ub_name = ub_name + ctx.tp_size = tp_size + ctx.requires_dgrad = inp.requires_grad + ctx.requires_wgrad = weight.requires_grad + ctx.reduce_and_update_bwd_fp8_tensors = False + + ctx.owns_input = saved_inputmat is not inp + if ctx.fp8 and requires_grad(inp, weight, bias): + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.wgrad_store = wgrad_store + + # ------------------------------------------------------ + # Cached state for backward pass is ready... + # ------------------------------------------------------ + + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring + + # NVTX label for profiling + nvtx_label = "transformer_engine._Linear.backward" + if ctx.ub_name is not None: + nvtx_label = f"{nvtx_label}.{ctx.ub_name}" + + with torch.cuda.nvtx.range("_Linear_backward"): + saved_tensors = ctx.saved_tensors + inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking + restore_from_saved(ctx.tensor_objects, saved_tensors) + ) + + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None + + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ( + ctx.main_grad_func() + if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad + else None + ) + + if ctx.cpu_offloading: + if ctx.grad_added_to_main_grad: + weight = ctx.weight_object + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + weight.main_grad = main_grad + + # Gather intermediate/activation tensors if needed + # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_gather") + _fsdp_gather_tensors( + ctx.fsdp_group, + ctx.fsdp_shapes, + inputmat, + weight_fp8, + ) + nvtx_range_pop(f"{nvtx_label}.fsdp_gather") + + # Configure Userbuffers communication (comm+GEMM overlap) + ctx.ub_obj_gradout = None + ub_obj_dgrad = None + ub_obj_wgrad = None + ub_type_dgrad = None + ub_type_wgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + if ctx.ub_overlap_ag: + # Overlap grad_output all-gather with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_type_wgrad = tex.CommOverlapType.RS + + # -------------------------------------------------- + # Prepare grad output tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + # -------------------------------------------------- + + # Unmodified grad output tensor + grad_output_arg = grad_output + + # Configure quantizer for grad output tensor + # Note: dgrad GEMM requires row-wise usage, wgrad GEMM + # requires column-wise usage + if ctx.grad_output_quantizer is not None: + quantizer = ctx.grad_output_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + if ctx.ub_overlap_ag: + # Userbuffers only supports communication for one + # tensor usage at a time. Configure quantizer with + # usage for only dgrad GEMM. + quantizer.set_usage(columnwise=False) + + # Adjust the quantization direction approach depending + # on whether wgrad calculations will be performed. + # NOTE: If requires_dgrad is False, disabling `rowwise` quantization and keeping `columnwise` quantization + # results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!` + # NOTE: For `ctx.bias is True`, selected quantize kernel errors with + # `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.` + if ( + not ctx.use_bias + and not ctx.requires_wgrad + and ctx.grad_output_quantizer is not None + ): + ctx.grad_output_quantizer.set_usage(columnwise=False) + + # Prepare grad output tensor + nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") + ( + grad_output, + grad_bias, + ) = TransformerEngineBaseModule.grad_output_preprocess( + ctx, + grad_output, + ctx.parallel_mode == "row", + ctx.grad_output_quantizer, + ) + nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") + + # -------------------------------------------------- + # Grad output tensor is ready for computing grad input... + # -------------------------------------------------- + + # -------------------------------------------------- + # Prepare input tensor + # Note: Input tensor is needed for wgrad GEMM. + # Tensor-parallel communication is overlapped with dgrad + # GEMM. + # -------------------------------------------------- + inputmat_total = None + inputmat_total_work = None + if ctx.requires_wgrad: + if ctx.fp8 or ctx.debug: + if isinstance(inputmat, QuantizedTensorBase): + # Input tensor is already quantized + pass + elif ctx.debug or ctx.experimental: + # Debug quantizer will be applied immediately before wgrad GEMM + pass + else: + # Quantize input tensor + quantizer = ctx.input_quantizer + if quantizer.supports_only_rowwise_all_gather(): + # All-gather is not supported with FP8 column-wise data + quantizer.set_usage( + rowwise=True, + columnwise=not ctx.backward_input_needs_gather, + ) + else: + quantizer.set_usage(rowwise=False, columnwise=True) + inputmat = quantizer(inputmat) + else: + if isinstance(inputmat, QuantizedTensorBase): + inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) + else: + inputmat = cast_if_needed(inputmat, ctx.activation_dtype) + if ctx.backward_input_needs_gather: + quantizer = None + if ctx.fp8 or ctx.debug: + quantizer = ctx.input_quantizer + if quantizer.supports_only_rowwise_all_gather(): + # If data is in FP8, we compute FP8 transposes manually + quantizer.set_usage(rowwise=True, columnwise=False) + else: + # wgrad GEMM requires input with column-wise usage + quantizer.set_usage(rowwise=False, columnwise=True) + if ctx.ub_bulk_dgrad: + inputmat_total, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_dgrad, + inputmat, + quantizer, + ctx.tp_group, + ) + else: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + inputmat_total, inputmat_total_work = gather_along_first_dim( + inputmat, + ctx.tp_group, + async_op=True, + quantizer=quantizer, + ) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") + else: + inputmat_total = inputmat + # -------------------------------------------------- + # Input tensor is ready for computing grad weight... + # -------------------------------------------------- + + # -------------------------------------------------- + # Compute grad input tensor + # -------------------------------------------------- + + dgrad = None + dgrad_work = None + if ctx.requires_dgrad: + + # Make sure required data is available + if isinstance(grad_output, QuantizedTensorBase): + grad_output.update_usage(rowwise_usage=True) + if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase): + weight_fp8.update_usage(columnwise_usage=True) + + # Choose whether to use GEMM kernel with split accumulator + use_split_accumulator = _2X_ACC_DGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + + # Update grad input quantizer + if ctx.grad_input_quantizer is not None: + ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + + # Output buffers for Userbuffers reduce-scatter + gemm_out = None + reduce_scatter_out = None + if ctx.ub_overlap_rs_dgrad: + reduce_scatter_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device + ) + elif ctx.ub_bulk_wgrad: + gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False) + + # dgrad GEMM + # Note: dx = dy * w + + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + gemm_out, *_, reduce_scatter_out = general_gemm( + weight_fp8, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=gemm_out, + out_dtype=ctx.activation_dtype, + use_split_accumulator=use_split_accumulator, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=reduce_scatter_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") + + # Prepare grad input tensor + # Note: Perform tensor-parallel communication + if ctx.ub_overlap_rs_dgrad: + dgrad = reduce_scatter_out + elif ctx.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) + elif ctx.parallel_mode == "column" and ctx.tp_size > 1: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") + dgrad = gemm_out + if ctx.sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, + ) + else: + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + else: + dgrad = gemm_out + + # -------------------------------------------------- + # Grad input tensor has been computed... + # -------------------------------------------------- + + # -------------------------------------------------- + # Compute grad weight + # -------------------------------------------------- + + wgrad = None + if ctx.requires_wgrad: + + # Prepare input tensor + # Note: Synchronize tensor-parallel communication and + # make sure required data is available + if inputmat_total_work is not None: + inputmat_total_work.wait() + inputmat_total_work = None + if ctx.fp8 or ctx.debug: + if isinstance(inputmat_total, QuantizedTensorBase): + inputmat_total.update_usage(columnwise_usage=True) + else: + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + inputmat_total = ctx.input_quantizer(inputmat_total) + + # Prepare grad output tensor + # Note: Synchronize tensor-parallel communication and + # make sure required data is available + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + # UB does not support pipelined overlapping grad output + # all-gather with wgrad GEMM. Also, we can't + # convert row-scaled MXFP8 to column-scaled, so we + # can't reuse the grad output that was gathered + # for the dgrad GEMM. We work around by explicitly + # overlapping the AG operation with the dgrad GEMM. + + # Get the communication stream from the dgrad GEMM to use for the AG + dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() + + # This object is separate from the ub_obj_wgrad object which is passed to the GEMM + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + + # We use the send stream to copy into the userbuffers. + # This is the same stream that we will use to access the data in the AG, + # so we dont need to add any syncs yet. + with torch.cuda.stream(dgrad_send_stream): + grad_output, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_overlap_wgrad, + grad_output_arg, + ctx.grad_output_quantizer, + ctx.tp_group, + ) + + # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm + tex.bulk_overlap_ag_with_external_gemm( + ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream + ) + + if ctx.fp8 or ctx.debug: + if isinstance(grad_output, QuantizedTensorBase): + grad_output.update_usage(columnwise_usage=True) + else: + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + grad_output = ctx.grad_output_quantizer(grad_output) + + # Figure out whether to use split accumulator + use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator + + # Figure out whether to output wgrad GEMM directly into main grad + if ctx.is_first_microbatch is not None: + accumulate_wgrad_into_param_main_grad = ( + ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + ) + else: + accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + + # Output buffer for overlapping FP8 grad input + # reduce-scatter with wgrad GEMM + reduce_scatter_out = None + if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + reduce_scatter_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device + ) + + # Arguments to include in wgrad GEMM closure + wgrad_gemm_kwargs = { + "workspace": get_workspace(), + "out_dtype": ( + main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + ), + "quantization_params": ctx.grad_weight_quantizer, + "accumulate": accumulate_wgrad_into_param_main_grad, + "layout": "NT", + "out": main_grad if ctx.fuse_wgrad_accumulation else None, + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "use_split_accumulator": use_split_accumulator, + "grad": True, + "ub": ub_obj_wgrad, + "ub_type": ub_type_wgrad, + "extra_output": reduce_scatter_out, + "bulk_overlap": ctx.ub_bulk_wgrad, + } + + def wgrad_gemm( + x: torch.Tensor, + dy: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform wgrad GEMM: dw = dy^T * x + + May be fused with bgrad computation. + + May be called outside of this function to enable + some advanced communication/compute overlapping. + + """ + nvtx_range_push(f"{nvtx_label}.wgrad_gemm") + dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs) + nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") + return dw, db + + # Choose whether to call wgrad GEMM now or delay + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + if ( + wgrad_gemm_kwargs["ub"] is not None + or wgrad_gemm_kwargs["ub_type"] is not None + or wgrad_gemm_kwargs["extra_output"] is not None + or wgrad_gemm_kwargs["bulk_overlap"] + ): + raise NotImplementedError( + "Delayed weight grad computation is not supported " + "with Userbuffers (tensor-parallel communication overlapping)" + ) + ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm) + else: + + # Call wgrad GEMM now + wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output) + + # Update grad bias if needed + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ + + # Deallocate tensors if permitted + if ctx.owns_input: + # Input tensor is internal + clear_tensor_data(inputmat_total) + elif ctx.backward_input_needs_gather: + # Gathered input tensor is internal + clear_tensor_data(inputmat_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) + + # Update grad input if overlapping reduce-scatter with wgrad GEMM + if ctx.ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = reduce_scatter_out + else: + dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone() + + # -------------------------------------------------- + # Grad weight has been computed... + # -------------------------------------------------- + + # Don't return grad bias if not needed + if not ctx.use_bias: + grad_bias = None + + # Make sure all tensor-parallel communication is finished + if inputmat_total_work is not None: + inputmat_total_work.wait() + inputmat_total_work = None + if dgrad_work is not None: + dgrad_work.wait() + dgrad_work = None + + if ctx.requires_wgrad: + # Handle custom DDP from mcore. + if ( + ctx.fuse_wgrad_accumulation + and weight is not None + and hasattr(weight, "grad_added_to_main_grad") + ): + weight.grad_added_to_main_grad = True + if getattr(weight, "zero_out_wgrad", False): + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, + ) + else: + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + ) + elif ctx.fuse_wgrad_accumulation: + wgrad = None + else: + wgrad = None + + # Update FP8 scaling factors if needed + if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") + + # Scatter fp8 weight buffers + if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): + _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) + return ( + wgrad, + dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, + grad_bias, + None, # is_first_microbatch + None, # fp8 + None, # fp8_calibration + None, # wgrad_store + None, # input_quantizer + None, # weight_quantizer + None, # output_quantizer + None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # grad_output_quantizer + None, # fuse_wgrad_accumulation + None, # cpu_offloading + None, # tp_group + None, # tp_size + None, # sequence_parallel + None, # tensor_parallel + None, # activation_dtype + None, # parallel_mode + None, # is_grad_enabled + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad + None, # ub_name + None, # fp8_output + None, # fsdp_group + None, # module + None, # skip_fp8_weight_update + None, # symmetric_ar_type + None, # save_original_input + None, # debug + ) + + +class Linear(TransformerEngineBaseModule): + """Applies a linear transformation to the incoming data :math:`y = xA^T + b` + + On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`. + + Parameters + ---------- + in_features : int + size of each input sample. + out_features : int + size of each output sample. + bias : bool, default = `True` + if set to `False`, the layer will not learn an additive bias. + init_method : Callable, default = `None` + used for initializing weights in the following way: `init_method(weight)`. + When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. + get_rng_state_tracker : Callable, default = `None` + used to get the random number generator state tracker for initializing weights. + rng_tracker_name : str, default = `None` + the param passed to get_rng_state_tracker to get the specific rng tracker. + parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None + Configuration for splitting the weight and bias tensors along dim 0 into + multiple PyTorch parameters. If a list or tuple of strings is provided, + they are used to make the names of equally-sized parameters. If a dict + (preferably an OrderedDict) is provided, the keys are used as names and + values as split sizes along dim 0. The resulting parameters will have + names that end in `_weight` or `_bias`, so trailing underscores are + stripped from any provided names. + device : Union[torch.device, str], default = "cuda" + The device on which the parameters of the model will be allocated. It is the user's + responsibility to ensure all parameters are moved to the GPU before running the + forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. + + Parallelism parameters + ---------------------- + sequence_parallel : bool, default = `False` + if set to `True`, uses sequence parallelism. + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + tp_size : int, default = 1 + used as TP (tensor parallel) world size when TP groups are not formed during + initialization. In this case, users must call the + `set_tensor_parallel_group(tp_group)` method on the initialized module before the + forward pass to supply the tensor parallel group needed for tensor and sequence + parallel collectives. + parallel_mode : {None, 'column', 'row'}, default = `None` + used to decide whether this Linear layer is Column Parallel Linear or Row + Parallel Linear as described `here `_. + When set to `None`, no communication is performed. + + Optimization parameters + ----------------------- + fuse_wgrad_accumulation : bool, default = 'False' + if set to `True`, enables fusing of creation and accumulation of + the weight gradient. When enabled, it is assumed that the weights + have an additional `main_grad` attribute (used instead of the + regular `grad`) which is a pre-allocated buffer of the correct + size to accumulate gradients in. + return_bias : bool, default = `False` + when set to `True`, this module will not apply the additive bias itself, but + instead return the bias value during the forward pass together with the + output of the linear transformation :math:`y = xA^T`. This is useful when + the bias addition can be fused to subsequent operations. + params_dtype : torch.dtype, default = `torch.get_default_dtype()` + it controls the type used to allocate the initial parameters. Useful when + the model is trained with lower precision and the original FP32 parameters + would not fit in GPU memory. + delay_wgrad_compute : bool, default = `False` + Whether or not to delay weight gradient computation. If set to `True`, + it's the user's responsibility to call `module.backward_dw` to compute + weight gradients. + symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None + Type of symmetric memory all-reduce to use during the forward pass. + This can help in latency bound communication situations. + Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + is used. + save_original_input : bool, default = `False` + If set to `True`, always saves the original input tensor rather than the + cast tensor. In some scenarios, the input tensor is used by multiple modules, + and saving the original input tensor may reduce the memory usage. + Cannot work with FP8 DelayedScaling recipe. + """ + + def __init__( + self, + in_features: int, + out_features: int, + sequence_parallel: bool = False, + fuse_wgrad_accumulation: bool = False, + tp_group: Optional[dist_group_type] = None, + tp_size: int = 1, + get_rng_state_tracker: Optional[Callable] = None, + rng_tracker_name: Optional[str] = None, + init_method: Optional[Callable] = None, + bias: bool = True, + return_bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + parallel_mode: Optional[str] = None, + parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, + device: Union[torch.device, str] = "cuda", + ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, + ub_name: Optional[str] = None, + delay_wgrad_compute: bool = False, + symmetric_ar_type: Optional[str] = None, + save_original_input: bool = False, + name: Optional[str] = None, + ) -> None: + super().__init__() + + params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.in_features = in_features + self.out_features = out_features + self.fuse_wgrad_accumulation = fuse_wgrad_accumulation + self.use_bias = bias + self.return_bias = return_bias + self.apply_bias = bias and not return_bias + self.get_rng_state_tracker = get_rng_state_tracker + self.rng_tracker_name = rng_tracker_name + self.symmetric_ar_type = symmetric_ar_type + self.save_original_input = save_original_input + self.name = name + + self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) + + if device == "meta": + assert parameters_split is None, "Cannot split module parameters on 'meta' device." + if tp_group is None: + self.tp_size = tp_size + if tp_size == 1: + self.set_tensor_parallel_group(tp_group) + else: + self.tp_size = get_distributed_world_size(tp_group) + self.set_tensor_parallel_group(tp_group) + self.set_nccl_overlap_warning_if_tp() + + self.parallel_mode = parallel_mode + assert ( + self.parallel_mode in GemmParallelModes + ), f"parallel_mode {parallel_mode} not supported" + + if self.parallel_mode == "column": + self.out_features = divide(self.out_features, self.tp_size) + elif self.parallel_mode == "row": + self.in_features = divide(self.in_features, self.tp_size) + + self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + + # Column parallel TP overlap options + self.ub_overlap_ag_fprop = ( + self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag + ) + self.ub_overlap_rs_dgrad = ( + self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_bulk_dgrad + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_wgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_bulk_wgrad + and not self.ub_overlap_rs_dgrad + ) + + # Row parallel TP overlap options + self.ub_overlap_rs_fprop = ( + self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs + ) + self.ub_overlap_ag_dgrad = ( + self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag + ) + + if any( + [ + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + ] + ): + assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." + self.ub_name = ub_name + + if self.symmetric_ar_type is not None: + assert torch_version() >= ( + 2, + 7, + 0, + ), "Torch version must be at least 2.7 to use symmetric memory" + + # Initialize params in FP8 + with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() + + # Contiguous buffers for params + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device=device, + dtype=params_dtype, + ) + bias_tensor = None + if self.use_bias: + bias_tensor = torch.empty( + self.out_features, + device=device, + dtype=params_dtype, + ) + + # Configure parameter splits + self.weight_names = [] + self.bias_names = [] + self.parameter_split_sizes = [] + if parameters_split is None: + # Split into a single parameter by default + self.weight_names = ["weight"] + self.bias_names = ["bias"] + self.parameter_split_sizes = [out_features] + elif not parameters_split: + raise ValueError("Cannot split weight buffer into 0 parameters") + elif isinstance(parameters_split, dict): + # Split parameters with provided sizes + for name, split_size in parameters_split.items(): + self.weight_names.append(f"{name.rstrip('_')}_weight") + self.bias_names.append(f"{name.rstrip('_')}_bias") + self.parameter_split_sizes.append(split_size) + elif all(isinstance(name, str) for name in parameters_split): + # Split parameters evenly + split_size = out_features // len(parameters_split) + for name in parameters_split: + self.weight_names.append(f"{name.rstrip('_')}_weight") + self.bias_names.append(f"{name.rstrip('_')}_bias") + self.parameter_split_sizes.append(split_size) + else: + raise TypeError("Invalid configuration for parameters split") + + # Make sure parameter splits are valid + if sum(self.parameter_split_sizes) != out_features: + raise ValueError( + f"Trying to split weight buffer ({out_features=}) " + f"with split sizes {self.parameter_split_sizes}" + ) + + # Adjust parameter splits for tensor-parallel distribution + if self.parallel_mode == "column": + for i, size in enumerate(self.parameter_split_sizes): + if size % self.tp_size != 0: + raise RuntimeError( + f"Attempting to distribute a parameter with out_features={size} " + f"between {self.tp_size} tensor-parallel processes" + ) + self.parameter_split_sizes[i] = size // self.tp_size + + # Construct weight parameters + # Note: Register weights together so that they are adjacent to + # each other in Linear.parameters(). This makes it more likely + # that they will stay contiguous if the weights are + # manipulated externally, e.g. by FSDP. + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + + # Check if parameters are subviews of buffers + is_subview = (split_start, split_end) != (0, self.out_features) + if is_subview and with_fp8_params: + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) + + # Construct weight parameter + self.register_parameter( + self.weight_names[i], + torch.nn.Parameter(weight_tensor[split_start:split_end]), + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + + # Construct bias parameters if needed + if self.use_bias: + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + self.register_parameter( + self.bias_names[i], + torch.nn.Parameter(bias_tensor[split_start:split_end]), + init_fn=init_method_constant(0.0), + ) + else: + for name in self.bias_names: + bias = torch.Tensor().to(dtype=params_dtype, device=device) + setattr(self, name, bias) + + if with_fp8_params: + self.init_fp8_metadata() + + self.reset_parameters(defer_init=device == "meta") + + # For RPL, bias has to be added after TP collectives + # So it cannot be fused with the GEMM + if self.parallel_mode == "row" and self.apply_bias: + self.gemm_bias_unfused_add = True + else: + self.gemm_bias_unfused_add = False + + if self.wgrad_store.delay_wgrad_compute(): + for name, param in self.named_parameters(): + if name in self.weight_names or name in self.bias_names: + param.skip_backward_post_hook = True + + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + elif recipe.float8_block_scaling(): + self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) + # elif for other recipes (mxfp8, etc.) + + def reset_parameters(self, defer_init=False): + super().reset_parameters(defer_init=defer_init) + + if not defer_init: + # Set parallelism attributes for linear weights + for weight in self.weight_names: + set_tensor_model_parallel_attributes( + tensor=getattr(self, weight), + is_parallel=True, + dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) + + # Set parallelism attributes for linear biases + if self.use_bias: + for bias in self.bias_names: + if self.parallel_mode == "row": + setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) + elif self.parallel_mode == "column": + set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) + + @no_torch_dynamo() + def forward( + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None, + fp8_output: Optional[bool] = False, + fp8_grad: Optional[bool] = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + """ + Apply the linear transformation to the input. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + is_first_microbatch : {True, False, None}, default = None + During training using either gradient accumulation or + pipeline parallelism a minibatch of data is further split + into microbatches. Between the microbatches of the same minibatch + the model weights are not updated. Setting this parameter indicates + whether the current microbatch is the first in a minibatch or not. + When set, this parameter enables additional optimizations: + + * during FP8 training, it allows caching of the FP8 versions of + the weights + * it also allows skipping gradient accumulation during the + first microbatch (since it is the first gradient being + produced) + """ + if is_in_onnx_export_mode(): + return self.onnx_forward(inp, fp8_output) + + debug = self.is_debug_iter() + + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None + if skip_fp8_weight_update is not None: + is_first_microbatch = False + + if self.ub_overlap_rs_fprop: + if get_ub( + self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): + fp8_output = True + if self.ub_overlap_rs_dgrad: + if get_ub( + self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): + fp8_grad = True + + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward( + inp, + allow_non_contiguous=isinstance(inp, QuantizedTensor), + ) as inp: + + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + + quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad) + + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + + if torch.is_grad_enabled(): + linear_fn = _Linear.apply + args = [] + else: + linear_fn = _Linear.forward + args = [None] + args += ( + weight_tensor, + inp, + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + self.fuse_wgrad_accumulation, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + torch.is_grad_enabled(), + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_name, + fp8_output, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + self.save_original_input, + debug, + ) + out = linear_fn(*args) + if self.gemm_bias_unfused_add: + out = out + cast_if_needed(bias_tensor, self.activation_dtype) + + if self.return_bias: + return out, cast_if_needed(bias_tensor, self.activation_dtype) + return out + + def _get_quantizers(self, fp8_output, fp8_grad): + if not self.fp8: + return [None] * 6 + grad_input_quantizer = None + grad_weight_quantizer = None + grad_output_quantizer = None + output_quantizer = None + input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer.internal = True + (weight_quantizer,) = self._get_weight_quantizers() + if fp8_output: + output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + if torch.is_grad_enabled(): + grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer.internal = True + if fp8_grad: + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) + + def _get_debug_quantizers(self, fp8_output, fp8_grad): + original_quantizers = self._get_quantizers(fp8_output, fp8_grad) + assert TEDebugState.debug_enabled + from ...debug.pytorch.debug_quantization import DebugQuantizer + + names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return tuple( + DebugQuantizer(self.name, name, q, self.tp_group) + for name, q in zip(names, original_quantizers) + ) + + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + """Get the weight tensors of the module.""" + unfused_weights = [getattr(self, name) for name in self.weight_names] + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): + if self.fp8: + if len(unfused_weights) != 1: + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) + else: + warnings.warn( + "You are using quantized weights without quantized compute. " + "Please make sure this is intentional." + ) + unfused_weights = [w.dequantize() for w in unfused_weights] + return unfused_weights + + def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Get concatenated weight and bias tensors + unfused_weights = self._get_weight_tensors() + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): + if self.fp8: + if len(unfused_weights) != 1: + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) + else: + warnings.warn( + "You are using quantized weights without quantized compute. " + "Please make sure this is intentional." + ) + unfused_weights = [w.dequantize() for w in unfused_weights] + + weight_tensor = noop_cat(unfused_weights) + if self.use_bias: + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + else: + bias_tensor = None + + return weight_tensor, bias_tensor + + def onnx_forward( + self, + inp: torch.Tensor, + fp8_output: bool, + ) -> torch.Tensor: + """ + ONNX-compatible version of the forward function that provides numerical equivalence + while only using operations that have defined ONNX symbolic translations. + This simplified implementation is designed specifically for inference scenarios. + """ + from ..export import onnx_gemm + + assert_warmed_up(self) + assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export." + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + ( + input_quantizer, + weight_quantizer, + output_quantizer, + *_, + ) = self._get_quantizers(fp8_output, False) + inp_dtype = inp.dtype + + if input_quantizer is not None: + inp_q = input_quantizer.onnx_quantize(inp) + inp = input_quantizer.onnx_dequantize(inp_q) + inp = inp.to(inp_dtype) + + if weight_quantizer is not None: + weight_q = weight_quantizer.onnx_quantize(weight_tensor) + weight_tensor = weight_quantizer.onnx_dequantize(weight_q) + if bias_tensor is not None: + bias_tensor = bias_tensor.to(inp_dtype) + weight_tensor = weight_tensor.to(inp_dtype) + + if self.apply_bias: + output = onnx_gemm(weight_tensor, inp, bias_tensor) + else: + output = onnx_gemm(weight_tensor, inp, None) + + if output_quantizer is not None: + raise NotImplementedError("ONNX export of quantized output is not supported") + + if self.return_bias: + return output, bias_tensor + + return output + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # also set weight quantizer with same amax_epsilon & power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # paralle related + if self.sequence_parallel and self.parallel_mode == "column": + # customize input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + # set grad_output_quantizer with amax epsilon and power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + # parallel related + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # customize input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + + def _get_weight_quantizers(self) -> List[Quantizer]: + """Get the weight quantizers of the module.""" + if not self.fp8 and not self.fp8_calibration: + return [None] + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + return [weight_quantizer] + + def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on blockwise scaling recipe + linear.""" + assert ( + recipe.float8_block_scaling() + ), "blockwise scaling recipe quantizer customization here" + + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # set compact for inp tensor X + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].all_gather_usage = True + else: + if self.sequence_parallel and self.parallel_mode == "row": + # set compact for grad_output tensor dY + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].all_gather_usage = True From ffe78795843dc065f27b3eb20873ba71c78c05f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Sep 2025 21:31:43 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../distributed/test_fused_linear_comms.py | 54 +-- .../pytorch/distributed/test_linear_comms.py | 42 ++- .../include/transformer_engine/ubnext.h | 11 +- transformer_engine/common/ubnext.cu | 322 ++++++++++-------- .../common/util/pybind_helper.h | 33 +- .../pytorch/cpp_extensions/symm_allocator.py | 88 +++-- .../pytorch/module/layernorm_linear.py | 13 +- transformer_engine/pytorch/module/linear.py | 24 +- 8 files changed, 368 insertions(+), 219 deletions(-) diff --git a/tests/pytorch/distributed/test_fused_linear_comms.py b/tests/pytorch/distributed/test_fused_linear_comms.py index 7b230ea4f9..b2e07437da 100644 --- a/tests/pytorch/distributed/test_fused_linear_comms.py +++ b/tests/pytorch/distributed/test_fused_linear_comms.py @@ -27,7 +27,12 @@ def main(): "--cuda_graph", action="store_true", help="Use CUDA Graphs (pass this flag to enable)" ) parser.add_argument("--validate", action="store_true", help="Validate allreduce ubnext") - parser.add_argument("--comm_type", type=str, default="sym", help="Comm type: none,nccl,sym,ub,ubnext,ubnext_add,ubnext_rms") + parser.add_argument( + "--comm_type", + type=str, + default="sym", + help="Comm type: none,nccl,sym,ub,ubnext,ubnext_add,ubnext_rms", + ) parser.add_argument( "--sym_type", type=str, @@ -141,7 +146,11 @@ def main(): fc1 = te.LayerNormLinear( in_features=args.hidden_size, - out_features=args.hidden_size*args.fc_factor//tp_size if args.comm_type == "none" else args.hidden_size*args.fc_factor, + out_features=( + args.hidden_size * args.fc_factor // tp_size + if args.comm_type == "none" + else args.hidden_size * args.fc_factor + ), bias=False, device=device, params_dtype=torch.bfloat16, @@ -156,7 +165,7 @@ def main(): ub_overlap_ag=args.comm_type == "ub", ub_name="fc1" if args.comm_type == "ub" else None, ) - + if args.comm_type == "ubnext_add_rms": proj.layer_norm_weight = fc1.layer_norm_weight.data # Create CUDA stream @@ -167,22 +176,29 @@ def main(): for logbatch in range(int(math.log2(args.batch_size)) + 1): batch = 2**logbatch - if args.comm_type == "ub":# and batch < tp_size: - batch = args.batch_size#tp_size + if args.comm_type == "ub": # and batch < tp_size: + batch = args.batch_size # tp_size # Create input tensor torch.cuda.synchronize() torch.distributed.barrier(group=torch.distributed.group.WORLD) - residual = torch.randn(batch//tp_size if args.comm_type == "ub" else batch, args.hidden_size, device=device, dtype=torch.bfloat16) + residual = torch.randn( + batch // tp_size if args.comm_type == "ub" else batch, + args.hidden_size, + device=device, + dtype=torch.bfloat16, + ) inp = torch.randn(batch, args.hidden_size // tp_size, device=device, dtype=torch.bfloat16) - + # Warm-up run if not args.comm_type.startswith("ubnext_add"): - out_proj=proj(inp) + out_proj = proj(inp) out_proj.add_(residual) - out=fc1(out_proj) + out = fc1(out_proj) else: - out=fc1(proj(inp,residual=residual)) # this also allocates distributed internal residual - + out = fc1( + proj(inp, residual=residual) + ) # this also allocates distributed internal residual + torch.cuda.synchronize() if args.cuda_graph: with torch.cuda.stream(stream): @@ -191,11 +207,11 @@ def main(): with torch.cuda.graph(graph): if not args.comm_type.startswith("ubnext_add"): - out_proj=proj(inp) + out_proj = proj(inp) out_proj.add_(residual) - out=fc1(out_proj) + out = fc1(out_proj) else: - out=fc1(proj(inp)) + out = fc1(proj(inp)) # Warm-up the graph for _ in range(5): @@ -215,11 +231,11 @@ def main(): graph.replay() else: if not args.comm_type.startswith("ubnext_add"): - out_proj=proj(inp) + out_proj = proj(inp) out_proj.add_(residual) - out=fc1(out_proj) + out = fc1(out_proj) else: - out=fc1(proj(inp)) + out = fc1(proj(inp)) torch.cuda.synchronize() end_time = time.time() @@ -227,9 +243,7 @@ def main(): # Calculate and print elapsed time (only on rank 0) if myrank == 0: - print( - f"Batch{batch},{(elapsed/ args.iterations) * 1e6:.4f}" - ) + print(f"Batch{batch},{(elapsed/ args.iterations) * 1e6:.4f}") if args.cuda_graph: # needed or NCCL would hang del graph diff --git a/tests/pytorch/distributed/test_linear_comms.py b/tests/pytorch/distributed/test_linear_comms.py index 2e7e9c0bf6..17668bafc9 100644 --- a/tests/pytorch/distributed/test_linear_comms.py +++ b/tests/pytorch/distributed/test_linear_comms.py @@ -118,7 +118,11 @@ def main(): dtype=torch.bfloat16, zero_centered_gamma=False, ) - residual = torch.randn(args.batch_size, args.out_features, device=device, dtype=torch.bfloat16) if args.rmsnorm else None + residual = ( + torch.randn(args.batch_size, args.out_features, device=device, dtype=torch.bfloat16) + if args.rmsnorm + else None + ) ln_weight = modelnorm.weight.data if args.rmsnorm else None if ( @@ -185,23 +189,37 @@ def main(): torch.manual_seed(57) torch.cuda.manual_seed(57) residual = torch.randn(1, args.out_features, dtype=torch.bfloat16, device=device) - t = allocator.create_tensor((1,args.out_features,), dtype=torch.bfloat16) - #te.cpp_extensions.symm_allocator.ubsymm_free_residual(t) + t = allocator.create_tensor( + ( + 1, + args.out_features, + ), + dtype=torch.bfloat16, + ) + # te.cpp_extensions.symm_allocator.ubsymm_free_residual(t) t.fill_(myrank) t_in = t.clone() torch.distributed.all_reduce(t_in) t_in.add_(residual) - out1=modelnorm(t_in) - out2 = allocator.allreduce_simple(t,hidden_size=args.out_features,residual_in=residual,residual_out=residual,fuse_layernorm=True,eps=args.eps,gamma=modelnorm.weight.data) + out1 = modelnorm(t_in) + out2 = allocator.allreduce_simple( + t, + hidden_size=args.out_features, + residual_in=residual, + residual_out=residual, + fuse_layernorm=True, + eps=args.eps, + gamma=modelnorm.weight.data, + ) abs_diff = torch.abs(out1 - out2) max_delta = torch.max(abs_diff).item() num_different = torch.sum(out1 != out2).item() print(f"FUSED RMSNorm Max delta: {max_delta}, num different: {num_different}") - if(myrank== 0): + if myrank == 0: print(f"gamma: {modelnorm.weight.data}") print(f"FUSED RMSNorm output: {out1}") print(f"FUSED RMSNorm output: {out2}") - + # Test different tensor sizes from 64 to 1024*1024 elements all_max_deltas = [] all_num_different = [] @@ -277,9 +295,9 @@ def main(): batch, int(args.in_features / tp_size), device=device, dtype=torch.bfloat16 ) # Warm-up run - out=modelseq(inp) + out = modelseq(inp) modelnorm(out) - modelpar(inp,residual=residual) + modelpar(inp, residual=residual) torch.cuda.synchronize() if args.cuda_graph: with torch.cuda.stream(stream): @@ -289,10 +307,10 @@ def main(): with torch.cuda.graph(gseq): output = modelseq(inp) if args.rmsnorm: - output.add_(residual[:batch,:args.out_features]) - output=modelnorm(output) + output.add_(residual[:batch, : args.out_features]) + output = modelnorm(output) with torch.cuda.graph(gpar): - output = modelpar(inp,residual=residual) + output = modelpar(inp, residual=residual) # Warm-up the graph for _ in range(5): gseq.replay() diff --git a/transformer_engine/common/include/transformer_engine/ubnext.h b/transformer_engine/common/include/transformer_engine/ubnext.h index aa82a27c33..0e1b304a61 100644 --- a/transformer_engine/common/include/transformer_engine/ubnext.h +++ b/transformer_engine/common/include/transformer_engine/ubnext.h @@ -16,12 +16,17 @@ extern "C" { #endif void allreduce_2shot_mc(int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* mcptr_in, - void* mcptr_out, size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size, cudaStream_t stream); + void* mcptr_out, size_t bytes, void* residual_in, void* residual_out, + bool fuse_layernorm, void* gamma, float eps, const int hidden_size, + cudaStream_t stream); void allreduce_2shot_mc_lamport(int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* ucptr_out, void* mcptr_in, void* mcptr_out, void* clear_ptr, size_t bytes, - bool poisoned, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size, cudaStream_t stream); + bool poisoned, void* residual_in, void* residual_out, + bool fuse_layernorm, void* gamma, float eps, const int hidden_size, + cudaStream_t stream); void allreduce_2shot_uc(int ranks, int myrank, void* uc0ptr, void* ucptr_in, void* ucptr_out, - size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size, cudaStream_t stream); + size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, + void* gamma, float eps, const int hidden_size, cudaStream_t stream); #ifdef __cplusplus } diff --git a/transformer_engine/common/ubnext.cu b/transformer_engine/common/ubnext.cu index 6358b07c85..8aa5287f9e 100644 --- a/transformer_engine/common/ubnext.cu +++ b/transformer_engine/common/ubnext.cu @@ -56,53 +56,49 @@ #define FINAL_MASK 0xffffffff template -__inline__ __device__ T warpReduceSumV2(T* val) -{ +__inline__ __device__ T warpReduceSumV2(T *val) { #pragma unroll - for (int i = 0; i < NUM; i++) - { + for (int i = 0; i < NUM; i++) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); - } - return (T) (0.0f); + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T)(0.0f); } template -__inline__ __device__ T blockReduceSumV2(T* val) -{ - static __shared__ T shared[NUM][33]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; +__inline__ __device__ T blockReduceSumV2(T *val) { + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; - warpReduceSumV2(val); + warpReduceSumV2(val); - if (lane == 0) - { + if (lane == 0) { #pragma unroll - for (int i = 0; i < NUM; i++) - { - shared[i][wid] = val[i]; - } + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; } + } - __syncthreads(); + __syncthreads(); - bool is_mask = threadIdx.x < (blockDim.x / 32.f); + bool is_mask = threadIdx.x < (blockDim.x / 32.f); #pragma unroll - for (int i = 0; i < NUM; i++) - { - val[i] = is_mask ? shared[i][lane] : (T) (0.0f); - } - warpReduceSumV2(val); - return (T) 0.0f; + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSumV2(val); + return (T)0.0f; } template __global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) userbuffers_fp16_sum_inplace_gpu_mc(const int RANKS, const int myrank, const int mylines, int *uc_flagptr, int *mc_flagptr, uint4 *mc_ptr_in, - uint4 *mc_ptr_out, uint4 *residual_in, uint4 *residual_out, xhalf* gamma, float eps, const int hidden_size, bool fuse_layernorm) { + uint4 *mc_ptr_out, uint4 *residual_in, uint4 *residual_out, + xhalf *gamma, float eps, const int hidden_size, + bool fuse_layernorm) { // flags[3,4,5,6]: reduce_id, sm_sync-local, flag-barrier-1,flag-barrier-2 int reduce_id; __shared__ float s_variance; @@ -135,7 +131,7 @@ __global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) const int loop_step0 = blockDim.x; const int loop_step = loop_step0 * UNROLL * gridDim.x; - const int start_elem = threadIdx.x + blockDim.x*blockIdx.x*UNROLL; + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL; const int end_elem = max(start_elem, mylines); //const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; //const int end_aligned = start_elem + aligned_elem; @@ -145,42 +141,42 @@ __global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) xhalf *x = reinterpret_cast(&val[0]); #pragma unroll for (int i = 0; i < UNROLL; i++) MULTIMEM_LD(val[i], mc_ptr_in + (line + i * loop_step0)) - - if(residual_in!=nullptr) { + + if (residual_in != nullptr) { for (int i = 0; i < UNROLL; i++) { - uint4 resval = residual_in[line+i*loop_step0]; + uint4 resval = residual_in[line + i * loop_step0]; xhalf *y = reinterpret_cast(&resval); - #pragma unroll - for (int j = 0; j < 8; j++) - x[i*8+j] += y[j]; - if(residual_out!=nullptr) - residual_out[line+i*loop_step0]=val[i]; +#pragma unroll + for (int j = 0; j < 8; j++) x[i * 8 + j] += y[j]; + if (residual_out != nullptr) residual_out[line + i * loop_step0] = val[i]; } } - if(fuse_layernorm) { + if (fuse_layernorm) { float local_var_sum = 0.0f; - for (int j = 0; j < UNROLL*sizeof(int4) / sizeof(xhalf); j++) - local_var_sum += (float)(x[j])*(float)(x[j]); + for (int j = 0; j < UNROLL * sizeof(int4) / sizeof(xhalf); j++) + local_var_sum += (float)(x[j]) * (float)(x[j]); float packed[1] = {local_var_sum}; blockReduceSumV2(packed); float variance = packed[0]; - if (threadIdx.x == 0) - { - variance = (variance / hidden_size); // Var[x] = E[x²] - s_variance = rsqrtf(variance + eps); + if (threadIdx.x == 0) { + variance = (variance / hidden_size); // Var[x] = E[x²] + s_variance = rsqrtf(variance + eps); } __syncthreads(); } - int i=0; + int i = 0; #pragma unroll for (int g = 0; g < UNROLL; g++) { - if(fuse_layernorm) { - #pragma unroll + if (fuse_layernorm) { +#pragma unroll for (int j = 0; j < sizeof(int4) / sizeof(xhalf); j++) { - x[i] = (xhalf)((float)(x[i]) * s_variance * (float) gamma[(threadIdx.x+g*loop_step0)*sizeof(int4)/sizeof(xhalf)+j]); + x[i] = + (xhalf)((float)(x[i]) * s_variance * + (float) + gamma[(threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(xhalf) + j]); i++; } } @@ -241,7 +237,9 @@ template __global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) userbuffers_fp16_sum_inplace_gpu_uc(const int myrank, const int numlines, const int lineoffset_in, const int lineoffset_out, - int *uc_flagptr, void **commbuff, uint4 *residual_in, uint4 *residual_out, xhalf* gamma, float eps, const int hidden_size, bool fuse_layernorm) { + int *uc_flagptr, void **commbuff, uint4 *residual_in, + uint4 *residual_out, xhalf *gamma, float eps, + const int hidden_size, bool fuse_layernorm) { // flags[3,4,5,6]: reduce_id, sm_sync-local, flag-barrier-1,flag-barrier-2 //NB! uc_flagptr is shifted by ranks*8 for easier flag offsets // while lineoffset is relative to start of reg0 @@ -305,14 +303,12 @@ __global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) for (int j = 0; j < 8; j++) s[j] += x[j]; } - if(residual_in!=nullptr) { + if (residual_in != nullptr) { uint4 resval = residual_in[lineoffset_in + line]; xhalf *y = reinterpret_cast(&resval); - #pragma unroll - for (int j = 0; j < 8; j++) - s[j] += y[j]; - if(residual_out!=nullptr) - residual_out[lineoffset_in + line]=sum; +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += y[j]; + if (residual_out != nullptr) residual_out[lineoffset_in + line] = sum; } #pragma unroll @@ -363,9 +359,14 @@ __global__ void memset_int(uint32_t *data, int n, uint32_t val) { } template -__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) userbuffers_fp16_sum_inplace_gpu_mc_lamport( - const int RANKS, const int myrank, const int mylines, const int numlines, int *uc_flagptr, int *mc_flagptr, - uint4 *mc_ptr_in, uint4 *mc_ptr_out, uint4 *uc_ptr_out, uint4 *clear_ptr, uint4 *residual_in, uint4 *residual_out, xhalf* gamma, float eps, const int hidden_size, bool fuse_layernorm) { +__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) + userbuffers_fp16_sum_inplace_gpu_mc_lamport(const int RANKS, const int myrank, + const int mylines, const int numlines, + int *uc_flagptr, int *mc_flagptr, uint4 *mc_ptr_in, + uint4 *mc_ptr_out, uint4 *uc_ptr_out, + uint4 *clear_ptr, uint4 *residual_in, + uint4 *residual_out, xhalf *gamma, float eps, + const int hidden_size, bool fuse_layernorm) { // flags[0,1,2]: reduce_id, sm_sync-local, flag-barrier // those go right after rank UC pointers, but its the CPU caller who should account for it int reduce_id; @@ -405,7 +406,7 @@ __global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) userbuffers_fp16_sum_inpla const int loop_step0 = blockDim.x; const int loop_step = loop_step0 * UNROLL * gridDim.x; - const int start_elem = threadIdx.x + blockDim.x*blockIdx.x*UNROLL; + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL; const int end_elem = max(start_elem, mylines); for (int line = start_elem; line < end_elem; line += loop_step) { @@ -413,42 +414,42 @@ __global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) userbuffers_fp16_sum_inpla xhalf *x = reinterpret_cast(&val[0]); #pragma unroll for (int i = 0; i < UNROLL; i++) MULTIMEM_LD(val[i], mc_ptr_in + (line + i * loop_step0)) - - if(residual_in!=nullptr) { + + if (residual_in != nullptr) { for (int i = 0; i < UNROLL; i++) { - uint4 resval = residual_in[line+i*loop_step0]; + uint4 resval = residual_in[line + i * loop_step0]; xhalf *y = reinterpret_cast(&resval); - #pragma unroll - for (int j = 0; j < 8; j++) - x[i*8+j] += y[j]; - if(residual_out!=nullptr) - residual_out[line+i*loop_step0]=val[i]; +#pragma unroll + for (int j = 0; j < 8; j++) x[i * 8 + j] += y[j]; + if (residual_out != nullptr) residual_out[line + i * loop_step0] = val[i]; } } - if(fuse_layernorm) { + if (fuse_layernorm) { float local_var_sum = 0.0f; - for (int j = 0; j < UNROLL*sizeof(int4) / sizeof(xhalf); j++) - local_var_sum += (float)(x[j])*(float)(x[j]); + for (int j = 0; j < UNROLL * sizeof(int4) / sizeof(xhalf); j++) + local_var_sum += (float)(x[j]) * (float)(x[j]); float packed[1] = {local_var_sum}; blockReduceSumV2(packed); float variance = packed[0]; - if (threadIdx.x == 0) - { - variance = (variance / hidden_size); // Var[x] = E[x²] - s_variance = rsqrtf(variance + eps); + if (threadIdx.x == 0) { + variance = (variance / hidden_size); // Var[x] = E[x²] + s_variance = rsqrtf(variance + eps); } __syncthreads(); } - int i=0; + int i = 0; #pragma unroll for (int g = 0; g < UNROLL; g++) { - if(fuse_layernorm) { - #pragma unroll + if (fuse_layernorm) { +#pragma unroll for (int j = 0; j < sizeof(int4) / sizeof(xhalf); j++) { - x[i] = (xhalf)((float)(x[i]) * s_variance * (float) gamma[(threadIdx.x+g*loop_step0)*sizeof(int4)/sizeof(xhalf)+j]); + x[i] = + (xhalf)((float)(x[i]) * s_variance * + (float) + gamma[(threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(xhalf) + j]); i++; } } @@ -499,70 +500,86 @@ __global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) userbuffers_fp16_sum_inpla namespace transformer_engine { - #define split_tokens(x) \ - const int elements = bytes/sizeof(half); \ - const int elements_per_thread = sizeof(uint4)/sizeof(half); \ - int nthreads=1024, nlines=4; \ - size_t total_bytes = bytes/ranks, start_bytes = myrank*total_bytes; \ - int sms=x; \ - if(hidden_size) { \ - assert(hidden_size<=32768); \ - assert(elements % hidden_size==0); \ - assert(hidden_size%elements_per_thread==0); \ - int ntokens = elements/hidden_size; \ - int my_tokens = ntokens / ranks; \ - int extra_tokens = ntokens % ranks; \ - int first_token = myrank*my_tokens; \ - first_token+= myrank1024) { \ - nlines++; \ - assert(nlines<=4); \ - if((hidden_size/elements_per_thread)%nlines==0) \ - nthreads=((hidden_size/elements_per_thread))/nlines; \ - } \ - if(sms>my_tokens) sms=my_tokens; \ - if (sms==0) sms=1; \ - } \ - bool residual_in_global = residual_in!=nullptr && residual_in!=residual_out && residual_out!=nullptr; // out residual is always local +#define split_tokens(x) \ + const int elements = bytes / sizeof(half); \ + const int elements_per_thread = sizeof(uint4) / sizeof(half); \ + int nthreads = 1024, nlines = 4; \ + size_t total_bytes = bytes / ranks, start_bytes = myrank * total_bytes; \ + int sms = x; \ + if (hidden_size) { \ + assert(hidden_size <= 32768); \ + assert(elements % hidden_size == 0); \ + assert(hidden_size % elements_per_thread == 0); \ + int ntokens = elements / hidden_size; \ + int my_tokens = ntokens / ranks; \ + int extra_tokens = ntokens % ranks; \ + int first_token = myrank * my_tokens; \ + first_token += myrank < extra_tokens ? myrank : extra_tokens; \ + if (myrank < extra_tokens) my_tokens++; \ + start_bytes = first_token * hidden_size * sizeof(half); \ + total_bytes = my_tokens * hidden_size * sizeof(half); \ + nthreads = hidden_size / elements_per_thread; \ + nlines = 1; \ + while (nthreads > 1024) { \ + nlines++; \ + assert(nlines <= 4); \ + if ((hidden_size / elements_per_thread) % nlines == 0) \ + nthreads = ((hidden_size / elements_per_thread)) / nlines; \ + } \ + if (sms > my_tokens) sms = my_tokens; \ + if (sms == 0) sms = 1; \ + } \ + bool residual_in_global = residual_in != nullptr && residual_in != residual_out && \ + residual_out != nullptr; // out residual is always local extern "C" void allreduce_2shot_mc(int ranks, int myrank, void *uc0ptr, void *mc0ptr, - void *mcptr_in, void *mcptr_out, size_t bytes, - void *residual_in, void *residual_out, bool fuse_layernorm, - void* gamma, float eps, const int hidden_size, - cudaStream_t stream) { + void *mcptr_in, void *mcptr_out, size_t bytes, void *residual_in, + void *residual_out, bool fuse_layernorm, void *gamma, float eps, + const int hidden_size, cudaStream_t stream) { split_tokens(32); SETUP_LAUNCH_CONFIG(sms, nthreads, stream, 4, 1); int arg1 = ranks, arg2 = myrank, arg3 = total_bytes / sizeof(uint4); - void *arg4 = uc0ptr + (ranks * 8), *arg5 = mc0ptr + (ranks * 8), *arg6 = mcptr_in+start_bytes, - *arg7 = mcptr_out+start_bytes, *arg8 = residual_in_global?residual_in+start_bytes:residual_in, *arg9 = residual_out, *arg10 = gamma; - float arg11 = eps; int arg12 = hidden_size; bool arg13 = fuse_layernorm; - void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg4, - (void *)&arg5, (void *)&arg6, (void *)&arg7, (void *)&arg8, (void *)&arg9, (void *)&arg10, (void *)&arg11, (void *)&arg12, (void *)&arg13}; - #define call_mc_kernel(x,cond) \ - if(x==nlines || cond) {CUDACHECK(cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc), kernelArgs)); return;} - call_mc_kernel(1,false); - call_mc_kernel(2,false); - call_mc_kernel(3,false); - call_mc_kernel(4,true); + void *arg4 = uc0ptr + (ranks * 8), *arg5 = mc0ptr + (ranks * 8), *arg6 = mcptr_in + start_bytes, + *arg7 = mcptr_out + start_bytes, + *arg8 = residual_in_global ? residual_in + start_bytes : residual_in, *arg9 = residual_out, + *arg10 = gamma; + float arg11 = eps; + int arg12 = hidden_size; + bool arg13 = fuse_layernorm; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg4, + (void *)&arg5, (void *)&arg6, (void *)&arg7, (void *)&arg8, + (void *)&arg9, (void *)&arg10, (void *)&arg11, (void *)&arg12, + (void *)&arg13}; +#define call_mc_kernel(x, cond) \ + if (x == nlines || cond) { \ + CUDACHECK( \ + cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc), kernelArgs)); \ + return; \ + } + call_mc_kernel(1, false); + call_mc_kernel(2, false); + call_mc_kernel(3, false); + call_mc_kernel(4, true); } extern "C" void allreduce_2shot_uc(int ranks, int myrank, void *uc0ptr, void *ucptr_in, - void *ucptr_out, size_t bytes, void *residual_in, void *residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size, cudaStream_t stream) { + void *ucptr_out, size_t bytes, void *residual_in, + void *residual_out, bool fuse_layernorm, void *gamma, float eps, + const int hidden_size, cudaStream_t stream) { SETUP_LAUNCH_CONFIG(64, 1024, stream, 4, 1); int arg1 = myrank, arg2 = bytes / 16, arg3 = (int4 *)ucptr_in - (int4 *)uc0ptr, arg4 = (int4 *)ucptr_out - (int4 *)uc0ptr; - void *arg5 = uc0ptr + (ranks * 8), **arg6 = (void **)uc0ptr, *arg7 = residual_in, *arg8 = residual_out, *arg9 = gamma; - float arg10 = eps; int arg11 = hidden_size; bool arg12 = fuse_layernorm; - void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, - (void *)&arg4, (void *)&arg5, (void *)&arg6, (void *)&arg7, (void *)&arg8, (void *)&arg9, (void *)&arg10, (void *)&arg11, (void *)&arg12}; + void *arg5 = uc0ptr + (ranks * 8), **arg6 = (void **)uc0ptr, *arg7 = residual_in, + *arg8 = residual_out, *arg9 = gamma; + float arg10 = eps; + int arg11 = hidden_size; + bool arg12 = fuse_layernorm; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg4, + (void *)&arg5, (void *)&arg6, (void *)&arg7, (void *)&arg8, + (void *)&arg9, (void *)&arg10, (void *)&arg11, (void *)&arg12}; #define call_uc_kernel(x) \ if (x == ranks) \ CUDACHECK( \ @@ -575,9 +592,9 @@ extern "C" void allreduce_2shot_uc(int ranks, int myrank, void *uc0ptr, void *uc extern "C" void allreduce_2shot_mc_lamport(int ranks, int myrank, void *uc0ptr, void *mc0ptr, void *ucptr_out, void *mcptr_in, void *mcptr_out, void *clear_ptr, size_t bytes, bool poisoned, - void *residual_in,void* residual_out, bool fuse_layernorm, - void* gamma, float eps, const int hidden_size, - cudaStream_t stream) { + void *residual_in, void *residual_out, + bool fuse_layernorm, void *gamma, float eps, + const int hidden_size, cudaStream_t stream) { if (!poisoned) { //user tells us destination was not pre-poisoned, so we need to do it before calling allreduce int threadsPerBlock = 512; @@ -586,23 +603,34 @@ extern "C" void allreduce_2shot_mc_lamport(int ranks, int myrank, void *uc0ptr, NVTE_UB_LAMPORT_INT); } split_tokens(64); - - SETUP_LAUNCH_CONFIG(64, nthreads, stream, 4, 1); - - int arg1 = ranks, arg2 = myrank, arg3 = total_bytes / sizeof(uint4), arg3a = bytes / sizeof(uint4); - void *arg4 = uc0ptr + (ranks * 8), *arg5 = mc0ptr + (ranks * 8), *arg6 = mcptr_in+start_bytes, - *arg7 = mcptr_out+start_bytes, *arg8 = ucptr_out, *arg9 = clear_ptr, *arg10 = residual_in_global?residual_in+start_bytes:residual_in, *arg11 = residual_out, *arg12 = gamma; - float arg13 = eps; int arg14 = hidden_size; bool arg15 = fuse_layernorm; - void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg3a, (void *)&arg4, (void *)&arg5, - (void *)&arg6, (void *)&arg7, (void *)&arg8, (void *)&arg9, (void *)&arg10, (void *)&arg11, (void *)&arg12, (void *)&arg13, (void *)&arg14, (void *)&arg15}; - #define call_mc_lamport_kernel(x,cond) \ - if(x==nlines || cond) {CUDACHECK(cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc_lamport), kernelArgs)); return;} + SETUP_LAUNCH_CONFIG(64, nthreads, stream, 4, 1); - call_mc_lamport_kernel(1,false); - call_mc_lamport_kernel(2,false); - call_mc_lamport_kernel(3,false); - call_mc_lamport_kernel(4,true); + int arg1 = ranks, arg2 = myrank, arg3 = total_bytes / sizeof(uint4), + arg3a = bytes / sizeof(uint4); + void *arg4 = uc0ptr + (ranks * 8), *arg5 = mc0ptr + (ranks * 8), *arg6 = mcptr_in + start_bytes, + *arg7 = mcptr_out + start_bytes, *arg8 = ucptr_out, *arg9 = clear_ptr, + *arg10 = residual_in_global ? residual_in + start_bytes : residual_in, *arg11 = residual_out, + *arg12 = gamma; + float arg13 = eps; + int arg14 = hidden_size; + bool arg15 = fuse_layernorm; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg3a, + (void *)&arg4, (void *)&arg5, (void *)&arg6, (void *)&arg7, + (void *)&arg8, (void *)&arg9, (void *)&arg10, (void *)&arg11, + (void *)&arg12, (void *)&arg13, (void *)&arg14, (void *)&arg15}; + +#define call_mc_lamport_kernel(x, cond) \ + if (x == nlines || cond) { \ + CUDACHECK(cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc_lamport), \ + kernelArgs)); \ + return; \ } + call_mc_lamport_kernel(1, false); + call_mc_lamport_kernel(2, false); + call_mc_lamport_kernel(3, false); + call_mc_lamport_kernel(4, true); +} + } // namespace transformer_engine diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 6639b391f7..240a18e8a7 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -142,30 +142,41 @@ m.def( \ "allreduce_2shot_mc", \ [](int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* mcptr_in, void* mcptr_out, \ - size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size) { \ - transformer_engine::allreduce_2shot_mc(ranks, myrank, uc0ptr, mc0ptr, mcptr_in, mcptr_out, \ - bytes, residual_in, residual_out, fuse_layernorm, gamma, eps, hidden_size, at::cuda::getCurrentCUDAStream()); \ + size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, \ + float eps, const int hidden_size) { \ + transformer_engine::allreduce_2shot_mc( \ + ranks, myrank, uc0ptr, mc0ptr, mcptr_in, mcptr_out, bytes, residual_in, residual_out, \ + fuse_layernorm, gamma, eps, hidden_size, at::cuda::getCurrentCUDAStream()); \ }, \ py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("mc0ptr"), \ - py::arg("mcptr_in"), py::arg("mcptr_out"), py::arg("bytes"), py::arg("residual_in"), py::arg("residual_out"), py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), py::arg("hidden_size")); \ + py::arg("mcptr_in"), py::arg("mcptr_out"), py::arg("bytes"), py::arg("residual_in"), \ + py::arg("residual_out"), py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), \ + py::arg("hidden_size")); \ m.def( \ "allreduce_2shot_uc", \ - [](int ranks, int myrank, void* uc0ptr, void* ucptr_in, void* ucptr_out, size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size) { \ - transformer_engine::allreduce_2shot_uc(ranks, myrank, uc0ptr, ucptr_in, ucptr_out, bytes, \ - residual_in, residual_out, fuse_layernorm, gamma, eps, hidden_size, at::cuda::getCurrentCUDAStream()); \ + [](int ranks, int myrank, void* uc0ptr, void* ucptr_in, void* ucptr_out, size_t bytes, \ + void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, \ + const int hidden_size) { \ + transformer_engine::allreduce_2shot_uc( \ + ranks, myrank, uc0ptr, ucptr_in, ucptr_out, bytes, residual_in, residual_out, \ + fuse_layernorm, gamma, eps, hidden_size, at::cuda::getCurrentCUDAStream()); \ }, \ py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("ucptr_in"), \ - py::arg("ucptr_out"), py::arg("bytes"), py::arg("residual_in"), py::arg("residual_out"), py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), py::arg("hidden_size")); \ + py::arg("ucptr_out"), py::arg("bytes"), py::arg("residual_in"), py::arg("residual_out"), \ + py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), py::arg("hidden_size")); \ m.def( \ "allreduce_2shot_mc_lamport", \ [](int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* ucptr_out, void* mcptr_in, \ - void* mcptr_out, void* clear_ptr, size_t bytes, bool poisoned, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size) { \ + void* mcptr_out, void* clear_ptr, size_t bytes, bool poisoned, void* residual_in, \ + void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size) { \ transformer_engine::allreduce_2shot_mc_lamport( \ ranks, myrank, uc0ptr, mc0ptr, ucptr_out, mcptr_in, mcptr_out, clear_ptr, bytes, \ - poisoned, residual_in, residual_out, fuse_layernorm, gamma, eps, hidden_size, at::cuda::getCurrentCUDAStream()); \ + poisoned, residual_in, residual_out, fuse_layernorm, gamma, eps, hidden_size, \ + at::cuda::getCurrentCUDAStream()); \ }, \ py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("mc0ptr"), \ py::arg("ucptr_out"), py::arg("mcptr_in"), py::arg("mcptr_out"), py::arg("clear_ptr"), \ - py::arg("bytes"), py::arg("poisoned"), py::arg("residual_in"), py::arg("residual_out"), py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), py::arg("hidden_size")); + py::arg("bytes"), py::arg("poisoned"), py::arg("residual_in"), py::arg("residual_out"), \ + py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), py::arg("hidden_size")); #endif diff --git a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py index 103d3884e1..46253affea 100644 --- a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py +++ b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py @@ -7,6 +7,7 @@ import torch.distributed._symmetric_memory as symm_mem from ctypes import pythonapi, c_void_p, py_object + def to_capsule(ptr): # Set the return type to py_object to get a Python object (PyCapsule) pythonapi.PyCapsule_New.restype = py_object @@ -82,7 +83,7 @@ def __init__(self, size_bytes: int, device: torch.device, dist_group: torch.dist else: alignment = 2 * 1024 * 1024 # memory is allocated in 2MB pages anyways self.pool_size = int((size_bytes + alignment - 1) / alignment) * alignment -# symm_mem.set_backend("NCCL") + # symm_mem.set_backend("NCCL") self.internal_pool = symm_mem.empty(self.pool_size, dtype=torch.uint8, device=device) self.hdl0 = symm_mem.rendezvous(self.internal_pool, dist_group) self.mc0_ptr = self.hdl0.multicast_ptr @@ -170,7 +171,16 @@ def create_tensor( self.tensors.add(tensor) return tensor - def allreduce_uc(self, tensor_in: torch.Tensor, hidden_size: int = 0, residual_in: Optional[torch.Tensor] = None,residual_out: Optional[torch.Tensor] = None, fuse_layernorm: bool = False, gamma: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> torch.Tensor: + def allreduce_uc( + self, + tensor_in: torch.Tensor, + hidden_size: int = 0, + residual_in: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + fuse_layernorm: bool = False, + gamma: Optional[torch.Tensor] = None, + eps: Optional[float] = None, + ) -> torch.Tensor: """Performs in-place allreduce on the given SymmTensor using best algo""" assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" @@ -195,11 +205,20 @@ def allreduce_uc(self, tensor_in: torch.Tensor, hidden_size: int = 0, residual_i fuse_layernorm, to_capsule(gamma.data_ptr()) if gamma is not None else None, eps if eps is not None else 0.0, - hidden_size + hidden_size, ) return tensor_in - def allreduce_simple(self, tensor_in: torch.Tensor, hidden_size: int = 0, residual_in: Optional[torch.Tensor] = None,residual_out: Optional[torch.Tensor] = None, fuse_layernorm: bool = False, gamma: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> torch.Tensor: + def allreduce_simple( + self, + tensor_in: torch.Tensor, + hidden_size: int = 0, + residual_in: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + fuse_layernorm: bool = False, + gamma: Optional[torch.Tensor] = None, + eps: Optional[float] = None, + ) -> torch.Tensor: """Performs in-place allreduce on the given SymmTensor using best algo""" assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" @@ -225,11 +244,20 @@ def allreduce_simple(self, tensor_in: torch.Tensor, hidden_size: int = 0, residu fuse_layernorm, to_capsule(gamma.data_ptr()) if gamma is not None else None, eps if eps is not None else 0.0, - hidden_size + hidden_size, ) return tensor_in - def allreduce_lamport(self, tensor_in: torch.Tensor, hidden_size: int = 0, residual_in: Optional[torch.Tensor] = None,residual_out: Optional[torch.Tensor] = None, fuse_layernorm: bool = False, gamma: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> torch.Tensor: + def allreduce_lamport( + self, + tensor_in: torch.Tensor, + hidden_size: int = 0, + residual_in: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + fuse_layernorm: bool = False, + gamma: Optional[torch.Tensor] = None, + eps: Optional[float] = None, + ) -> torch.Tensor: """ Performs allreduce using 2-shot multicast Lamport variant: - Takes `tensor_in` as input (SymmTensor). @@ -239,7 +267,9 @@ def allreduce_lamport(self, tensor_in: torch.Tensor, hidden_size: int = 0, resid """ assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" if self.mc0_ptr is None or self.mc0_ptr == 0: - return self.allreduce_uc(tensor_in,hidden_size,residual_in,residual_out,fuse_layernorm, gamma, eps) + return self.allreduce_uc( + tensor_in, hidden_size, residual_in, residual_out, fuse_layernorm, gamma, eps + ) from transformer_engine_torch import allreduce_2shot_mc_lamport # Allocate output tensor of same shape/dtype @@ -253,7 +283,9 @@ def allreduce_lamport(self, tensor_in: torch.Tensor, hidden_size: int = 0, resid tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) poisonedout = False if tensor_out is None: - return self.allreduce_simple(tensor_in,hidden_size,residual_in,residual_out,fuse_layernorm, gamma, eps) + return self.allreduce_simple( + tensor_in, hidden_size, residual_in, residual_out, fuse_layernorm, gamma, eps + ) # alllcate potential output for next allreduce (speculative) and poison it now self.nextpoisoned = self.create_tensor(tensor_in.shape, tensor_in.dtype) @@ -285,7 +317,7 @@ def allreduce_lamport(self, tensor_in: torch.Tensor, hidden_size: int = 0, resid fuse_layernorm, to_capsule(gamma.data_ptr()) if gamma is not None else None, eps if eps is not None else 0.0, - hidden_size + hidden_size, ) return tensor_out @@ -338,7 +370,12 @@ def ubsymm_get_sym_tensor( return allocator.create_tensor(shape, dtype) -def ubsymm_allreduce(tensor_in: SymmTensor,residual_global: Optional[torch.Tensor] = None, gamma: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> SymmTensor: +def ubsymm_allreduce( + tensor_in: SymmTensor, + residual_global: Optional[torch.Tensor] = None, + gamma: Optional[torch.Tensor] = None, + eps: Optional[float] = None, +) -> SymmTensor: """ Performs allreduce on the given SymmTensor using best algo Four modes: @@ -352,33 +389,47 @@ def ubsymm_allreduce(tensor_in: SymmTensor,residual_global: Optional[torch.Tenso fuse_layernorm = gamma is not None and eps is not None internal_residual = tensor_in._allocator.residual num_ranks = tensor_in._allocator.world_size - hidden_size = tensor_in.shape[-1] if fuse_layernorm or internal_residual is not None or residual_global is not None else tensor_in.numel() // num_ranks + hidden_size = ( + tensor_in.shape[-1] + if fuse_layernorm or internal_residual is not None or residual_global is not None + else tensor_in.numel() // num_ranks + ) num_tokens = tensor_in.numel() // hidden_size myrank = tensor_in._allocator.myrank - if residual_global is not None and (internal_residual is None or tensor_in._allocator.residual_tokens != num_tokens): + if residual_global is not None and ( + internal_residual is None or tensor_in._allocator.residual_tokens != num_tokens + ): my_tokens = num_tokens // num_ranks extra_tokens = num_tokens % num_ranks - first_token = myrank*my_tokens + first_token = myrank * my_tokens if myrank < extra_tokens: my_tokens += 1 first_token += myrank else: first_token += extra_tokens if my_tokens == 0: - my_tokens = 1 #avoid empty residual + my_tokens = 1 # avoid empty residual if tensor_in._allocator.residual is not None: del tensor_in._allocator.residual - tensor_in._allocator.residual = torch.empty(my_tokens*hidden_size, dtype=tensor_in.dtype, device=tensor_in.device) + tensor_in._allocator.residual = torch.empty( + my_tokens * hidden_size, dtype=tensor_in.dtype, device=tensor_in.device + ) tensor_in._allocator.residual_tokens = num_tokens internal_residual = tensor_in._allocator.residual residual_in = residual_global if residual_global is not None else internal_residual - residual_out = internal_residual if fuse_layernorm else None #without layernorm new full residual is output of allreduce + residual_out = ( + internal_residual if fuse_layernorm else None + ) # without layernorm new full residual is output of allreduce if tensor_in.numel() > 1048576: - return tensor_in._allocator.allreduce_simple(tensor_in,hidden_size,residual_in,residual_out,fuse_layernorm, gamma, eps) + return tensor_in._allocator.allreduce_simple( + tensor_in, hidden_size, residual_in, residual_out, fuse_layernorm, gamma, eps + ) else: - return tensor_in._allocator.allreduce_lamport(tensor_in,hidden_size,residual_in,residual_out,fuse_layernorm, gamma, eps) + return tensor_in._allocator.allreduce_lamport( + tensor_in, hidden_size, residual_in, residual_out, fuse_layernorm, gamma, eps + ) def ubsymm_free_residual(tensor_in: SymmTensor): @@ -386,4 +437,3 @@ def ubsymm_free_residual(tensor_in: SymmTensor): del tensor_in._allocator.residual tensor_in._allocator.residual_tokens = 0 tensor_in._allocator.residual = None - \ No newline at end of file diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b77cd0bcda..c2a9f68ef6 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -346,7 +346,12 @@ def forward( out_shape[-1] = out_features reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) symm_out = None - if symmetric_ar_type is not None and symmetric_ar_type.startswith("ubnext") and parallel_mode == "row" and tp_size > 1: + if ( + symmetric_ar_type is not None + and symmetric_ar_type.startswith("ubnext") + and parallel_mode == "row" + and tp_size > 1 + ): out_shape_list = list(tuple(inp.shape)) out_shape_list[-1] = out_features symm_out = ubsymm_get_sym_tensor( @@ -1299,7 +1304,11 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - if self.symmetric_ar_type.startswith("ubnext") and parallel_mode == "row" and tp_size > 1: + if ( + self.symmetric_ar_type.startswith("ubnext") + and parallel_mode == "row" + and tp_size > 1 + ): ubsymm_request_allocator( self.tp_group, ( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7d3f2a167d..598433d34c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -309,7 +309,12 @@ def forward( reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) symm_out = None - if symmetric_ar_type is not None and symmetric_ar_type.startswith("ubnext") and parallel_mode == "row" and tp_size > 1: + if ( + symmetric_ar_type is not None + and symmetric_ar_type.startswith("ubnext") + and parallel_mode == "row" + and tp_size > 1 + ): out_shape_list = list(tuple(inp.shape)) out_shape_list[-1] = out_features symm_out = ubsymm_get_sym_tensor( @@ -317,7 +322,10 @@ def forward( activation_dtype, tp_group, ) - assert symm_out is not None or symmetric_ar_type == "ubnext", "No symmetric pool out of space fallback for fused ops, increase NVTE_UB_SYMM_POOL_SIZE" + assert symm_out is not None or symmetric_ar_type == "ubnext", ( + "No symmetric pool out of space fallback for fused ops, increase" + " NVTE_UB_SYMM_POOL_SIZE" + ) # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T @@ -363,7 +371,9 @@ def forward( elif tensor_parallel: if symmetric_ar_type is not None: if symm_out is not None: - out = ubsymm_allreduce(symm_out,residual_global=residual,gamma=ln_weight,eps=eps) + out = ubsymm_allreduce( + symm_out, residual_global=residual, gamma=ln_weight, eps=eps + ) else: fallback_symmetric = ( "multimem_all_reduce" @@ -1230,7 +1240,11 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - if self.symmetric_ar_type.startswith("ubnext") and parallel_mode == "row" and tp_size > 1: + if ( + self.symmetric_ar_type.startswith("ubnext") + and parallel_mode == "row" + and tp_size > 1 + ): ubsymm_request_allocator( self.tp_group, ( @@ -1240,7 +1254,7 @@ def __init__( params_dtype, ) self.eps = eps - self.layer_norm_weight = ln_weight # in general expected to be filled with reference to layernorm_weight from next LayerNormLinear later + self.layer_norm_weight = ln_weight # in general expected to be filled with reference to layernorm_weight from next LayerNormLinear later # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() From 86b3757c1c038f4062d83c936c2beac7594e84ca Mon Sep 17 00:00:00 2001 From: Anton Korzh Date: Mon, 29 Sep 2025 14:47:45 -0700 Subject: [PATCH 3/3] merge cleanup --- transformer_engine/common/CMakeLists.txt | 3 +- transformer_engine/common/CMakeLists.txt.orig | 298 --- transformer_engine/common/CMakeLists.txt.rej | 12 - .../comm_gemm_overlap.cpp.orig | 1210 ----------- .../transformer_engine/comm_gemm_overlap.h | 3 + .../comm_gemm_overlap.h.orig | 327 --- .../comm_gemm_overlap.h.rej | 11 - .../common/util/pybind_helper.h.orig | 140 -- .../extensions/comm_gemm_overlap.cpp.orig | 320 --- .../pytorch/module/base.py.orig | 1597 -------------- .../pytorch/module/layernorm_linear.py.orig | 1827 ----------------- .../pytorch/module/linear.py.orig | 1710 --------------- 12 files changed, 5 insertions(+), 7453 deletions(-) delete mode 100644 transformer_engine/common/CMakeLists.txt.orig delete mode 100644 transformer_engine/common/CMakeLists.txt.rej delete mode 100644 transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp.orig delete mode 100644 transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.orig delete mode 100644 transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.rej delete mode 100644 transformer_engine/common/util/pybind_helper.h.orig delete mode 100644 transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp.orig delete mode 100644 transformer_engine/pytorch/module/base.py.orig delete mode 100644 transformer_engine/pytorch/module/layernorm_linear.py.orig delete mode 100644 transformer_engine/pytorch/module/linear.py.orig diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a4915080e8..45cc445034 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -143,7 +143,8 @@ list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) + comm_gemm_overlap/comm_gemm_overlap.cpp + ubnext.cu) if (NVTE_WITH_CUBLASMP) list(APPEND transformer_engine_SOURCES diff --git a/transformer_engine/common/CMakeLists.txt.orig b/transformer_engine/common/CMakeLists.txt.orig deleted file mode 100644 index a4915080e8..0000000000 --- a/transformer_engine/common/CMakeLists.txt.orig +++ /dev/null @@ -1,298 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -cmake_minimum_required(VERSION 3.21) - -# Language options -if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) - elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) - else () - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) - endif() -endif() -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CUDA_STANDARD 17) -set(CMAKE_CUDA_STANDARD_REQUIRED ON) -if (CMAKE_BUILD_TYPE STREQUAL "Debug") - set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G") -endif() - -# Hide non-necessary symbols in shared object. -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") - -# Transformer Engine library -project(transformer_engine LANGUAGES CUDA CXX) - -# CUDA Toolkit -find_package(CUDAToolkit REQUIRED) -if (CUDAToolkit_VERSION VERSION_LESS 12.0) - message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") -endif() - -# cuDNN frontend API -set(CUDNN_FRONTEND_INCLUDE_DIR - "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") -if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. " - "Try running 'git submodule update --init --recursive' " - "within the Transformer Engine source.") -endif() -include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) - -set(CUTLASS_INCLUDE_DIR - "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/include") -set(CUTLASS_TOOLS_INCLUDE_DIR - "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include") - -# Python -find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) - -# NVIDIA MathDX include directory (from Python package install location) -if(NOT DEFINED MATHDX_INCLUDE_DIR) - execute_process( - COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx - OUTPUT_VARIABLE _PIP_SHOW_MATHDX - ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR - RESULT_VARIABLE _PIP_SHOW_MATHDX_RES - OUTPUT_STRIP_TRAILING_WHITESPACE) - if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0) - message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}") - endif() - string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}") - if(NOT _MATHDX_LOC_MATCH) - message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}") - endif() - set(MATHDX_LOCATION "${CMAKE_MATCH_1}") - set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include") -endif() -if(NOT EXISTS "${MATHDX_INCLUDE_DIR}") - message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.") -endif() - -# Configure Transformer Engine library -include_directories(${PROJECT_SOURCE_DIR}/..) -set(transformer_engine_SOURCES) -list(APPEND transformer_engine_SOURCES - cudnn_utils.cpp - transformer_engine.cpp - common.cu - multi_tensor/adam.cu - multi_tensor/compute_scale.cu - multi_tensor/l2norm.cu - multi_tensor/scale.cu - multi_tensor/sgd.cu - transpose/cast_transpose.cu - transpose/transpose.cu - transpose/cast_transpose_fusion.cu - transpose/transpose_fusion.cu - transpose/multi_cast_transpose.cu - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise.cu - transpose/swap_first_dims.cu - transpose/quantize_transpose_vector_blockwise_fp4.cu - activation/gelu.cu - dropout/dropout.cu - fused_attn/flash_attn.cu - fused_attn/context_parallel.cu - fused_attn/kv_cache.cu - fused_attn/fused_attn_f16_max512_seqlen.cu - fused_attn/fused_attn_f16_arbitrary_seqlen.cu - activation/relu.cu - activation/swiglu.cu - fused_attn/fused_attn_fp8.cu - fused_attn/fused_attn.cpp - fused_attn/utils.cu - gemm/config.cpp - gemm/cublaslt_gemm.cu - gemm/cutlass_grouped_gemm.cu - normalization/common.cpp - normalization/layernorm/ln_api.cpp - normalization/layernorm/ln_bwd_semi_cuda_kernel.cu - normalization/layernorm/ln_fwd_cuda_kernel.cu - normalization/rmsnorm/rmsnorm_api.cpp - normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu - normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu - permutation/permutation.cu - util/cast.cu - util/padding.cu - util/cuda_driver.cpp - util/cuda_nvml.cpp - util/cuda_runtime.cpp - util/multi_stream.cpp - util/rtc.cpp - swizzle/swizzle.cu - fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu - fused_softmax/scaled_aligned_causal_masked_softmax.cu - fused_rope/fused_rope.cu - fused_router/fused_moe_aux_loss.cu - fused_router/fused_score_for_moe_aux_loss.cu - fused_router/fused_topk_with_score_function.cu - recipe/current_scaling.cu - recipe/delayed_scaling.cu - recipe/fp8_block_scaling.cu - recipe/nvfp4.cu - hadamard_transform/hadamard_transform.cu - hadamard_transform/hadamard_transform_cast_fusion.cu - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) - -if (NVTE_WITH_CUBLASMP) -list(APPEND transformer_engine_SOURCES - comm_gemm/comm_gemm.cpp) -endif() - -add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) -target_include_directories(transformer_engine PUBLIC - "${CMAKE_CURRENT_SOURCE_DIR}/include") - -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "gemm/cutlass_grouped_gemm.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") -else() - message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") -endif() - -# Configure dependencies -target_link_libraries(transformer_engine PUBLIC - CUDA::cublas - CUDA::cudart - CUDNN::cudnn_all) - -target_include_directories(transformer_engine PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR}) -target_include_directories(transformer_engine SYSTEM PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) -target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") -target_include_directories(transformer_engine PRIVATE - ${CUTLASS_INCLUDE_DIR} - ${CUTLASS_TOOLS_INCLUDE_DIR}) - -# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI -option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) -if (NVTE_UB_WITH_MPI) - find_package(MPI REQUIRED) - target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) - target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) - target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) -endif() - -option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF) -if (NVTE_ENABLE_NVSHMEM) - add_subdirectory(nvshmem_api) - target_link_libraries(transformer_engine PUBLIC nvshmemapi) - target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) -endif() - -option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) -if (NVTE_WITH_CUBLASMP) - target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) - target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) - find_library(CUBLASMP_LIB - NAMES cublasmp libcublasmp - PATHS ${CUBLASMP_DIR} - PATH_SUFFIXES lib - REQUIRED) - find_library(NVSHMEM_HOST_LIB - NAMES nvshmem_host libnvshmem_host.so.3 - PATHS ${NVSHMEM_DIR} - PATH_SUFFIXES lib - REQUIRED) - target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) - message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") - message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") -endif() - -# Hack to enable dynamic loading in cuDNN frontend -target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) - -# Helper functions to make header files with C++ strings -function(make_string_header STRING STRING_NAME) - configure_file(util/string_header.h.in - "string_headers/${STRING_NAME}.h" - @ONLY) -endfunction() -function(make_string_header_from_file file_ STRING_NAME) - file(READ "${file_}" STRING) - configure_file(util/string_header.h.in - "string_headers/${STRING_NAME}.h" - @ONLY) -endfunction() - -# Header files with C++ strings -list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path) -make_string_header("${cuda_include_path}" - string_path_cuda_include) -make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu - string_code_transpose_rtc_cast_transpose_fusion_cu) -make_string_header_from_file(transpose/rtc/cast_transpose.cu - string_code_transpose_rtc_cast_transpose_cu) -make_string_header_from_file(transpose/rtc/transpose.cu - string_code_transpose_rtc_transpose_cu) -make_string_header_from_file(transpose/rtc/swap_first_dims.cu - string_code_transpose_rtc_swap_first_dims_cu) -make_string_header_from_file(utils.cuh - string_code_utils_cuh) -make_string_header_from_file(util/math.h - string_code_util_math_h) -target_include_directories(transformer_engine PRIVATE - "${CMAKE_CURRENT_BINARY_DIR}/string_headers") - -# Compiler options -set_source_files_properties(fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu - fused_softmax/scaled_aligned_causal_masked_softmax.cu - multi_tensor/adam.cu - multi_tensor/compute_scale.cu - multi_tensor/l2norm.cu - multi_tensor/scale.cu - multi_tensor/sgd.cu - fused_attn/flash_attn.cu - fused_attn/context_parallel.cu - fused_attn/kv_cache.cu - PROPERTIES - COMPILE_OPTIONS "--use_fast_math") -option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) -if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) - set_source_files_properties(activation/gelu.cu - activation/relu.cu - activation/swiglu.cu - util/cast.cu - PROPERTIES - COMPILE_OPTIONS "--use_fast_math") -endif() -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") - -# Number of parallel build jobs -if(ENV{MAX_JOBS}) - set(BUILD_JOBS_STR "$ENV{MAX_JOBS}") -elseif(ENV{NVTE_BUILD_MAX_JOBS}) - set(BUILD_JOBS_STR "$ENV{NVTE_BUILD_MAX_JOBS}") -else() - set(BUILD_JOBS_STR "max") -endif() -message(STATUS "Parallel build jobs: ${BUILD_JOBS_STR}") - -# Number of threads per parallel build job -set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB}) -if (NOT BUILD_THREADS_PER_JOB) - set(BUILD_THREADS_PER_JOB 1) -endif() -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") -message(STATUS "Threads per parallel build job: ${BUILD_THREADS_PER_JOB}") - -# Install library -install(TARGETS transformer_engine DESTINATION .) diff --git a/transformer_engine/common/CMakeLists.txt.rej b/transformer_engine/common/CMakeLists.txt.rej deleted file mode 100644 index faade11dac..0000000000 --- a/transformer_engine/common/CMakeLists.txt.rej +++ /dev/null @@ -1,12 +0,0 @@ ---- transformer_engine/common/CMakeLists.txt -+++ transformer_engine/common/CMakeLists.txt -@@ -109,7 +109,8 @@ list(APPEND transformer_engine_SOURCES - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/userbuffers/userbuffers.cu -- comm_gemm_overlap/comm_gemm_overlap.cpp) -+ comm_gemm_overlap/comm_gemm_overlap.cpp -+ ubnext.cu) - add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) - target_include_directories(transformer_engine PUBLIC - "${CMAKE_CURRENT_SOURCE_DIR}/include") diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp.orig b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp.orig deleted file mode 100644 index 56369db27f..0000000000 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp.orig +++ /dev/null @@ -1,1210 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include -#include - -#include "common/common.h" -#include "common/util/cuda_driver.h" -#include "common/util/cuda_runtime.h" -#include "common/util/logging.h" -#include "common/util/system.h" -#include "userbuffers/userbuffers.h" - -#define HALF_BYTES 2 -#define UB_MAX_SM 32 - -using namespace std::placeholders; - -namespace transformer_engine { - -namespace { - -std::vector shape_to_vector(const NVTEShape &shape) { - return std::vector(shape.data, shape.data + shape.ndim); -} - -} // namespace - -/*************************************************************************************************** - * Comm+GEMM Overlap Common Core - **************************************************************************************************/ - -bool ubuf_built_with_mpi() { -#ifdef NVTE_UB_WITH_MPI - return true; -#else - return false; -#endif -} - -CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, int tp_size, ExtAllgatherOp allgather_handle, - ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, - int comm_cga_size, int gemm_priority, int comm_priority, - int num_comm_sm, bool set_sm_margin, bool use_ce, - bool atomic_gemm) { - // Initialize userbuf communicator - if (!_comm_created) { - if (myrank == 0) { - printf("!!! [UB] Create Userbuffers Communicator\n"); - } -#ifdef NVTE_UB_WITH_MPI - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); -#else - create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - allgather_handle, barrier_handle, 1, 1, tp_size, 1); -#endif - _comm_created = true; - } - - initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, - num_comm_sm, set_sm_margin, use_ce, atomic_gemm); -} - -void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, - int comm_cga_size, int gemm_priority, int comm_priority, - int num_comm_sm, bool set_sm_margin, bool use_ce, - bool atomic_gemm) { - _use_ce = static_cast(use_ce); - _num_comm_sm = num_comm_sm; - _cga_size = comm_cga_size; - - if (gemm_priority == 0 && comm_priority == 0) { - transformer_engine::cuda::stream_priority_range(&_gemm_priority, &_comm_priority); - } else { - _gemm_priority = gemm_priority; - _comm_priority = comm_priority; - } - for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { - cudaStream_t stream; - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority)); - _stream_compute.push_back(std::move(stream)); - } - - _num_splits = num_splits; - _rank = _ub_comm->myrank; - _tp_size = tp_size; - _tp_id = _rank % _tp_size; - - // Set the number of SMs for GEMM with margin - int sm_count = transformer_engine::cuda::sm_count(); - _math_sms = (set_sm_margin) ? sm_count - num_comm_sm : sm_count; - _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); - - _atomic_gemm = atomic_gemm; - if (_atomic_gemm) { - void *counter_ptr; - size_t counter_bytes = _num_splits * 2 * sizeof(int32_t); - NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes)); - NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes)); - NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2)); - _counter = TensorWrapper(counter_ptr, std::vector{static_cast(_num_splits * 2)}, - DType::kInt32); - } - // CUDA event creation - NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_compute, 0)); - NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_compute, 0)); - NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_comm, 0)); - NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_comm, 0)); - - /* - Defining the launcher order between the communication and GEMM kernels - using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1. - The event is used to schedule the communication kernel before the GEMM. - This is needed only for Hopper, which uses persistent CTA execution. - */ - int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); - int runtime_version = 0; - NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&runtime_version)); - cudaDeviceProp deviceProp; - NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, 0)); - if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { - NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming)); - } else { - _comm_launch_event = 0; - } -} - -CommOverlapCore::~CommOverlapCore() { - cudaEventDestroy(_stop_comm); - cudaEventDestroy(_start_comm); - cudaEventDestroy(_stop_compute); - cudaEventDestroy(_start_compute); - if (_comm_launch_event) { - cudaEventDestroy(_comm_launch_event); - } - - if (_atomic_gemm) { - cudaFree(_counter.dptr()); - } - - for (size_t i = 0; i < _stream_compute.size(); i++) { - cudaStreamSynchronize(_stream_compute[i]); - cudaStreamDestroy(_stream_compute[i]); - } - - auto error = cudaGetLastError(); - if (error != cudaSuccess) { - NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); - } - - if (_comm_created) { - try { -#ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); -#else - destroy_communicator(_ub_comm); -#endif - } catch (const std::exception &e) { - NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); - } - _comm_created = false; - } -} - -TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, - const std::vector &chunk_shape) { - const auto scaling_mode = source.scaling_mode(); - - // Tensor dimensions - std::vector shape = shape_to_vector(source.shape()); - auto flatten_shape_to_2d = [](const std::vector &shape) -> std::pair { - if (shape.empty()) { - return {1, 1}; - } - size_t height = 1; - for (size_t i = 0; i < shape.size() - 1; ++i) { - height *= shape[i]; - } - return {height, shape.back()}; - }; - size_t height, width, chunk_height, chunk_width; - std::tie(height, width) = flatten_shape_to_2d(shape); - std::tie(chunk_height, chunk_width) = flatten_shape_to_2d(chunk_shape); - - // Check tensor dimensions -#define NVTE_DIM_CHECK(cond, message) \ - NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \ - ", chunk offset=", chunk_offset, ")") - NVTE_DIM_CHECK(height > 0 && width > 0, "Attempted to get chunk from empty tensor"); - NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk"); - NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width, - "Attempted to get out-of-bounds tensor chunk"); - if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { - // MXFP8 scale-inverses are padded to a 2D matrix with dims that - // are divisible by 128. UB doesn't handle this padding yet. - NVTE_DIM_CHECK(height % 128 == 0 && width % 128 == 0, - "Userbuffers requires MXFP8 tensor dims that are divisible by 128"); - NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0, - "Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128"); - } -#undef NVTE_DIM_CHECK - - // Construct tensor chunk - TensorWrapper chunk(scaling_mode); - for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { - auto param_type = static_cast(param_id); - auto param = source.get_parameter(param_type); - auto param_dptr = reinterpret_cast(param.data_ptr); - auto param_dtype = static_cast(param.dtype); - auto param_shape = shape_to_vector(param.shape); - - if (param_dptr != nullptr) { - if (param_type == NVTETensorParam::kNVTERowwiseData || - param_type == NVTETensorParam::kNVTEColumnwiseData) { - // Offset data pointer - param_dptr += get_buffer_size_bytes(chunk_offset, param_dtype); - param_shape = chunk_shape; - - if (param_type == NVTETensorParam::kNVTEColumnwiseData && - source.scaling_mode() == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) { - // Columnwise shape for FP8 tensor-scaled tensors shifts the last dimension to the front - auto last_dim = param_shape.back(); - param_shape.pop_back(); - param_shape.insert(param_shape.begin(), last_dim); - } - } else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING && - (param_type == NVTETensorParam::kNVTERowwiseScaleInv || - param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) { - // Calculate offset and size for MXFP8 scale-invs - size_t chunk_scale_height = chunk_height; - size_t chunk_scale_width = chunk_width; - if (param_type == NVTETensorParam::kNVTERowwiseScaleInv) { - chunk_scale_width /= 32; - } else { - chunk_scale_height /= 32; - } - param_dptr += get_buffer_size_bytes(chunk_offset / 32, param_dtype); - param_shape = {chunk_scale_height, chunk_scale_width}; - } - - // Set chunked source parameters into the chunked tensor output - chunk.set_parameter(param_type, reinterpret_cast(param_dptr), param_dtype, - param_shape); - } - } - return chunk; -} - -TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source, - size_t chunk_offset, - const std::vector &chunk_shape) { - // Start with a chunk of the source tensor - auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape); - - // Update chunk with offset data pointers from the communication buffer - auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + chunk_offset * _ubuf.element_size(); - if (chunk.dptr() != nullptr) { - chunk.set_rowwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), chunk.shape()); - } - if (chunk.columnwise_dptr() != nullptr) { - chunk.set_columnwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), - chunk.columnwise_shape()); - } - return chunk; -} - -/*************************************************************************************************** - * Comm+GEMM Overlap Base (Pipelined / Collective) - **************************************************************************************************/ - -CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, - int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, int tp_size, ExtAllgatherOp allgather_handle, - ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, - int comm_cga_size, int gemm_priority, int comm_priority, - int num_comm_sm, bool set_sm_margin, bool atomic_gemm, - bool rs_overlap_first_gemm) - : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, - allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, - gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, - atomic_gemm) { - initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); -} - -void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, - bool rs_overlap_first_gemm) { - _rs_overlap_first_gemm = rs_overlap_first_gemm; - _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); - NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, - "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", - "or 2 (multi-atomic)."); - - NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); - size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); - void *buffer_ptr; - _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_ub_comm->myrank == 0) { - printf("!!! [UB] Register UBuf %d\n", _ub_reg); - } - _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); - - NVTE_CHECK_CUDA( - cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority)); - NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); -} - -CommOverlapBase::~CommOverlapBase() { - cudaEventDestroy(_start_d2dcopy); - cudaStreamSynchronize(_stream_comm); - cudaStreamDestroy(_stream_comm); -} - -/* -** Bulk GEMM + COMM -** This function assumes the communication input is pre-copied to _ubuf -*/ -void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, - CommOverlapType comm_type, TensorWrapper &rs_output, - cudaStream_t stream_main) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - - // Catch up the default torch stream - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0)); - - // Communication: AG and RS - int comm_elements = _ubuf.bytes() / 2; // UBUF uses 2Byte element size - if (comm_type == CommOverlapType::AG) { - allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, - (cudaEvent_t)_comm_launch_event); - } else { - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - comm_elements *= 2; - assert(rs_output.numel() == _ubuf.numel() / _tp_size); - assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); - assert(rs_output.element_size() == 2); - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, - comm_elements, _ub_comm, _stream_comm, - (cudaEvent_t)_comm_launch_event); - } else { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, - (cudaEvent_t)_comm_launch_event); - } - } - - assert(pre_gelu_out.numel() == 0); - // When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch - if (_comm_launch_event) - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _comm_launch_event, 0)); - nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, - grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, - _stream_compute[0]); - - _ub_comm->sms = ori_sms; - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); - -} // CommOverlapBase::bulk_overlap - -/* -** Split FPROP GEMM + ReduceScatter -*/ -void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, - const TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions - size_t m = transa ? A.size(0) : A.size(1); - size_t k = transa ? A.size(1) : A.size(0); - size_t n = _ubuf.size(0); - size_t m_chunk = m / _num_splits; - size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - - // Reset atomic counters - int *counter_ptr = reinterpret_cast(_counter.dptr()); - reset_counters(counter_ptr, _num_splits, false, stream_main); - - // Catch up the default torch stream - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0)); - - assert(pre_gelu_out.numel() == 0); - - auto output_d = get_buffer_chunk_like(D, 0, {n, m}); - auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); - nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), - _stream_compute[0]); - - for (int i = 0; i < _num_splits; i++) { - if (_rs_kernel_type == 1) { - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D.dtype(), fp8_type, - reducescatter2_userbuff_strided_atomic_fp8( - rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, - &counter_ptr[i], _ub_comm, _stream_comm);); - } else { - reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _num_splits, &counter_ptr[i], _ub_comm, - _stream_comm); - } - } else if (_rs_kernel_type == 2) { - if (_ubuf.element_size() == 1) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D.dtype(), fp8_type, - reducescatter2_userbuff_strided_multiatomic_fp8( - rs_output_ptr, D.scale_inv(), _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, - counter_ptr, _ub_comm, _stream_comm);); - } else { - reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, - _num_splits, counter_ptr, _ub_comm, - _stream_comm); - } - break; - } else { - consumer(counter_ptr, i, _stream_comm); - if (_ubuf.element_size() == 1) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D.dtype(), fp8_type, - reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, D.scale_inv(), - _ub_reg, i * m_chunk, m_chunk, n, m, - _ub_comm, _stream_comm);); - } else { - reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _ub_comm, _stream_comm); - } - } - - rs_output_ptr += m_chunk * rs_output.element_size(); - } - - _ub_comm->sms = ori_sms; - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[0])); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); -} // split_overlap_rs - -/* -** Split FPROP GEMM + ReduceScatter -*/ -void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main) { - // Get GEMM dimensions - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - size_t m = transa ? A.size(0) : A.size(1); - size_t k = transa ? A.size(1) : A.size(0); - size_t n = _ubuf.size(0); - size_t m_chunk = m / _num_splits; - const std::vector input_a_chunk_shape = - (transa ? std::vector{m_chunk, k} : std::vector{k, m_chunk}); - const std::vector output_chunk_shape = {n, m_chunk}; - size_t input_a_chunk_size = m_chunk * k; - size_t output_chunk_size = n * m_chunk; - size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - - // Helper function to get bias chunk if needed - auto maybe_get_bias_chunk = [this, &bias, m_chunk](size_t chunk_id) -> TensorWrapper { - if (bias.dptr() == nullptr) { - return TensorWrapper(); - } - return get_tensor_chunk(bias, chunk_id * m_chunk, {m_chunk}); - }; - - // Catch up the default torch stream - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); - } - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0)); - - assert(pre_gelu_out.numel() == 0); - - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - if (_rs_overlap_first_gemm) { - auto input_a_chunk = get_tensor_chunk(A, 0, input_a_chunk_shape); - auto output_chunk = get_buffer_chunk_like(D, 0, output_chunk_shape); - auto bias_chunk = maybe_get_bias_chunk(0); - auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), - pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, _stream_compute[0]); - - for (int i = 1; i < _num_splits; i++) { - input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape); - output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape); - bias_chunk = maybe_get_bias_chunk(i); - workspace_chunk = get_tensor_chunk( - workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), - pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), - accumulate, use_split_accumulator, _math_sms, - _stream_compute[i % _stream_compute.size()]); - - NVTE_CHECK_CUDA( - cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); - - // Communication chunk - if (_ubuf.element_size() == 1) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D.dtype(), fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, D.scale_inv(), _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, - _ub_comm, _stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, _stream_comm); - } - - rs_output_ptr += m_chunk * rs_output.element_size(); - } - int last_compute_stream_id = - (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[last_compute_stream_id])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); - - // Last communication chunk with max SM - _ub_comm->sms = UB_MAX_SM; - if (_ubuf.element_size() == 1) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D.dtype(), fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, D.scale_inv(), _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, - n, m, _ub_comm, _stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, - (_num_splits - 1) * output_chunk_size, m_chunk, n, m, - _ub_comm, _stream_comm); - } - } else { - for (int i = 0; i < _num_splits; i++) { - auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape); - auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape); - auto bias_chunk = maybe_get_bias_chunk(i); - auto workspace_chunk = get_tensor_chunk( - workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), - pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), - accumulate, use_split_accumulator, _math_sms, - _stream_compute[i % _stream_compute.size()]); - - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); - - // Communication chunk. Uses MAX_SM at the last chunk - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D.dtype(), fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, D.scale_inv(), _ub_reg, i * output_chunk_size, m_chunk, n, m, - _ub_comm, _stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, - m_chunk, n, m, _ub_comm, _stream_comm); - } - - rs_output_ptr += m_chunk * rs_output.element_size(); - } - } - - _ub_comm->sms = ori_sms; - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); -} // CommOverlapBase::split_overlap_rs - -void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, - cudaStream_t stream_main) { - int comm_bytes = _ubuf.bytes(); - int comm_bytes_per_rank = comm_bytes / _tp_size; - - // We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush - userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, - _ub_comm, send_stream); - userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, - _ub_comm, recv_stream); - - // We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf - for (auto stream : {send_stream, recv_stream}) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0)); - } - - // Next we sync with the main stream - // We have to recapture an event off the comm stream to enable cuda graph capture otherwise the comm stream will be never be joined in the graph - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); -} - -/*************************************************************************************************** - * Comm+GEMM Overlap P2P Base (Ring-Exchange) - **************************************************************************************************/ - -CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, - int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, int tp_size, - ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, - CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int gemm_priority, int comm_priority, - int num_comm_sm, bool set_sm_margin, bool use_ce, - bool atomic_gemm, bool aggregate) - : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, - allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, - gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, - atomic_gemm) { - initialize(buffer_shape, buffer_dtype, comm_type, aggregate); -} - -void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, - CommOverlapType comm_type, bool aggregate) { - _is_p2p = true; - _is_reduce_scatter = comm_type == CommOverlapType::RS; - _aggregate = aggregate; - - // Create workspace tensor with userbuffer - NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); - size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); - int buffer_chunk_bytes = buffer_bytes / _tp_size; - _num_ubuf_chunks = _tp_size; - if (_is_reduce_scatter) { - // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk - // outputs for reduction at the end of the pipelining. - buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1); - _num_ubuf_chunks = _tp_size * 2 - 1; - } - - void *buffer_ptr; - _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg); - _ubuf = TensorWrapper( - buffer_ptr, - std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]}, - buffer_dtype); - - // Create tensor chunks for easy management - char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); - for (int i = 0; i < _num_ubuf_chunks; i++) { - _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), - std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]}, - buffer_dtype)); - ubuf_byte_ptr += buffer_chunk_bytes; - } - - _rank_round_tp = (_rank / _tp_size) * _tp_size; - _next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp; - _prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp; - - _self_chunk_id = _tp_id; - if (_atomic_gemm && !_is_reduce_scatter) { - _use_multiatomic_ag = getenv("NVTE_AG_P2P_MULTI_ATOMIC"); - if (_use_multiatomic_ag) { - _use_ce = 0; - _ub_comm->push = 1; - if (_rank == 0) { - printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); - } - } - _self_chunk_id = 0; - NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); - } - - for (int i = 0; i < _stream_compute.size(); i++) { - cudaStream_t stream; - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); - _stream_send.push_back(std::move(stream)); - } - NVTE_CHECK_CUDA( - cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority)); - NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); -} - -CommOverlapP2PBase::~CommOverlapP2PBase() { - cudaEventDestroy(_stop_recv); - cudaEventDestroy(_stop_send); - cudaStreamDestroy(_stream_recv); - for (size_t i = 0; i < _stream_send.size(); i++) { - cudaStreamDestroy(_stream_send[i]); - } -} - -void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, - bool local_chunk, bool rowwise) { - // Check element size - const size_t element_size = source.element_size(); - NVTE_CHECK(_ubuf.element_size() == element_size, - "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", - "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), - " bytes)"); - - // Input data - const size_t source_size = source.numel(); - const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); - - // Userbuffers data - void *dst_ptr; - if (local_chunk) { - NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, - "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", - "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); - dst_ptr = _ubufs[_tp_id].dptr(); - } else { - NVTE_CHECK(_ubuf.numel() == source_size, - "Tried to copy an invalid tensor into a Userbuffers buffer ", - "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); - dst_ptr = _ubuf.dptr(); - } - - // Copy data - NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, - cudaMemcpyDeviceToDevice, stream)); -} - -TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, - size_t chunk_id) { - // Start with a chunk of the source tensor - auto chunk = get_tensor_chunk(source, 0, shape_to_vector(_ubufs[chunk_id].shape())); - - // Update chunk with offset data pointers from the communication buffer - if (chunk.dptr() != nullptr) { - chunk.set_rowwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.shape()); - } - if (chunk.columnwise_dptr() != nullptr) { - chunk.set_columnwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.columnwise_shape()); - } - return chunk; -} - -/* -** Split AllGather + AtomicGEMM using P2P communication -** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG -** outputs in each rank to be in the contiguous memory space after all ring exchange phases. -*/ -void CommOverlapP2PBase::atomic_gemm_overlap_ag( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - - // Get GEMM dimensions between TN and NN input layouts - const size_t m = (transa) ? A.size(0) : A.size(1); - const size_t n_chunk = _ubufs[0].size(0); - assert(pre_gelu_out.numel() == 0); - - // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].bytes(); - - // Create an GEMM output buffer with N+1 chunks in a contiguous memory - void *D_buffer_ptr; - int D_chunk_bytes = n_chunk * m * D.element_size(); - NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); - auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), - D.scale_inv(), D.scale_inv_shape(), D.scaling_mode()); - - // Reset atomic counters - int *counter_ptr = reinterpret_cast(_counter.dptr()); - reset_counters(counter_ptr, _tp_size, true, stream_main); - - // Catch up the default torch stream - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); - - auto input_b = get_buffer_chunk_like(B, 0, shape_to_vector(B.shape())); - size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); - - for (int i = 0; i < _tp_size - 1; i++) { - // Set the userbuffer id. Buffer under send is the input for the current - // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to - // have the AG output in all ranks to be contiguous after the ring - // exchanges - int send_chunk_id = i; - int recv_chunk_id = i + 1; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - - if (_use_multiatomic_ag) { - if (i == 0) { - _ub_comm->use_ce = 0; - userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, - _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, - true, _stream_recv); - } - } else { - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank, - _stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank, - _stream_recv); - producer(counter_ptr, recv_chunk_id, _stream_recv); - } - if (i == 0) { - nvte_cublas_atomic_gemm(A.data(), input_b.data(), D_buffer.data(), bias.data(), - pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), - accumulate, use_split_accumulator, _math_sms, 0, _tp_size, false, - _counter.data(), stream_main); - } - } - - // Store the input activation for backprop - if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); - assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), - _ubufs[_self_chunk_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); - } - - // Copy the first GEMM output chunk to the end chunk position of D_buffer - char *src_ptr = reinterpret_cast(D_buffer.dptr()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + D.bytes(), src_ptr, D_chunk_bytes, - cudaMemcpyDeviceToDevice, stream_main)); - - // Return the last N rows of D_buffer - NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.bytes(), - cudaMemcpyDeviceToDevice, stream_main)); - - // Clean up buffer allocation - NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main)); - - _ub_comm->sms = ori_sms; -} // CommOverlapP2PBase::atomic_gemm_overlap_ag - -/* -** Split AllGather + GEMM using P2P communication -** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG -** outputs in each rank to be in the contiguous memory space after all ring exchange phases. -*/ -void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, - const TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions between TN and NN input layouts - const size_t m = (transa) ? A.size(0) : A.size(1); - const size_t k = (transa) ? A.size(1) : A.size(0); - const size_t n_chunk = _ubufs[0].size(0); - - // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].bytes(); - const bool do_gelu = pre_gelu_out.numel() > 0; - size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - - // Check B copy sizing - if (B_copy.numel() > 0) { - NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", - _ubuf.numel(), " elements but got ", B_copy.numel()); - NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), - "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, - "-bit data type but got ", B_copy.element_size() * 8, "-bit"); - } - - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); - } - if (_aggregate) { - const int num_steps = _tp_size / 2; - - // Chunk dims - std::vector input_b_chunk_shape = - (transb ? std::vector{k, 2 * n_chunk} : std::vector{2 * n_chunk, k}); - std::vector output_chunk_shape = {2 * n_chunk, m}; - size_t input_b_chunk_size = 2 * n_chunk * k; - size_t output_chunk_size = 2 * n_chunk * m; - - // Initial 1X input chunk exchange between neighboring peers - int send_chunk_id = _tp_id; - int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - _stream_send[0]); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, - _stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); - - int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; - const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; - const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; - - // Ring exchange of 2X inputs chunks - for (int i = 0; i < num_steps; i++) { - send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; - recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; - send_offset = comm_bytes * send_chunk_id; - recv_offset = comm_bytes * recv_chunk_id; - - // GEMM - auto input_b_chunk = - get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape); - auto output_chunk = - get_tensor_chunk(D, output_chunk_size * send_chunk_id / 2, output_chunk_shape); - auto aux_chunk = (do_gelu) - ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id / 2, - {n_chunk * 2, k}) - : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); - auto workspace_chunk = get_tensor_chunk( - workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); - - nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), - aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, - _stream_compute[i % _stream_compute.size()]); - - if (i < num_steps - 1) { - // P2P communication - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, _stream_send[0]); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, - prev_rank, _stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); - NVTE_CHECK_CUDA( - cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } - } - } else { - // Chunk dims - std::vector input_b_chunk_shape = - (transb ? std::vector{k, n_chunk} : std::vector{n_chunk, k}); - std::vector output_chunk_shape = {n_chunk, m}; - size_t input_b_chunk_size = n_chunk * k; - size_t output_chunk_size = n_chunk * m; - - for (int i = 0; i < _tp_size; i++) { - // Set the userbuffer id. Buffer under send is the input for the current - // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to - // have the AG output in all ranks to be contiguous after the ring - // exchanges - int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; - int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - - // GEMM - auto input_b_chunk = - get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape); - auto output_chunk = - get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape); - auto aux_chunk = - (do_gelu) - ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k}) - : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); - auto workspace_chunk = get_tensor_chunk( - workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); - - nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), - aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, - _stream_compute[i % _stream_compute.size()]); - - if (i < _tp_size - 1) { - // P2P communication - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, _stream_send[0]); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, _stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); - NVTE_CHECK_CUDA( - cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } - } - } - - // Copy all-gathered B from communication buffer into auxiliary output - if (B_copy.numel() > 0) { - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), - cudaMemcpyDeviceToDevice, _stream_send[0])); - } - - _ub_comm->sms = ori_sms; - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); -} // CommOverlapP2PBase::split_overlap_ag - -/* -** Split ReduceScatter + GEMM using P2P communication -*/ -void CommOverlapP2PBase::atomic_gemm_overlap_rs( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - - // Get communication and GEMM input chunk sizes - const int comm_bytes = _ubufs[0].bytes(); - - // Reset counters - int *counter_ptr = reinterpret_cast(_counter.dptr()); - reset_counters(counter_ptr, _tp_size, false, stream_main); - - // Catch up the main stream - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); - - // Atomic GEMM - // Process GEMM chunks in the order that AG+GEMM places the output chunks. - auto output_d = get_buffer_chunk_like(D, 0, shape_to_vector(D.shape())); - nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, - _math_sms, 0, _tp_size, true, _counter.data(), stream_main); - - // P2P communication chunk - for (int i = 1; i < _tp_size; i++) { - int send_chunk_id = i - 1; - int recv_chunk_id = send_chunk_id + _tp_size; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - - consumer(counter_ptr, send_chunk_id, _stream_recv); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, - _stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, - _stream_recv); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); - - // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D.dtype(), fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, - _ubufs[0].numel(), stream_main);); - } else { - reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); - } - _ub_comm->sms = ori_sms; -} - -/* -** Split ReduceScatter + GEMM using P2P communication -*/ -void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, - const TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - - // Get communication and GEMM input chunk sizes - size_t m = transa ? A.size(0) : A.size(1); - size_t k = transa ? A.size(1) : A.size(0); - size_t n_chunk = _ubufs[0].size(0); - const int comm_bytes = _ubufs[0].bytes(); - - // Get input and workspace data pointers - size_t input_chunk_size = n_chunk * k; - size_t output_chunk_size = n_chunk * m; - size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - - // Catch up the main stream - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - for (size_t i = 0; i < _stream_send.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[i], _start_compute, 0)); - } - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); - } - - // GEMM and send/recv chunks - for (int i = 0; i < _tp_size; i++) { - // GEMM chunk - int stream_id = i % _stream_compute.size(); - int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - - auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); - auto output_chunk = get_buffer_chunk_by_id(D, i); - - auto workspace_chunk = - get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); - - nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), - pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, _stream_compute[stream_id]); - - if (i > 0) { - // P2P communication chunk - int prev_stream_id = (i - 1) % _stream_compute.size(); - int send_offset = comm_bytes * (i - 1); - int recv_offset = comm_bytes * (i - 1 + _tp_size); - int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[prev_stream_id])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[prev_stream_id], _start_comm, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0)); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, - _stream_send[prev_stream_id]); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, - _stream_recv); - } - } - - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); - } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); - - // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D.dtype(), fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, - _ubufs[0].numel(), stream_main);); - } else { - reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); - } - - _ub_comm->sms = ori_sms; -} - -} // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index cffc411a0d..572f4c52cd 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -225,6 +225,9 @@ class CommOverlapBase : public CommOverlapCore { void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, cudaStream_t stream_main) override; + + // initialize ubnext buffer and return multicast pointer for allreduce + uintptr_t init_ubnext(); }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.orig b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.orig deleted file mode 100644 index cffc411a0d..0000000000 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.orig +++ /dev/null @@ -1,327 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ -#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ - -#include -#include -#include - -#include - -#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" - -#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 - -namespace transformer_engine { - -/* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. - * This can turned on by building Transformer Engine with the `NVTE_UB_WITH_MPI=1` option. - * - * \return True if Userbuffers is built with MPI - */ -bool ubuf_built_with_mpi(); - -enum class CommOverlapType { RS = 0, AG = 1 }; - -enum class CommOverlapAlgo { - BULK_OVERLAP_AG = 0, - BULK_OVERLAP_RS = 1, - SPLIT_PIPELINED_AG_P2P = 2, - SPLIT_PIPELINED_RS = 3, - SPLIT_PIPELINED_RS_P2P = 4, - ATOMIC_GEMM_RS = 5, - ATOMIC_GEMM_AG_P2P = 6, - ATOMIC_GEMM_RS_P2P = 7, - EXTERNAL_BULK_OVERLAP_AG = 8, -}; - -class CommOverlapCore { - protected: - static inline communicator *_ub_comm{nullptr}; - static inline bool _comm_created{false}; - - int _rank; - int _tp_id; - int _tp_size; - int _num_splits; - int _math_sms; - int _num_comm_sm; - int _cga_size; - int _use_ce; - int _ub_reg; - int _gemm_priority; - int _comm_priority; - bool _atomic_gemm{false}; - bool _is_p2p{false}; - - TensorWrapper _ubuf; - TensorWrapper _counter; - float *_ubuf_scale_inv; - bool _ubuf_scale_inv_initialized{false}; - - std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; - - private: - void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, - int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, - bool use_ce, bool atomic_gemm); - - public: - CommOverlapCore() {} // dummy constructor for exposing type to Python - - CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, - int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, - int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, - int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, - bool atomic_gemm); - - virtual ~CommOverlapCore(); - - void *get_ubuf_dptr() { return _ubuf.dptr(); } - - void set_ubuf_scale_inv(float *scale_inv) { - _ubuf_scale_inv = scale_inv; - _ubuf_scale_inv_initialized = true; - } - - virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, - bool rowwise = true) { - NVTE_ERROR("Operation is not implemented."); - } - - TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, - const std::vector &shape); - - TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, - const std::vector &shape); - - int get_tp_size() { return _tp_size; } - - bool is_atomic_gemm() { return _atomic_gemm; } - - bool is_p2p_overlap() { return _is_p2p; } - - bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } - - virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, - TensorWrapper &rs_output, cudaStream_t stream_main) { - NVTE_ERROR("Operation is not implemented."); - } - - virtual void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main) { - NVTE_ERROR("Operation is not implemented."); - } - - virtual void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main) { - NVTE_ERROR("Operation is not implemented."); - } - - virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { - NVTE_ERROR("Operation is not implemented."); - } - - virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main) { - NVTE_ERROR("Operation is not implemented."); - } - - virtual void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, - cudaStream_t stream_main) { - NVTE_ERROR("Operation is not implemented."); - } -}; // CommOverlapCore - -class CommOverlapBase : public CommOverlapCore { - protected: - int _rs_kernel_type; - bool _rs_overlap_first_gemm; - cudaStream_t _stream_comm; - cudaEvent_t _start_d2dcopy; - - private: - void initialize(const std::vector &buffer_shape, DType buffer_dtype, - bool rs_overlap_first_gemm); - - public: - CommOverlapBase() {} // dummy constructor for exposing type to Python - - CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, - int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, - bool set_sm_margin = true, bool atomic_gemm = false, - bool rs_overlap_first_gemm = false); - - virtual ~CommOverlapBase(); - - /* - ** Bulk GEMM + COMM - ** This function assumes the communication input is pre-copied to _ubuf - */ - void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, - cudaStream_t stream_main) override; - - void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main) override { - NVTE_ERROR("Operation not supported."); - } - - void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main) override { - NVTE_ERROR("Operation not supported."); - } - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main) override; - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main) override; - - void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, - cudaStream_t stream_main) override; -}; // CommOverlapBase - -class CommOverlapP2PBase : public CommOverlapCore { - protected: - bool _is_reduce_scatter{false}; - bool _use_multiatomic_ag{false}; - bool _aggregate; - int _next_rank; - int _prev_rank; - int _rank_round_tp; - int _num_ubuf_chunks; - int _self_chunk_id; - std::vector _ubufs; - std::vector _stream_send; - cudaStream_t _stream_recv; - cudaEvent_t _stop_send, _stop_recv; - - private: - void initialize(const std::vector &buffer_shape, DType buffer_dtype, - CommOverlapType comm_type, bool aggregate); - - public: - CommOverlapP2PBase() {} // dummy constructor for exposing type to Python - - CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, - int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, - ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, - CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, - int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, - int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, - bool atomic_gemm = false, bool aggregate = false); - - virtual ~CommOverlapP2PBase(); - - void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, - bool rowwise = true) override; - - TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); - - void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, - cudaStream_t stream_main) override { - NVTE_ERROR("Operation not supported."); - } - - /* - ** Split AllGather + AtomicGEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG - ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. - */ - void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main) override; - - /* - ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG - ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. - */ - void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main) override; - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main) override; - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main) override; - - /* - ** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object. - ** The gemm for overlap_gemm is assumed to have been previously started. - */ - void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, - cudaStream_t stream_main) override { - NVTE_ERROR("Operation not supported."); - } -}; // CommOverlapP2PBase - -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.rej b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.rej deleted file mode 100644 index f229f5eea6..0000000000 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h.rej +++ /dev/null @@ -1,11 +0,0 @@ ---- transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h -+++ transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h -@@ -198,6 +198,8 @@ class CommOverlapBase : public CommOverlapCore { - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main) override; -+ // initialize ubnext buffer and return multicast pointer for allreduce -+ uintptr_t init_ubnext(); - }; // CommOverlapBase - - class CommOverlapP2PBase : public CommOverlapCore { diff --git a/transformer_engine/common/util/pybind_helper.h.orig b/transformer_engine/common/util/pybind_helper.h.orig deleted file mode 100644 index bce124e705..0000000000 --- a/transformer_engine/common/util/pybind_helper.h.orig +++ /dev/null @@ -1,140 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ -#define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ - -#include -#include -#include -#include - -#include "cuda_runtime.h" - -#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType", pybind11::module_local()) \ - .value("kByte", transformer_engine::DType::kByte) \ - .value("kInt32", transformer_engine::DType::kInt32) \ - .value("kFloat32", transformer_engine::DType::kFloat32) \ - .value("kFloat16", transformer_engine::DType::kFloat16) \ - .value("kBFloat16", transformer_engine::DType::kBFloat16) \ - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ - .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ - pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ - pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) \ - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_Softmax_Type", pybind11::module_local()) \ - .value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \ - .value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \ - .value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \ - pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) \ - .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ - .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ - .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) \ - .value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \ - .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ - .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ - .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \ - pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \ - .value("NVTE_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD) \ - .value("NVTE_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD) \ - .value("NVTE_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD) \ - .value("NVTE_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD) \ - .value("NVTE_Paged_KV_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD) \ - .value("NVTE_Paged_KV_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD) \ - .value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \ - .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ - .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ - .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_( \ - m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \ - .value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \ - .value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \ - pybind11::enum_(m, "CommOverlapType", \ - pybind11::module_local()) \ - .value("RS", transformer_engine::CommOverlapType::RS) \ - .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo", \ - pybind11::module_local()) \ - .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ - .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ - .value("SPLIT_PIPELINED_AG_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ - .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ - .value("SPLIT_PIPELINED_RS_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ - .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ - .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ - .value("EXTERNAL_BULK_OVERLAP_AG", \ - transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \ - py::class_>(m, "CommOverlapCore", \ - pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ - py::call_guard()) \ - .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ - py::call_guard()) \ - .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ - py::call_guard()) \ - .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ - py::call_guard()); \ - py::class_, \ - transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ - py::call_guard()); \ - py::class_, \ - transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ - pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ - py::call_guard()); \ - m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def( \ - "get_stream_priority_range", \ - [](int device_id = -1) { \ - int low_pri, high_pri; \ - transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ - return std::make_pair(low_pri, high_pri); \ - }, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); - -#endif diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp.orig b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp.orig deleted file mode 100644 index 38947c5a9d..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp.orig +++ /dev/null @@ -1,320 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../extensions.h" -#include "transformer_engine/transformer_engine.h" - -#define HALF_BYTES 2 -#define UB_MAX_SM 32 - -using namespace torch::indexing; -using namespace std::placeholders; - -namespace te = transformer_engine; - -/*************************************************************************************************** - * CommOverlapHelper - **************************************************************************************************/ - -CommOverlapHelper::CommOverlapHelper() { -#ifndef NVTE_UB_WITH_MPI - NVTE_ERROR("Internal TE error: Dummy CommOverlapHelper init without NVTE_UB_WITH_MPI=1!"); -#endif -} // empty constructor for NVTE_UB_WITH_MPI=1 - -CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_domain_group) { -#ifndef NVTE_UB_WITH_MPI - pgs.insert({"world", world_group}); - myrank = pgs["world"]->getRank(); - numranks = pgs["world"]->getSize(); - c10d::ProcessGroup::BackendType backend = pgs["world"]->getBackendType(); - backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); - - if (intra_domain_group.has_value()) { - // Get local rank on node and number of local ranks - NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, - "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!", pgs["world"]->getBackendName()); - pgs.insert({"intra", intra_domain_group.value()}); - mylocal = pgs["intra"]->getRank(); - numlocal = pgs["intra"]->getSize(); - - if (numlocal == numranks) { - // Intra-node group is same as the world group so there can only be 1 node - NVTE_CHECK( - mylocal == myrank, - "Internal TE error: Local rank must be equal to global rank when intra-node group size ", - "is equal to the world group size!"); - mynode = 0; - numnodes = 1; - } else { - // Get node ID and number of nodes - mynode = myrank / numlocal; - numnodes = numranks / numlocal; - } - } else { - // Intra-node group is not set so we assume there is only 1 node - mylocal = myrank; - numlocal = numranks; - pgs.insert({"intra", world_group}); - - mynode = 0; - numnodes = 1; - } - - initialized = true; -#else - NVTE_ERROR("Internal TE error: CommOverlapHelper cannot be initialized with valid PyTorch ", - "distributed process groups when TE is compiled with NVTE_UB_WITH_MPI=1!"); -#endif -} - -CommOverlapHelper::~CommOverlapHelper() { -#ifndef NVTE_UB_WITH_MPI - for (auto &pg : pgs) pg.second = nullptr; - backend_is_nccl = false; - initialized = false; -#endif -} - -void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void *localdata, - size_t localbytes, ExtComm group) { -#ifndef NVTE_UB_WITH_MPI - NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", - "with valid process groups!"); - - auto localtensor = - torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; - auto globaltensor = - torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; - - std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; - std::vector localchunk = {localtmp}; - auto work = pgs[group]->allgather(globalchunks, localchunk); - work->wait(); - - if (backend_is_nccl) { - globaltensor.copy_(globaltmp.cpu()); - globaltmp = torch::Tensor(); - localtmp = torch::Tensor(); - } -#else - NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_allgather is a no-op when TE is compiled ", - "with NVTE_UB_WITH_MPI=1!"); -#endif -} - -void CommOverlapHelper::ub_barrier(ExtComm group) { -#ifndef NVTE_UB_WITH_MPI - NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", - "with valid process groups!"); - auto work = pgs[group]->barrier(); - work->wait(); -#else - NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_barrier is a no-op when TE is compiled ", - "with NVTE_UB_WITH_MPI=1!"); -#endif -} - -/*************************************************************************************************** - * CommOverlap - **************************************************************************************************/ - -CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, int num_splits, - int num_max_streams, int comm_cga_size, int gemm_priority, - int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, - bool rs_overlap_first_gemm) - : te::CommOverlapBase(buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), - helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, - helper->mynode, helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, - num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, - set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} - -/* -** Helper function to copy input to _ubuf -*/ -void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) { - const auto &input_ = input.contiguous(); - - // Check element size - const size_t element_size = input.element_size(); - NVTE_CHECK(_ubuf.element_size() == element_size, - "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", - "(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), - " bytes)"); - - // Input data - const size_t input_size = input_.numel(); - const void *src_ptr = input_.data_ptr(); - - // Userbuffers data - const size_t ubuf_size = _ubuf.numel(); - void *dst_ptr = _ubuf.dptr(); - if (local_chunk) { - NVTE_CHECK(input_size * _tp_size == ubuf_size, - "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", - "(input_size=", input_size, ", tensor_parallel_size=", _tp_size, - ", ubuf_size=", ubuf_size, ")"); - dst_ptr = (reinterpret_cast(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size); - } else { - NVTE_CHECK(input_size == ubuf_size, - "Tried to copy an invalid tensor into a Userbuffers buffer ", - "(input_size=", input_size, ", ubuf_size=", ubuf_size, ")"); - } - - // Copy data - auto stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size, - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); -} - -at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional> shape) { - // Check buffer shape - const size_t ubuf_size = _ubuf.numel(); - if (shape) { - const size_t requested_size = transformer_engine::pytorch::product(*shape); - if (local_chunk) { - NVTE_CHECK(requested_size * _tp_size == ubuf_size, - "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape, - ", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")"); - } else { - NVTE_CHECK(requested_size == ubuf_size, - "Invalid shape for a Userbuffers buffer (requested shape=", *shape, - ", ubuf_size=", ubuf_size, ")"); - } - } else { - int64_t dim0 = _ubuf.size(0); - int64_t dim1 = _ubuf.size(1); - if (local_chunk) { - dim0 /= _tp_size; - } - shape = {dim0, dim1}; - } - - // Data pointer - void *ubuf_ptr = _ubuf.dptr(); - if (local_chunk) { - ubuf_ptr = (reinterpret_cast(ubuf_ptr) + - (ubuf_size / _tp_size) * _tp_id * _ubuf.element_size()); - } - - // Construct PyTorch tensor - const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); - return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); -} - -std::pair CommOverlap::get_communication_stream() { - // Return the same stream for both send and recv - return {at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()), - at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device())}; -} - -/*************************************************************************************************** - * CommOverlapP2P - **************************************************************************************************/ - -CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, - te::CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int gemm_priority, int comm_priority, - int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, - bool aggregate) - : te::CommOverlapP2PBase( - buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, - helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, - tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, - comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, - atomic_gemm, aggregate) {} - -/* -** Copy input to _ubufs[0] -*/ -void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) { - const auto &input_ = input.contiguous(); - - // Check element size - const size_t element_size = input.element_size(); - NVTE_CHECK(_ubuf.element_size() == element_size, - "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", - "(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), - " bytes)"); - - // Input data - const size_t input_size = input_.numel(); - const void *src_ptr = input_.data_ptr(); - - // Userbuffers data - void *dst_ptr; - if (local_chunk) { - NVTE_CHECK(_ubufs[_tp_id].numel() == input_size, - "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", - "(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); - dst_ptr = _ubufs[_tp_id].dptr(); - } else { - NVTE_CHECK(_ubuf.numel() == input_size, - "Tried to copy an invalid tensor into a Userbuffers buffer ", - "(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")"); - dst_ptr = _ubuf.dptr(); - } - - // Copy data - NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size, - cudaMemcpyDeviceToDevice, - (cudaStream_t)at::cuda::getCurrentCUDAStream())); -} - -at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional> shape) { - // Check buffer shape - if (shape) { - const size_t requested_size = transformer_engine::pytorch::product(*shape); - if (local_chunk) { - NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(), - "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape, - ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); - } else { - NVTE_CHECK(requested_size == _ubuf.numel(), - "Invalid shape for a Userbuffers buffer (requested shape=", *shape, - ", ubuf_size=", _ubuf.numel(), ")"); - } - } else { - int64_t dim0 = _ubuf.size(0); - int64_t dim1 = _ubuf.size(1); - if (local_chunk) { - dim0 /= _tp_size; - } - shape = {dim0, dim1}; - } - - // Data pointer - void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr(); - - // Construct PyTorch tensor - const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); - return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); -} - -std::pair CommOverlapP2P::get_communication_stream() { - return {at::cuda::getStreamFromExternal(_stream_send[0], at::cuda::current_device()), - at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())}; -} - -void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm( - CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) { - auto main_stream = at::cuda::getCurrentCUDAStream(); - allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream), - at::cuda::CUDAStream(recv_stream), main_stream); -} diff --git a/transformer_engine/pytorch/module/base.py.orig b/transformer_engine/pytorch/module/base.py.orig deleted file mode 100644 index bf4fb97d2d..0000000000 --- a/transformer_engine/pytorch/module/base.py.orig +++ /dev/null @@ -1,1597 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Base modules and utilities for TransformerEngine PyTorch API""" -import io -import math -import os -import pickle -import warnings -from enum import Enum -from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union -from contextlib import contextmanager -import logging -from types import MethodType - -import torch -import torch.nn.functional as F - -import transformer_engine_torch as tex -from transformer_engine.common.recipe import Recipe - -from ._common import _ParameterInitMeta, noop_cat -from ..fp8 import ( - MXFP8BlockScalingRecipeState, - DelayedScalingRecipeState, - Float8CurrentScalingRecipeState, - Float8BlockScalingRecipeState, - NVFP4BlockScalingRecipeState, - FP8GlobalStateManager, - RecipeState, -) -from ..distributed import ( - gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, - _fsdp_gather_tensors, -) -from ..constants import dist_group_type -from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer -from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer -from ..tensor.nvfp4_tensor import NVFP4Quantizer -from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ..tensor._internal.float8_tensor_base import Float8TensorBase -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -from ...common.recipe import DelayedScaling, Recipe -from ...debug.pytorch.debug_state import TEDebugState -from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor -from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled - -__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] - -_2X_ACC_FPROP = False -_2X_ACC_DGRAD = True -_2X_ACC_WGRAD = True -_multi_stream_cublas_workspace = [] -_dummy_wgrads = {} -_cublas_workspace = None -_ub_communicators = None -_NUM_MAX_UB_STREAMS = 3 -_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None -layers_atomic_ring_exchange = [] - - -class UserBufferQuantizationMode(Enum): - """ - UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer. - """ - - NONE = "none" - FP8 = "fp8" - - -def get_cublas_workspace_size_bytes() -> None: - """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: - # 32 MiB for NVFP4 GEMM, plus 256 B for misc scales - return 32 * 1024 * 1024 + 256 - return 4_194_304 - - -def get_workspace() -> torch.Tensor: - """Returns workspace for cublas.""" - global _cublas_workspace - if _cublas_workspace is None: - _cublas_workspace = torch.empty( - get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" - ) - return _cublas_workspace - - -def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: - """Returns workspace for multi-stream cublas.""" - global _multi_stream_cublas_workspace - if not _multi_stream_cublas_workspace: - for _ in range(tex.get_num_cublas_streams()): - _multi_stream_cublas_workspace.append( - torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda") - ) - return _multi_stream_cublas_workspace - - -def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: - """Returns a dummy tensor of given shape.""" - assert len(shape) == 2 - global _dummy_wgrads - if (shape[0], shape[1], dtype) not in _dummy_wgrads: - _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( - shape, - dtype=dtype, - device="cuda", - requires_grad=False, - ) - if zero: - _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) - return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() - - -def initialize_ub( - shape: list, - tp_size: int, - use_fp8: bool = False, - quantization_modes: List[UserBufferQuantizationMode] = None, - dtype: torch.dtype = torch.bfloat16, - ub_cfgs: Optional[Union[dict, List[dict]]] = None, - bootstrap_backend: Union[str, torch.distributed.Backend] = None, -) -> None: - r""" - Initialize the Userbuffers communicator for overlapping tensor-parallel communications with - GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules. - - Parameters - ---------- - shape : list - shape of the communication buffer, typically set to be the same as the global shape of - the input tensor to a te.TransformerLayer forward pass, with the sequence and batch - dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)` - tp_size : int - number of GPUs in the tensor-parallel process group - use_fp8 : bool = False - allocate the communication buffer for FP8 GEMM inputs/outputs. - DEPRECATED: Please use `quantization_modes` instead. - quantization_modes : List[UserBufferQuantizationMode] = None - if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list. - falls back to the legacy `use_fp8` parameter if `None` is provided. - dtype : torch.dtype = torch.bfloat16 - non-FP8 data type of the communication buffer when `use_fp8 = False` - ub_cfgs: dict = None - Configuration dictionary with the structure - ``` - { - : { - "method": <"ring_exchange" or "pipeline">, - "is_reduce_scatter": bool, - "num_sm": int, - "cga_size": int, - "set_sm_margin": bool, - "num_splits": int, - "aggregate": bool, - "atomic_gemm": bool, - "use_ce": bool, - "fp8_buf": bool, - } - } - ``` - for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", - "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", - "fc2_fprop", "fc2_wgrad"]`. - a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes` - bootstrap_backend : str = None - `torch.distributed` communication backend for the all-gather, broadcast and - barrier collectives during Userbuffers initialization. Not all backends are - valid for every cluster configuration and distributed launch method even if - they are available in PyTorch. When left unset, the initialization prefers - to use the MPI backend, falling back first on Gloo and then NCCL if MPI is - not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this - option and always initializes Userbuffers with direct MPI calls in C++, - which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time. - """ - if not tex.device_supports_multicast(): - assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( - "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " - + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." - ) - - if not quantization_modes: - warnings.warn( - "Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes" - " instead.", - DeprecationWarning, - ) - quantization_modes = [ - UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE - ] - else: - assert isinstance(quantization_modes, list), "quantization_modes must be a list" - assert all( - isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes - ), "quantization_modes must be a list of UserBufferQuantizationMode" - - if isinstance(ub_cfgs, dict) or ub_cfgs is None: - ub_cfgs = [ub_cfgs] * len(quantization_modes) - else: - assert len(ub_cfgs) == len( - quantization_modes - ), "Number of ub_cfgs settings must match number of quantization configurations" - - global _ub_communicators - assert _ub_communicators is None, "UB communicators are already initialized." - _ub_communicators = {} - - if tex.ubuf_built_with_mpi(): - # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force - # an MPI_Init() here by creating a new MPI process group... - assert torch.distributed.is_mpi_available() - _ = torch.distributed.new_group(backend="mpi") - helper = tex.CommOverlapHelper() - else: - # Bootstrapping with torch.distributed API, so check backend and construct - # intra/inter-node process groups... - assert ( - torch.distributed.is_initialized() - ), "torch.distributed must be initialized before Userbuffers" - if bootstrap_backend is None: - bootstrap_backend = "nccl" - if torch.distributed.is_mpi_available(): - bootstrap_backend = "mpi" - elif torch.distributed.is_gloo_available(): - bootstrap_backend = "gloo" - else: - assert bootstrap_backend in [ - "gloo", - "mpi", - "nccl", - ], "Invalid torch.distributed backend for bootstrapping Userbuffers!" - assert torch.distributed.is_backend_available(bootstrap_backend), ( - f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " - f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." - ) - - world_group = torch.distributed.new_group(backend=bootstrap_backend) - world_rank = torch.distributed.get_rank(world_group) - world_size = torch.distributed.get_world_size(world_group) - - num_domains = world_size // tp_size - mydomain_idx = world_rank // tp_size - if num_domains > 1: - ranks_per_domain_list = [ - [i * tp_size + t for t in range(tp_size)] for i in range(num_domains) - ] - tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( - ranks_per_domain_list, backend=bootstrap_backend - ) - local_rank = torch.distributed.get_rank(tp_domain_group) - tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group) - - helper = tex.CommOverlapHelper(world_group, tp_domain_group) - else: - # TP model on single NVLink domain, no replication, no data-parallelism - mydomain_idx = 0 - local_rank = world_rank - tp_domain_ranks = list(range(world_size)) - - helper = tex.CommOverlapHelper(world_group) - - if world_rank == 0: - print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True) - if local_rank == 0: - print( - f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n", - end="", - flush=True, - ) - - # Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls - global _cublas_workspace - if _cublas_workspace is None: - _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) - elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS: - # This ensures we don't do `.repeat()` on an already expanded workspace - _cublas_workspace = torch.empty( - get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" - ).repeat(_NUM_MAX_UB_STREAMS) - - # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe - layers_all_gather_overlap = [ - "qkv_fprop", - "qkv_dgrad", - "proj_dgrad", - "proj_wgrad", - "fc1_fprop", - "fc1_dgrad", - "fc2_dgrad", - "fc2_wgrad", - ] - layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] - dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] - # Default overlap methods for layers - methods = { - "ring_exchange": [ - "qkv_fprop", - "fc1_fprop", - "proj_dgrad", - "fc2_dgrad", - ], - "pipeline": ["proj_fprop", "fc2_fprop"], - "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], - "external": ["proj_wgrad", "fc2_wgrad"], - } - - # AG-RS overlap pairs of layers forming a tensor-parallel block - ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} - rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} - external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"} - global layers_atomic_ring_exchange - layers_atomic_ring_exchange = [] - - def get_method(name): - for method, names in methods.items(): - if name in names: - return method - raise KeyError(f"Given layer name {name} does not exist.") - - def get_default_config(name): - global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY - method = get_method(name) - is_reduce_scatter = name in layers_reduce_scatter_overlap - if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: - _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() - default_cfg = { - "method": method, - "is_reduce_scatter": is_reduce_scatter, - "num_sm": 1 if method == "ring_exchange" else 16, - "cga_size": 1 if method == "ring_exchange" else 2, - "set_sm_margin": not method == "ring_exchange", - "num_splits": tp_size if method == "ring_exchange" else 4, - "aggregate": False, - "atomic_gemm": False, - "use_ce": True, - "fp8_buf": name in layers_all_gather_overlap, - "comm_priority": _MAX_STREAM_PRIORITY, - "gemm_priority": _MIN_STREAM_PRIORITY, - "pipeline_rs_overlap_first_gemm": False, - } - return default_cfg - - def add_ub( - name: str, - quantization_mode: UserBufferQuantizationMode, - method: str, - is_reduce_scatter: bool, - num_sm: int = 16, - cga_size: int = 2, - set_sm_margin: bool = False, - num_splits: int = 0, - aggregate: bool = False, - atomic_gemm: bool = False, - use_ce: bool = True, - fp8_buf: bool = False, - comm_priority: int = 0, - gemm_priority: int = 0, - pipeline_rs_overlap_first_gemm: bool = False, - ) -> None: - if atomic_gemm: - warnings.warn( - "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." - ) - assert ( - quantization_mode == UserBufferQuantizationMode.FP8 - ), "Atomic GEMM overlap supported only for FP8 GEMM." - if method in ("bulk", "external"): - warnings.warn( - f"At {name}, atoimic GEMM not is supported for a bulk overlap." - "Defaulting to `atomic_gemm=False`." - ) - atomic_gemm = 0 - if not is_reduce_scatter and method == "pipeline": - raise ValueError( - f"At {name}, `pipeline` overlap method is not supported for AllGather." - ) - # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`. - # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality. - global layers_atomic_ring_exchange - if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs: - layers_atomic_ring_exchange += [name, ag_rs_pairs[name]] - if name in rs_ag_pairs: - assert_message = ( - f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk " - "outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and " - "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses " - "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config " - "for functionality." - ) - if name in layers_atomic_ring_exchange: - assert atomic_gemm and method == "ring_exchange", assert_message - else: - if atomic_gemm and method == "ring_exchange": - assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message - - if name in external_gemm_to_overlap: - assert method == "external", ( - f"At {name}, `external` overlap method is specified, but the selected method is" - f" {method}" - ) - assert external_gemm_to_overlap[name] in methods["ring_exchange"], ( - f"At {name}, `external` overlap method is specified, but the external gemm" - f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" - ) - - buffer_dtype = ( - torch.uint8 - if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf) - else dtype - ) - if method == "ring_exchange": - ub_obj = tex.CommOverlapP2P( - shape, # Communication buffer shape - buffer_dtype, # Communication buffer data type - helper, # Helper for torch.distributed callbacks during bootstrapping - tp_size, # Tensor-parallel group size (may be different than local_size) - tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, - num_max_streams=_NUM_MAX_UB_STREAMS, - comm_cga_size=cga_size, - num_comm_sm=num_sm, - set_sm_margin=set_sm_margin, - atomic_gemm=atomic_gemm, - use_ce=use_ce, - aggregate=aggregate, - gemm_priority=gemm_priority, - comm_priority=comm_priority, - ) - else: - ub_obj = tex.CommOverlap( - shape, # Communication buffer shape - buffer_dtype, # Communication buffer data type - helper, # Helper for torch.distributed callbacks during bootstrapping - tp_size, # Tensor-parallel group size (may be different than local_size) - num_splits=num_splits, - num_max_streams=_NUM_MAX_UB_STREAMS, - comm_cga_size=cga_size, - num_comm_sm=num_sm, - set_sm_margin=set_sm_margin, - atomic_gemm=atomic_gemm, - gemm_priority=gemm_priority, - comm_priority=comm_priority, - rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, - ) - _ub_communicators[(name, quantization_mode)] = ub_obj - - for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs): - if user_ub_cfg is not None: - for name in dgrad_reduce_scatter_overlap: - if ( - name in user_ub_cfg - and "method" in user_ub_cfg[name] - and user_ub_cfg[name]["method"] != "bulk" - ): - wgrad_name = name.replace("dgrad", "wgrad") - assert wgrad_name not in user_ub_cfg - layers_reduce_scatter_overlap.remove(wgrad_name) - layers_all_gather_overlap.remove(name) - layers_reduce_scatter_overlap.append(name) - methods["bulk"].remove(name) - new_method = user_ub_cfg[name]["method"] - methods[new_method].append(name) - - for name in ( - methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] - ): - ub_cfg = get_default_config(name) - if user_ub_cfg is not None and name in user_ub_cfg: - fp8_buf = (name in layers_all_gather_overlap) or ( - user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"] - ) - ub_cfg.update(user_ub_cfg[name]) - ub_cfg["fp8_buf"] = fp8_buf - add_ub(name, quantization_mode, **ub_cfg) - - -def get_ub(name: str, use_fp8: bool): - """Get userbuffer communicator corresponding to give key.""" - # For now use `use_fp8` boolean input as it matches the current design in the modules - # So favour simplicity until the correct design becomes clear. - # This is mainly an internal API so we don't need to worry about future changes - key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE) - assert _ub_communicators is not None, "UB manager is not initialized." - assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered." - return _ub_communicators[key] - - -def destroy_ub(): - """Destroy all allocated userbuffer communicators.""" - global _ub_communicators - _ub_communicators = None - global layers_atomic_ring_exchange - layers_atomic_ring_exchange = [] - - -def fill_userbuffers_buffer_for_all_gather( - comm, - local_tensor: torch.Tensor, - quantizer: Optional[Quantizer], - process_group, -) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]: - """Fill local shard of Userbuffers buffer with data for all-gather - - Returns the full tensor and the local shard, both using the - Userbuffers buffer as their underlying data. These tensors should - be used carefully (e.g. only immediately before and after a - Userbuffers operation) since the underlying data may be - overwritten by other Userbuffers operations. - - May perform blocking communication if needed for the gathered - tensor's metadata, e.g. scaling factors. - - """ - - # Tensor dimensions - local_shape = local_tensor.size() - if not local_shape: - raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})") - process_group_size = torch.distributed.get_world_size(process_group) - global_shape = list(local_shape) - global_shape[0] *= process_group_size - - # Unquantized data - if quantizer is None: - if isinstance(local_tensor, QuantizedTensorBase): - local_tensor = local_tensor.dequantize() - if comm.is_fp8_ubuf(): - raise RuntimeError( - "Attempting to all-gather unquantized tensor, " - "but Userbuffers is initialized with FP8 buffers" - ) - comm.copy_into_buffer(local_tensor, local_chunk=True) - global_tensor = comm.get_buffer(shape=global_shape) - return global_tensor, local_tensor - - # FP8 data - if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): - if not isinstance(local_tensor, Float8TensorBase): - if isinstance(local_tensor, QuantizedTensorBase): - local_tensor.dequantize() - quantizer.set_usage(rowwise=True, columnwise=False) - local_tensor = quantizer(local_tensor) - if not comm.is_fp8_ubuf(): - raise RuntimeError( - "Attempting to all-gather FP8 tensor, " - "but Userbuffers is not initialized with FP8 buffers" - ) - comm.copy_into_buffer(local_tensor._data, local_chunk=True) - global_tensor_data = comm.get_buffer(shape=global_shape) - global_tensor = Float8TensorBase( - data=global_tensor_data, - fp8_scale_inv=local_tensor._scale_inv, - fp8_dtype=local_tensor._fp8_dtype, - quantizer=quantizer, - ) - return global_tensor, local_tensor - - # MXFP8 data - if isinstance(quantizer, MXFP8Quantizer): - - # Cast to MXFP8 if needed - if not isinstance(local_tensor, MXFP8TensorBase): - if isinstance(local_tensor, QuantizedTensorBase): - local_tensor.dequantize() - local_tensor = quantizer(local_tensor) - if not comm.is_fp8_ubuf(): - raise RuntimeError( - "Attempting to all-gather MXFP8 tensor, " - "but Userbuffers is not initialized with FP8 buffers" - ) - - # Check which MXFP8 buffer to communicate - if quantizer.rowwise_usage == quantizer.columnwise_usage: - raise ValueError( - "Userbuffers can only communicate one MXFP8 buffer at a time, " - f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, " - f"columnwise_usage={quantizer.columnwise_usage}" - ) - with_rowwise_data = quantizer.rowwise_usage - - # Copy MXFP8 data to local chunk of Userbuffers buffer - local_data = ( - local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data - ) - comm.copy_into_buffer(local_data, local_chunk=True) - - # Gather scaling-inverses - if math.prod(local_shape[:-1]) % 128 != 0: - raise ValueError( - "Userbuffers requires MXFP8 tensor dims that are divisible by 128, " - f"but got MXFP8 tensor with shape={tuple(local_shape)}" - ) - local_scale_inv = ( - local_tensor._rowwise_scale_inv - if with_rowwise_data - else local_tensor._columnwise_scale_inv - ) - local_scale_inv_size = list(local_scale_inv.size()) - global_scale_inv = torch.empty( - [process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:], - dtype=local_scale_inv.dtype, - device=local_scale_inv.device, - ) - torch.distributed.all_gather_into_tensor( - global_scale_inv, - local_scale_inv, - group=process_group, - ) - - # Construct MXFP8 tensor with Userbuffers buffer - rowwise_data, rowwise_scale_inv = None, None - columnwise_data, columnwise_scale_inv = None, None - global_data = comm.get_buffer(shape=global_shape) - if with_rowwise_data: - rowwise_data, rowwise_scale_inv = global_data, global_scale_inv - else: - columnwise_data, columnwise_scale_inv = global_data, global_scale_inv - global_tensor = MXFP8TensorBase( - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - fp8_dtype=local_tensor._fp8_dtype, - quantizer=quantizer, - ) - return global_tensor, local_tensor - - # Unsupported data format - raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})") - - -class TransformerEngineBaseModule(torch.nn.Module, ABC): - """Base TE module.""" - - def __init__(self) -> None: - super().__init__() - assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.name = None - self.next_iter_when_debug_should_be_run = 0 - self.fp8_initialized = False - self.fp8 = False - self.fp8_calibration = False - self.fp8_meta = {} - self.fp8_meta["fp8_checkpoint"] = False - self.fp8_meta["fp8_group"] = None - self.fp8_meta_tensors_initialized = False - self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} - self.tp_group = None - self.tp_size = 1 - self.sequence_parallel = False - self.param_init_meta = {} - self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() - self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() - self.fsdp_wrapped = False - self.fsdp_group = None - self._fp8_workspaces: Dict[str, QuantizedTensor] = {} - self.activation_dtype: Optional[torch.dtype] = None - self.wgrad_accumulation_and_reduce_hooks = [] - - if not TEDebugState.debug_enabled: - TEDebugState.initialize() - - # Names of attributes that can be set quickly (see __setattr__ - # method) - _fast_setattr_names: Set[str] = { - "activation_dtype", - "fp8", - "fp8_initialized", - "fp8_calibration", - "fp8_parameters", - } - - def __setattr__(self, name: str, value: Any) -> None: - if name in TransformerEngineBaseModule._fast_setattr_names: - # torch.nn.Module has a custom __setattr__ that handles - # modules, parameters, and buffers. This is unnecessary - # overhead when setting plain attrs. - self.__dict__[name] = value - else: - # Default case - super().__setattr__(name, value) - - def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: - """ - Delayed scaling only. - - Increase or decrease size of amax history based on given `length`. - - .. warning:: - This changes the underlying amax memory location. - """ - if fwd is None: - fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd") - else: - fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",) - - for meta_key in fp8_meta_tensor_keys: - if meta_key not in self.fp8_meta: - # Handles non-parameter FP8 modules, e.g. DPA. - continue - curr_len = self.fp8_meta[meta_key].amax_history.shape[0] - if length == curr_len: - continue - if length < curr_len: - self.fp8_meta[meta_key].amax_history = ( - self.fp8_meta[meta_key].amax_history[:length].clone() - ) - elif length > curr_len: - extra_rows = length - curr_len - self.fp8_meta[meta_key].amax_history = F.pad( - self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows) - ) - - # Update quantizers with new amax pointers. - self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers() - # Make sure weight tensors has correct quantizers - self._update_weight_quantizers() - - # Update the global buffers with new amax and history pointers. - if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: - fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[ - FP8GlobalStateManager.get_buffer_info() - ] - for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): - if buffer_key in FP8GlobalStateManager.global_amax_buffer: - assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer - ), "TE internal error during amax history change." - FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[ - meta_key - ].amax_history[0] - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( - self.fp8_meta[meta_key].amax_history - ) - - def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: - """Init scales and amaxes for fwd | bwd.""" - fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" - - # Return early if recipe state matches recipe - if self.fp8_meta_tensors_initialized: - recipe_state = self.fp8_meta[fp8_meta_tensor_key] - if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState): - self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd) - return - if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): - return - if recipe.float8_current_scaling() and isinstance( - recipe_state, Float8CurrentScalingRecipeState - ): - return - if recipe.float8_block_scaling() and isinstance( - recipe_state, Float8BlockScalingRecipeState - ): - return - if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): - return - - # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and - # 2 (grad_output and grad_input) for bwd - num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 - - # Initialize recipe state and quantizers - recipe_state = RecipeState.create( - recipe, - mode=("forward" if fwd else "backward"), - num_quantizers=num_fp8_tensors, - ) - - self.fp8_meta[fp8_meta_tensor_key] = recipe_state - self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() - - def _update_weight_quantizers(self) -> None: - """Update the quantizers for the weight tensors.""" - weight_tensors = self._get_weight_tensors() - weight_quantizers = self._get_weight_quantizers() - assert len(weight_tensors) == len(weight_quantizers), ( - f"Number of weight tensors ({len(weight_tensors)}) and quantizers " - f"({len(weight_quantizers)}) must match" - ) - for weight, quantizer in zip(weight_tensors, weight_quantizers): - if quantizer is not None and isinstance(weight, QuantizedTensorBase): - weight.update_quantizer(quantizer) - - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: - """Get the weight tensors of the module.""" - raise NotImplementedError( - f"{self.__class__.__name__} class does not implement _get_weight_tensors function" - ) - - def _get_weight_quantizers(self) -> List[Quantizer]: - """Get the weight quantizers of the module.""" - raise NotImplementedError( - f"{self.__class__.__name__} class does not implement _get_weight_quantizers function" - ) - - def init_fp8_meta_tensors(self, recipe: Recipe) -> None: - """Init scales and amaxes.""" - self.set_meta_tensor(True, recipe) - self.set_meta_tensor(False, recipe) - - self.fp8_meta_tensors_initialized = True - - def get_fp8_meta_tensors(self) -> None: - """Get scales and amaxes.""" - fwd_key, bwd_key = "scaling_fwd", "scaling_bwd" - if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta: - return None - - fp8_meta_tensors = {fwd_key: [], bwd_key: []} - with torch.no_grad(): - for key in (fwd_key, bwd_key): - fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone()) - fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone()) - return fp8_meta_tensors - - def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None: - """Reset scales and amaxes.""" - - def reset(key): - if key in self.fp8_meta: - if fp8_meta_tensors is None: - self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) - self.fp8_meta[key].amax_history.copy_( - torch.zeros_like(self.fp8_meta[key].amax_history) - ) - else: - assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." - self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) - self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1]) - - with torch.no_grad(): - reset("scaling_fwd") - reset("scaling_bwd") - - def get_extra_state(self) -> torch.Tensor: - """Save before checkpointing.""" - - # This implementation is working around a few issues: - # - # (1) PyTorch's "extra state" infrastructure might be able to - # support any picklable type, but they make no guarantees. - # We have experienced problems (e.g. in ONNX export) with - # non-tensor extra state. - # (2) PyTorch's checkpointing infrastructure does not remap - # devices for "extra state" like it does for "state dict". - # Thus, we want to avoid putting extra state on the GPU - # since it may be loaded on the wrong device. - # (3) The extra state consists of many small tensors. If we - # want to copy them all to CPU, then we need to avoid the - # overhead of many GPU-CPU memory transfers. - # - # See: https://github.com/NVIDIA/TransformerEngine/pull/351 - # See: https://github.com/NVIDIA/TransformerEngine/pull/363 - - def to_cpu(src: torch.Tensor) -> torch.Tensor: - """Helper function to make CPU copy of tensor - - Memory transfer is asynchronous w.r.t. host, so GPU should - be synchronized before using result. - - """ - dst = torch.empty_like(src, device="cpu") - dst.copy_(src, non_blocking=True) - return dst - - # Store FP8 state if needed - state = None - fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration - if not fp8_checkpoint: - return torch.empty(0, dtype=torch.uint8) - - # Copy tensors to CPU and store - state = {} - state["recipe"] = self.fp8_meta["recipe"] - if state["recipe"].delayed(): - state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) - state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) - state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) - state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history) - - # Store other pickelable values - extra = {} - for k, v in self.fp8_meta.items(): - if k != "buffer_index_and_autocast_key" and isinstance( - v, (bool, int, float, str, tuple, list) - ): - extra[k] = v - state["extra_fp8_variables"] = extra - - # Serialize state into byte tensor - torch.cuda.synchronize() - state_serialized = bytearray(pickle.dumps(state)) - state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) - return state_serialized - - def set_extra_state(self, state: torch.Tensor) -> None: - """Load previous state.""" - - # Maintain backwards compatibility with older checkpoints. - if state is None: - return - - # Load state - if isinstance(state, torch.Tensor): - # No FP8 is indicated by an empty tensor we don't need to unpickle. - if state.numel() == 0: - return - # Default format: byte tensor with pickled data - state = pickle.loads(state.detach().cpu().numpy().tobytes()) - elif isinstance(state, io.BytesIO): - # Deprecated format with io.BytesIO - state.seek(0) - state = torch.load(state, map_location="cuda") - else: - raise RuntimeError("Unsupported checkpoint format.") - - if state is None: - return - - # TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing - if "recipe" not in state: - # TE 1.x only supported delayed scaling, which was the default recipe - state["recipe"] = DelayedScaling() - # TE 1.x also saved scale_inv, which is not needed with Recipe object - state.pop("scale_inv_fwd", None) - state.pop("scale_inv_bwd", None) - - # Load extra items - self.fp8_meta.update(state["extra_fp8_variables"]) - self.fp8_meta["recipe"] = state["recipe"] - if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta: - del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] - - # Initialize before loading - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) - - def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: - """Helper function to copy tensor from CPU - - Memory transfer is asynchronous w.r.t. host, so GPU should - be synchronized before using result. - - """ - dst.copy_(src, non_blocking=True) - - # Load tensors - if self.fp8_meta["recipe"].delayed(): - copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) - copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) - copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) - copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history) - torch.cuda.synchronize() - - def set_activation_dtype(self, inp: torch.Tensor) -> None: - """Get activation data type for AMP.""" - # Native AMP (`torch.autocast`) gets highest priority - if torch.is_autocast_enabled(): - self.activation_dtype = torch_get_autocast_gpu_dtype() - return - - # All checks after this have already been performed once, thus skip - if self.activation_dtype == inp.dtype: - return - - dtype = inp.dtype - if not self.allow_different_data_and_param_types: - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) - self.activation_dtype = dtype - - def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: - """ - Set the tensor parallel group for the given - module before executing the forward pass. - - Parameters - ---------- - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - """ - self.tp_group = tp_group - self.tp_group_initialized = True - - def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: - """returns the FP8 weights.""" - fp8_params = [] - for param in self.parameters(recurse=False): - if isinstance(param, QuantizedTensor) and param.requires_grad: - fp8_params.append(param) - if len(fp8_params) == 0: - return None - return fp8_params - - # This routine is shared across FP8 and FP8_calibration paths so should not actually - # assume FP8 execution. - def init_fp8_metadata(self, num_gemms: int = 1) -> None: - """Initialize fp8 related metadata and tensors during fprop.""" - _original_recipe = self.fp8_meta.get("recipe", None) - - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - fp8_enabled = self.fp8 or self.fp8_calibration - self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - - if self.fp8_parameters or fp8_enabled: - if ( - self.fp8_initialized - and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] - ): - # FP8 init has already been run and recipe is the same, don't do anything. - return - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - else: - # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False - return - - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) - - if fp8_enabled: - # Set FP8 and other FP8 metadata - self.fp8_meta["num_gemms"] = num_gemms - self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() - - # Set FP8_MAX per tensor according to recipe - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd - - # Allocate scales and amaxes - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) - self.fp8_initialized = True - - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - - _current_recipe = self.fp8_meta["recipe"] - if _original_recipe is not None and not ( - issubclass(_current_recipe.__class__, _original_recipe.__class__) - or issubclass(_original_recipe.__class__, _current_recipe.__class__) - ): - warnings.warn( - f"Recipe type changed from {_original_recipe.__class__.__name__} " - f"to {_current_recipe.__class__.__name__}. " - "This may affect model behavior." - ) - # Clear cached workspaces as they were created with the old recipe/quantizer type - self._fp8_workspaces.clear() - - @contextmanager - def prepare_forward( - self, - inp: torch.Tensor, - num_gemms: int = 1, - allow_non_contiguous: bool = False, - allow_different_data_and_param_types: bool = False, - ) -> Generator[torch.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - self.allow_different_data_and_param_types = allow_different_data_and_param_types - self.forwarded_at_least_once = True - # Activation recomputation is used and this is the second forward phase. - if self.fp8 and in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) - else: - assert inp.is_cuda, "TransformerEngine needs CUDA." - - if self.tp_size > 1: - assert self.tp_group_initialized, "TP group not initialized." - - self.set_activation_dtype(inp) - self.init_fp8_metadata(num_gemms=num_gemms) - self._check_weight_tensor_recipe_correspondence() - - if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): - assert self.fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across tensor parallel group is " - "necessary when using sequence parallelism with FP8." - ) - - if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) - - # Activation recomputation is used and this is the first forward phase. - if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): - FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) - - with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): - if not allow_non_contiguous and not inp.is_contiguous(): - inp = inp.contiguous() - yield inp - - if self.fp8 and in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) - - def set_nccl_overlap_warning_if_tp(self) -> None: - """When using TP, the NCCL communication needs to be scheduled - before the GEMM for there to be a guaranteed overlap. From the - host side in TE, the comm calls are always launched first, but - to ensure that the GEMM isn't scheduled first, the environment - variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to - force a single channel. - """ - if self.tp_size == 1: - return - num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) - if num_cuda_work_queues != 1: - warnings.warn( - "To guarantee overlapping TP and SP collectives with the backward" - "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1" - ) - - @staticmethod - def grad_output_preprocess( - ctx, - grad_output: torch.Tensor, - row_parallel_mode: bool, - quantizer: Optional[Quantizer], - ) -> Tuple[Union[torch.Tensor, None], ...]: - """Utility function for backward. - Returns tuple in order (all optional/None based on training precion/recipe): - R1: gathered `grad_output`. - R2: bias gradient on R1. - - """ - grad_output = grad_output.reshape((-1, grad_output.shape[-1])) - grad_output = grad_output.contiguous() - gather_grad_output = row_parallel_mode and ctx.sequence_parallel - - # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8 and not ctx.debug: - if gather_grad_output: - if not ctx.ub_overlap_ag: # Perform NCCL all-gather - grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) - else: # Initialize Userbuffers all-gather - grad_output, _ = fill_userbuffers_buffer_for_all_gather( - ctx.ub_obj_gradout, - grad_output, - None, - ctx.tp_group, - ) - return grad_output, None - - # FP8 with all-gather: unfused bgrad, fused cast + transpose - # Also supports debug quantization, which is handled inside gather_along_first_dim. - if gather_grad_output: - grad_bias = None - if ctx.use_bias: - grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) - if ctx.ub_overlap_ag: - # Quantize the gradient if needed - if not isinstance( - grad_output, - ( - QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, - ), - ): - grad_output = quantizer(grad_output) - - # Copy into communication buffer, and replace original gradient with it - grad_output, _ = fill_userbuffers_buffer_for_all_gather( - ctx.ub_obj_gradout, - grad_output, - quantizer, - ctx.tp_group, - ) - else: - grad_output, _ = gather_along_first_dim( - grad_output, - ctx.tp_group, - quantizer=quantizer, - ) - return grad_output, grad_bias - - # Debug without all-gather: unfused cast and bgrad - # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None - if ctx.debug: - grad_output_ = quantizer(grad_output) - if ( - isinstance( - grad_output_.get_tensor(True), - ( - QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, - ), - ) - and ctx.use_bias - ): - grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) - else: - grad_bias = None - grad_output = grad_output_ - return grad_output, grad_bias - - # FP8 without all-gather: fused bgrad + cast + transpose - grad_bias = None - if ctx.use_bias: - if isinstance( - grad_output, - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), - ): - grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) - else: - # TODO(ksivaman): Re-add fusion once kernel is available. - if isinstance(quantizer, (Float8BlockQuantizer, NVFP4Quantizer)): - # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. - grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) - else: - grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance(grad_output, QuantizedTensorBase): - grad_output = quantizer(grad_output) - return grad_output, grad_bias - - def register_parameter(self, name, param, **kwargs): - """ - Thin wrapper around PyTorch parameter registration to stash additional parameter - metedata used in deferred initialization. - """ - super().register_parameter(name, param) - self.param_init_meta[name] = _ParameterInitMeta(**kwargs) - - def reset_parameters(self, defer_init: Optional[bool] = False) -> None: - """ - Reset all module parameters to initial values. Unless deferred initialization - is specified, all parameters on a 'meta' device are also materialized on a real cuda - device before the values are reset to initial. - """ - if defer_init: - return - - for name, param in self.named_parameters(recurse=False): - # Ensure parameter is on a real device - if param.device == torch.device("meta"): - param = torch.empty_like(param, device="cuda") - - # Initialize the parameter values on device - init_fn = self.param_init_meta[name].init_fn - get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker - if get_rng_state_tracker is None: - init_fn(param) - else: - if hasattr(self, "rng_tracker_name") and self.rng_tracker_name: - with get_rng_state_tracker().fork(self.rng_tracker_name): - init_fn(param) - else: - with get_rng_state_tracker().fork(): - init_fn(param) - - # Wrap parameters in QuantizedTensor if needed - fp8_meta_index = self.param_init_meta[name].fp8_meta_index - high_precision_init_val = None - if self.primary_weights_in_fp8 and fp8_meta_index is not None: - - # Keep high-precision values on CPU if needed - if self.preserve_high_precision_init_val: - high_precision_init_val = param.detach().cpu() - - # Configure quantizer - quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] - if quantizer is None: - raise RuntimeError("Weight quantizer has not been initialized") - quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) - quantizer.internal = False - - # Quantize parameter - param = quantizer(param) - - # Redo parameter wrap in case we broke it above - # NOTE: Currently this can only be broken when primary weights are in Fp8 but - # re-applying the nn.Parameter() wrap is a no-op when the input is already - # a parameter so we always re-apply it just for extra safety. - param = torch.nn.Parameter(param) - - # Keep high-precision values on CPU if needed - if high_precision_init_val is not None: - - # - Master weights are initialized from model weights, if we use fp8 primary - # weights to initialize master weights, the numerical values of master weights - # are not consistent with the numerical values when we initialize them from - # bf16/fp16 weights. - # - So we add a `_high_precision_init_val` attribute to each model weight to store - # the original bf16/fp16 weight on cpu before casting it to fp8. And users can - # use `get_high_precision_init_val` to get this cpu tensor. - # - This cpu tensor is not needed once the master weight is initialized, so users - # should call `clear_high_precision_init_val` to remove it after master weight - # is initialized. - - def get(self): - if hasattr(self, "_high_precision_init_val"): - return self._high_precision_init_val - return None - - def clear(self): - if hasattr(self, "_high_precision_init_val"): - del self._high_precision_init_val - - param._high_precision_init_val = high_precision_init_val - param.get_high_precision_init_val = MethodType(get, param) - param.clear_high_precision_init_val = MethodType(clear, param) - - setattr(self, name, param) - - @abstractmethod - def forward(self): - """Needs override.""" - - def get_weight_workspace( - self, - *, - tensor: Optional[torch.Tensor] = None, - quantizer: Optional[Quantizer] = None, - cache_name: Optional[str] = None, - update_workspace: bool = True, - skip_update_flag: Optional[torch.Tensor] = None, - fsdp_group: Optional[dist_group_type] = None, - workspace_dtype: Optional[torch.dtype] = None, - ) -> QuantizedTensor: - """Get workspace buffer for weights and maybe update its values - - The workspace buffer may be cached for future function calls. - - Parameters - ---------- - tensor : torch.Tensor, optional - Values to copy into workspace. Required if the workspace - is being constructed or updated. - quantizer: Quantizer, optional - Quantizer used to cast the weights. Required if the - workspace is being constructed or updated. - cache_name: str, optional - Key for caching. - update_workspace: bool, default = `True` - Update workspace with values from `tensor`. - skip_update_flag: torch.Tensor, optional - GPU flag to skip updating the workspace. Take precedence - over `update_workspace` if provided. - fsdp_group: bool, default = None - FSDP process group that the weights are distributed over. - workspace_dtype: torch.dtype, default = None - If weight workspace contains high-precision tensor - for example - for debug quantization, this is dtype of the tensor. - """ - - # Handle case where weights are already quantized - # Note: Make sure weights have required usages, but do not - # destroy unnecessary usages since they may be used later. - if isinstance(tensor, QuantizedTensor): - update_rowwise_usage = True if quantizer.rowwise_usage else None - update_columnwise_usage = True if quantizer.columnwise_usage else None - tensor.update_usage( - rowwise_usage=update_rowwise_usage, - columnwise_usage=update_columnwise_usage, - ) - return tensor - - # Try getting workspace from cache - out = None - if cache_name is not None: - out = self._fp8_workspaces.get(cache_name, None) - - # Reset cache if workspace is invalid - if out is not None and quantizer is not None: - reset_cache = False - if isinstance(out, Float8TensorBase): - if ( - not is_non_tn_fp8_gemm_supported() - and quantizer.columnwise_usage - and out._transpose is None - ): - reset_cache = True - elif isinstance(out, MXFP8TensorBase): - if quantizer.rowwise_usage and out._rowwise_data is None: - reset_cache = True - elif quantizer.columnwise_usage and out._columnwise_data is None: - reset_cache = True - if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): - reset_cache = True - if reset_cache: - out = None - del self._fp8_workspaces[cache_name] - - # Gather cached Fp8 workspace if it's distributed - # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work - # for models initialized with Fp8 primary weights. - if ( - out is not None - and tensor is not None - and fsdp_group is not None - and out.data.shape != tensor.data.shape - ): - _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) - - # Construct workspace if needed - if out is None: - if tensor is None or quantizer is None: - raise ValueError( - "tensor and quantizer kwargs must be provided to construct FP8 workspace" - ) - - if cache_name is not None: - # Ensure the tensor in the cache is an instance of torch.Tensor, - # as it persists beyond a single forward pass. - # Setting internal=True would cause the data to be removed in prepare_for_saving(...). - quantizer_internal = quantizer.internal - quantizer.internal = False - out = quantizer.quantize(tensor, dtype=workspace_dtype) - if cache_name is not None: - quantizer.internal = quantizer_internal - - # Update cache - if cache_name is not None: - self._fp8_workspaces[cache_name] = out - return out - - # Update workspace if needed - if skip_update_flag is not None: - update_workspace = True - if update_workspace: - if tensor is None: - raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if hasattr(out, "quantize_"): - out.quantize_(tensor, noop_flag=skip_update_flag) - else: - tex.quantize(tensor, quantizer, out, skip_update_flag) - return out - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - """ - This function loads tensors and extra state including fp8 metadata. - This metadata is essential for copying fp8 tensors, as the copy_ function - uses the scale_inv parameter from fp8_meta to set the correct scaling factor - for the new tensor. - Hence, this extra state must be loaded before the tensor copying process, - not after, as is typically done in _load_from_state_dict. - Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True, - otherwise, this behavior is not required. - """ - if self.primary_weights_in_fp8: - extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX - if extra_state_key in state_dict: - self.set_extra_state(state_dict[extra_state_key]) - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - - def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook): - """ - This method is used to manually control the weight gradient accumulation and reduce. - This method should be called before the backward() method. - Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation - and reduce in backward(); - And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method. - """ - self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook) - - def backward_dw(self): - """ - Execute the delayed weight gradient computation. - This method is called after the main backward pass to compute weight gradients. - """ - if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): - return - with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): - (wgrad, bgrad), _ = self.wgrad_store.pop() - if not self.fuse_wgrad_accumulation: - weight_tensor = noop_cat(self._get_weight_tensors()) - weight_tensor.grad = wgrad.to(weight_tensor.dtype) - if self.use_bias: - bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) - if bias_tensor.grad is None: - bias_tensor.grad = bgrad.to(bias_tensor.dtype) - del wgrad - del bgrad - for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: - wgrad_accumulation_and_reduce_hook() - - def is_debug_iter(self) -> bool: - """ - This function checks if the debug should be enabled for this layer. - """ - debug = TEDebugState.debug_enabled - if not debug: - return False - self._validate_name() - - # If layer is run first time in new iteration, - # we need to check if the debug should be enabled for this layer - - # maybe in previous iterations debug features returned information - # that no feature will be active for this layer for multiple next iterations. - started_new_iteration = TEDebugState.get_iteration() != getattr( - self, "debug_last_iteration", None - ) - if started_new_iteration: - if self.next_iter_when_debug_should_be_run is None: - debug = False - else: - debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() - return debug - - def no_debug_features_active(self, quantizers): - """ - Checks if any debug feature is active for this layer. - """ - run_current = any_feature_enabled(quantizers) - - # Sometimes features inform that they will not be enabled for particular layer - # for multiple next iterations. - self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) - - if not run_current: - return True - - if self.primary_weights_in_fp8: - raise RuntimeError("FP8 weights are not supported in debug mode.") - return False - - def _validate_name(self): - """ - Validate name passed to the module. - This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. - If no name is assigned, it creates a default name with layer count as the variable. - """ - if self.name is not None: - return - assert TEDebugState.debug_enabled - import nvdlfw_inspect.api as debug_api - - if self.name is None: - debug_api.log_message( - "Names are not provided to debug modules. ", - "Creating and using generic names. Pass names to debug modules for better" - " insight. ", - level=logging.WARNING, - ) - self.name = f"Layer_{TEDebugState.get_layer_count()}" - - def _check_weight_tensor_recipe_correspondence(self) -> None: - """ - Verify that the weight tensor types match their corresponding recipe type. - This is invoked in the forward(). - - This establishes a 1:1 correspondence between recipe types and tensor types: - - DelayedScaling → Float8Tensor - - Float8CurrentScaling → Float8Tensor - - MXFP8BlockScaling → MXFP8Tensor - - Float8BlockScaling → Float8BlockTensor - - Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()), - but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()). - """ - if not self.fp8 and not self.fp8_calibration: - return - if not hasattr(self, "weight_names") or not self.weight_names: - return - - recipe = self.fp8_meta["recipe"] - weight_tensors = [getattr(self, name) for name in self.weight_names] - for i, tensor in enumerate(weight_tensors): - if isinstance(tensor, QuantizedTensorBase): - quantizer = tensor._get_quantizer() - if quantizer is None: - continue - compatible_recipe_class = quantizer._get_compatible_recipe() - if compatible_recipe_class is None: - continue - if not isinstance(recipe, compatible_recipe_class): - raise RuntimeError( - f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe" - f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}." - " Please check the recipes assigned during fp8_model_init() and" - " fp8_autocast() calls." - ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py.orig b/transformer_engine/pytorch/module/layernorm_linear.py.orig deleted file mode 100644 index 6dbbd335eb..0000000000 --- a/transformer_engine/pytorch/module/layernorm_linear.py.orig +++ /dev/null @@ -1,1827 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""LayerNormLinear API""" -import os -import warnings -from typing import Callable, Dict, Optional, Tuple, Union, List -from functools import reduce -from operator import mul as multiply_op - -import torch -from torch.nn import init - -import transformer_engine_torch as tex - -from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch import torch_version -from transformer_engine.pytorch.tensor.utils import is_experimental -from .base import ( - fill_userbuffers_buffer_for_all_gather, - get_workspace, - get_ub, - TransformerEngineBaseModule, - get_dummy_wgrad, - _2X_ACC_FPROP, - _2X_ACC_DGRAD, - _2X_ACC_WGRAD, -) -from ..fp8 import FP8GlobalStateManager -from ..utils import ( - assert_dim_for_fp8_exec, - assert_dim_for_all_gather, - cast_if_needed, - clear_tensor_data, - divide, - get_default_init_method, - init_method_constant, - nvtx_range_pop, - nvtx_range_push, - requires_grad, - needs_quantized_gemm, -) -from ..distributed import ( - set_tensor_model_parallel_attributes, - get_distributed_world_size, - allreduce, - symmetric_all_reduce, - reduce_scatter_along_first_dim, - gather_along_first_dim, - in_fp8_activation_recompute_phase, - _fsdp_scatter_tensors, - _fsdp_gather_tensors, -) -from ..constants import GemmParallelModes, dist_group_type -from ..jit import no_torch_dynamo -from ..graph import is_graph_capturing -from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers -from ..tensor.quantized_tensor import ( - QuantizedTensor, - QuantizedTensorBase, - Quantizer, - prepare_for_saving, - restore_from_saved, -) -from ...debug.pytorch.debug_state import TEDebugState -from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -from ..export import is_in_onnx_export_mode, assert_warmed_up -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload - -from ..cpp_extensions import ( - general_gemm, -) - -__all__ = ["LayerNormLinear"] - - -class _LayerNormLinear(torch.autograd.Function): - """LayerNormLinear semi-top level module - Calls custom cuda extensions. - """ - - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - ln_weight: torch.Tensor, - ln_bias: Union[torch.Tensor, None], - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - is_first_microbatch: Union[bool, None], - fp8: bool, - fp8_calibration: bool, - wgrad_store: WeightGradStore, - fuse_wgrad_accumulation: bool, - input_quantizer: Optional[Quantizer], - weight_quantizer: Optional[Quantizer], - output_quantizer: Optional[Quantizer], - grad_input_quantizer: Optional[Quantizer], - grad_weight_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], - cpu_offloading: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - sequence_parallel: bool, - tensor_parallel: bool, - activation_dtype: torch.dtype, - parallel_mode: Union[str, None], - return_layernorm_output: bool, - return_layernorm_output_gathered: bool, - is_grad_enabled: bool, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - zero_centered_gamma: bool, - normalization: str, - ub_overlap_ag_fprop: bool, - ub_overlap_rs_fprop: bool, - ub_overlap_ag_dgrad: bool, - ub_overlap_rs_dgrad: bool, - ub_bulk_wgrad: bool, - ub_bulk_dgrad: bool, - ub_name: str, - fsdp_group: Union[dist_group_type, None], - module: torch.nn.Module, - skip_fp8_weight_update: bool, - symmetric_ar_type: str, - debug: Optional[bool] = False, - ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: - # pylint: disable=missing-function-docstring - - # NVTX label for profiling - nvtx_label = "transformer_engine._LayerNormLinear.forward" - if ub_name is not None: - nvtx_label = f"{nvtx_label}.{ub_name}" - - with_input_all_gather = parallel_mode == "column" and sequence_parallel - - # Make sure input dimensions are compatible - out_features, in_features = weight.shape - inp_shape = inp.shape - inp_requires_grad = inp.requires_grad - assert inp_shape[-1] == in_features, "GEMM not possible" - inp = inp.view((-1, in_features)) - inputmat = inp - if fp8: - assert_dim_for_fp8_exec(inputmat, weight) - assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer) - - # Cast for native AMP - nvtx_range_push(f"{nvtx_label}.norm_input_cast") - inputmat = cast_if_needed(inputmat, activation_dtype) - ln_weight = cast_if_needed(ln_weight, activation_dtype) - if ln_bias is not None: - ln_bias = cast_if_needed(ln_bias, activation_dtype) - nvtx_range_pop(f"{nvtx_label}.norm_input_cast") - - tp_world_size = get_distributed_world_size(tp_group) - - weight_requires_grad = weight.requires_grad - backward_needs_input = is_grad_enabled and weight_requires_grad - - # Configure Userbuffers communication (comm+GEMM overlap) - if debug: # turn off userbuffers in debug mode - ub_overlap_ag_fprop = False - ub_overlap_rs_fprop = False - ub_overlap_ag_dgrad = False - ub_overlap_rs_dgrad = False - ub_bulk_wgrad = False - ub_bulk_dgrad = False - ub_obj = None - ub_type = None - ub_overlap_ag_fprop = ( - ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output - ) - if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop", fp8) - ub_type = tex.CommOverlapType.RS - elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop", fp8) - ub_type = tex.CommOverlapType.AG - - # Configure quantizer for norm output - if fp8: - if input_quantizer is None: - raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) - if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): - # All-gather is not supported with FP8 column-wise data - input_quantizer.set_usage(columnwise=False) - - # Avoid quantized norm kernel if norm output will be returned - # or if a gather of ln_out must be in high precision. - experimental = is_experimental(input_quantizer) - with_quantized_norm = ( - fp8 - and not debug - and not return_layernorm_output - and not return_layernorm_output_gathered - and not experimental - ) - - # Apply normalization - nvtx_range_push(f"{nvtx_label}.norm") - ln_out, mu, rsigma = apply_normalization( - inputmat, - None, # ln_out - ln_weight, - ln_bias, - eps, - input_quantizer if with_quantized_norm else None, - inputmat.dtype, - normalization, - fwd_ln_sm_margin, - zero_centered_gamma, - ) - nvtx_range_pop(f"{nvtx_label}.norm") - - # Store unquantized layer norm output if we need to return it - ln_out_return = None - if return_layernorm_output or return_layernorm_output_gathered: - ln_out_return = ln_out - - # ------------------------------------------------------ - # Prepare GEMM input tensor - # Note: Cast to expected dtype and perform tensor-parallel communication - # ------------------------------------------------------ - nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") - ln_out_total = None - if with_input_all_gather: - if return_layernorm_output_gathered: - # Perform all-gather in high precision if gathered - # norm output will be returned - ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) - ln_out_return = ln_out_total - if fp8 or debug: - ln_out = input_quantizer(ln_out) - input_quantizer.set_usage(rowwise=True, columnwise=False) - if isinstance(input_quantizer, Float8BlockQuantizer): - input_quantizer.all_gather_usage = False - ln_out_total = input_quantizer(ln_out_total) - else: - quantizer = None - if fp8 or debug: - quantizer = input_quantizer - # experimental recipe doesn't need to support quantized AG - if not with_quantized_norm and not experimental: - ln_out = quantizer(ln_out) - quantizer.set_usage(rowwise=True, columnwise=False) - if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather - ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( - ub_obj, - ln_out, - quantizer, - tp_group, - ) - else: # Perform NCCL all-gather - ln_out_total, _ = gather_along_first_dim( - ln_out, - tp_group, - quantizer=quantizer, - ) - else: - if (fp8 or debug) and not with_quantized_norm: - ln_out = input_quantizer(ln_out) - ln_out_total = ln_out - nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") - # ------------------------------------------------------ - # GEMM input tensor is ready... - # ------------------------------------------------------ - - # ------------------------------------------------------ - # Prepare weight tensor - # ------------------------------------------------------ - weightmat = weight - quantized_weight = False - if fp8 or debug: - quantized_weight = not isinstance(weight, QuantizedTensorBase) - - # Configure quantizer - if weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) - - # Get quantized weight - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( - tensor=weight, - quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - workspace_dtype=activation_dtype, - ) - weightmat.update_usage(rowwise_usage=True) - - else: - weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP - # ------------------------------------------------------ - # Weight tensor is ready for GEMM... - # ------------------------------------------------------ - - # Cast bias to expected dtype - bias_dtype = activation_dtype - if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: - # cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16 - bias_dtype = torch.bfloat16 - bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias - - # Calibrate quantizers if needed - if not fp8 and fp8_calibration: - if input_quantizer is not None: - input_quantizer.calibrate(ln_out_total) - if weight_quantizer is not None: - weight_quantizer.calibrate(weight) - - # Choose whether to use GEMM kernel with split accumulator - use_split_accumulator = _2X_ACC_FPROP - if fp8: - recipe = FP8GlobalStateManager.get_fp8_recipe() - if hasattr(recipe, "fp8_gemm_fprop"): - use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator - - # Configure output quantizer - if output_quantizer is not None: - output_quantizer.set_usage(rowwise=True, columnwise=False) - - # Output buffer for Userbuffers reduce-scatter - reduce_scatter_out = None - if ub_overlap_rs_fprop: - out_shape = list(inp_shape) - out_shape[0] //= tp_world_size - out_shape[-1] = out_features - reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) - - # ------------------------------------------------------ - # Forward GEMM - # Note: y = x * w^T - # ------------------------------------------------------ - nvtx_range_push(f"{nvtx_label}.gemm") - gemm_out, *_, reduce_scatter_out = general_gemm( - weightmat, - ln_out_total, - get_workspace(), - quantization_params=output_quantizer, - out_dtype=activation_dtype, - bias=bias, - use_split_accumulator=use_split_accumulator, - ub=ub_obj, - ub_type=ub_type, - extra_output=reduce_scatter_out, - ) - nvtx_range_pop(f"{nvtx_label}.gemm") - # ------------------------------------------------------ - # Finished forward GEMM... - # ------------------------------------------------------ - - # Deallocate GEMM input tensor if no longer needed - if not weight.requires_grad and not return_layernorm_output: - clear_tensor_data(ln_out, ln_out_total) - ln_out = ln_out_total = None - elif with_input_all_gather and not return_layernorm_output_gathered: - clear_tensor_data(ln_out_total) - ln_out_total = None - - # ------------------------------------------------------ - # Prepare output tensor - # Note: Perform tensor-parallel communication - # ------------------------------------------------------ - out = None - if ub_overlap_rs_fprop: - out = reduce_scatter_out - elif parallel_mode == "row" and tp_size > 1: - nvtx_range_push(f"{nvtx_label}.row_parallel_comm") - out = gemm_out - if sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif tensor_parallel: - if symmetric_ar_type is not None: - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) - else: - out, _ = allreduce(out, tp_group) - nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") - else: - out = gemm_out - out = out.view(-1, *inp_shape[1:-1], out_features) - # ------------------------------------------------------ - # Output tensor is ready to return... - # ------------------------------------------------------ - - # ------------------------------------------------------ - # Cache state for backward pass - # ------------------------------------------------------ - - if is_grad_enabled: - ctx.weight_quantizer = weight_quantizer - ctx.ln_out_needs_gather = ( - weight.requires_grad and parallel_mode == "column" and sequence_parallel - ) - - # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input: - if isinstance(ln_out, QuantizedTensorBase): - # For sequence parallel in vanilla FP8, rowwise data is - # to gather the input. For MXFP8, columnwise only data - # can be allgathered. - if ( - isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) - or not ctx.ln_out_needs_gather - ): - ln_out.update_usage(rowwise_usage=False) - - # Weight with column-wise usage is needed for dgrad GEMM. - if isinstance(weightmat, QuantizedTensorBase): - weightmat.update_usage(columnwise_usage=True) - - if cpu_offloading: - mark_activation_offload(inputmat, mu, rsigma, ln_out) - - # Scatter intermediate/activation tensors saved for the backward pass - # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already - # shards/unshards the base weights so we don't do it ourselves - nvtx_range_push(f"{nvtx_label}.fsdp_scatter") - ctx.fsdp_group = fsdp_group - ctx.fsdp_shapes = _fsdp_scatter_tensors( - fsdp_group, - mu, - rsigma, - weightmat if quantized_weight else None, - ln_out if weight.requires_grad else None, - ) - nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - - if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - - if ctx.grad_added_to_main_grad: - # If you are passing torch.nn.Parameter through the Torch hooks, you will - # get back torch.Tensor. Torch rips off the Parameter wrapper. - # You need to preserve the weight object to have all the attributes user - # sets for the weights. Because of this, it is not recommended to offload - # weights if weights are externally touched outside this module - ctx.weight_object = weight - - tensors_to_save, tensor_objects = prepare_for_saving( - inputmat, - weightmat, - weight, - bias, - ln_weight, - ln_out, - mu, - rsigma, - ) - ctx.save_for_backward(*tensors_to_save) - ctx.tensor_objects = tensor_objects - ctx.requires_dgrad = inp_requires_grad - ctx.requires_wgrad = weight.requires_grad - ctx.quantized_weight = quantized_weight - if fuse_wgrad_accumulation and weight.requires_grad: - # This check is needed to ensure that main_grad is not created - # during the forward pass when using MCore FSDP as it creates - # the main_grad buffer lazily before backprop - if hasattr(weight, "__fsdp_param__"): - # MCore FSDP creates main_grad lazily before backward - ctx.main_grad_func = weight.get_main_grad - else: - ctx.main_grad_func = lambda: weight.main_grad - ctx.grad_input_quantizer = grad_input_quantizer - ctx.grad_weight_quantizer = grad_weight_quantizer - ctx.grad_output_quantizer = grad_output_quantizer - ctx.input_quantizer = input_quantizer - ctx.owns_input = inputmat is not inp - ctx.weight = weight - ctx.activation_dtype = activation_dtype - ctx.fp8 = fp8 - ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.cpu_offloading = cpu_offloading - ctx.is_first_microbatch = is_first_microbatch - ctx.use_bias = bias is not None - ctx.sequence_parallel = sequence_parallel - ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp_shape - ctx.parallel_mode = parallel_mode - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.return_layernorm_output = return_layernorm_output - ctx.return_layernorm_output_gathered = return_layernorm_output_gathered - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.ub_overlap_ag = ub_overlap_ag_dgrad - ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - ctx.ub_bulk_wgrad = ub_bulk_wgrad - ctx.ub_bulk_dgrad = ub_bulk_dgrad - ctx.ub_name = ub_name - ctx.requires_dgrad = inp_requires_grad - ctx.normalization = normalization - ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): - _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE - ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() - if in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module - ctx.wgrad_store = wgrad_store - ctx.debug = debug - - # ------------------------------------------------------ - # Cached state for backward pass is ready... - # ------------------------------------------------------ - - if return_layernorm_output: - if return_layernorm_output_gathered: - shape = list(inp_shape) - shape[0] *= tp_size if with_input_all_gather else 1 - return out, ln_out_return.view(shape) - return out, ln_out_return.view(inp_shape) - return out - - @staticmethod - def backward( - ctx, *grad_outputs: Tuple[torch.Tensor, ...] - ) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring - - # NVTX label for profiling - nvtx_label = "transformer_engine._LayerNormLinear.backward" - if ctx.ub_name is not None: - nvtx_label = f"{nvtx_label}.{ctx.ub_name}" - - with torch.cuda.nvtx.range("_LayerNormLinear_backward"): - saved_tensors = ctx.saved_tensors - ( # pylint: disable=unbalanced-tuple-unpacking - inputmat, - weight, - origin_weight, - bias, - ln_weight, - ln_out, - mu, - rsigma, - ) = restore_from_saved(ctx.tensor_objects, saved_tensors) - # Delete the references to tensor objects once they've been consumed - # by the `restore_from_saved` method to construct back the actual tensors. - ctx.tensor_objects = None - - # Since main_grad can be modified inplace, it should not be a part of saved_tensors - main_grad = ( - ctx.main_grad_func() - if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad - else None - ) - - # Gather intermediate/activation tensors if needed - # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already - # shards/unshards the base weights so we don't do it ourselves - nvtx_range_push(f"{nvtx_label}.fsdp_gather") - _fsdp_gather_tensors( - ctx.fsdp_group, - ctx.fsdp_shapes, - mu, - rsigma, - weight if ctx.fp8 and ctx.quantized_weight else None, - ln_out, - ) - nvtx_range_pop(f"{nvtx_label}.fsdp_gather") - - # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, - # we need to connect them into one. - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: - origin_weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - origin_weight.main_grad = main_grad - - # Configure Userbuffers communication (comm+GEMM overlap) - ctx.ub_obj_gradout = None - ub_obj_dgrad = None - ub_obj_wgrad = None - ub_type_dgrad = None - ub_type_wgrad = None - dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] - if ctx.ub_overlap_ag: - # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout - ub_type_dgrad = tex.CommOverlapType.AG - elif ctx.ub_overlap_rs_dgrad: - # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout - ub_type_dgrad = tex.CommOverlapType.RS - else: - if ctx.ub_bulk_dgrad: - # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout - ub_type_dgrad = tex.CommOverlapType.AG - if ctx.ub_bulk_wgrad: - # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) - ub_type_wgrad = tex.CommOverlapType.RS - - # -------------------------------------------------- - # Prepare grad output tensor - # Note: Cast to expected dtype and perform tensor-parallel communication - # -------------------------------------------------- - - # Configure quantizer for grad output tensor - # Note: dgrad GEMM requires row-wise usage, wgrad GEMM - # requires column-wise usage - if ctx.grad_output_quantizer is not None: - quantizer = ctx.grad_output_quantizer - quantizer.set_usage(rowwise=True, columnwise=True) - if ctx.ub_overlap_ag: - # Userbuffers only supports communication for one - # tensor usage at a time. Configure quantizer with - # usage for only dgrad GEMM. - quantizer.set_usage(columnwise=False) - - # Prepare grad output tensor - # Note: Cast to expected dtype and perform tensor-parallel communication - nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") - ( - grad_output, - grad_bias, - ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, - grad_outputs[0], - ctx.parallel_mode == "row", - ctx.grad_output_quantizer, - ) - nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") - - # -------------------------------------------------- - # Grad output tensor is ready for computing grad input... - # -------------------------------------------------- - - # -------------------------------------------------- - # Prepare GEMM input tensor - # Note: Input tensor is needed for wgrad GEMM. - # Tensor-parallel communication is overlapped with dgrad - # GEMM. - # -------------------------------------------------- - ln_out_total = None - ln_out_total_work = None - if ctx.ln_out_needs_gather: - quantizer = None - if ctx.input_quantizer is not None: - quantizer = ctx.input_quantizer - if quantizer.supports_only_rowwise_all_gather(): - # If data is in FP8, we compute FP8 transposes manually - quantizer.set_usage(rowwise=True, columnwise=False) - else: - # wgrad GEMM requires input with column-wise usage - quantizer.set_usage(rowwise=False, columnwise=True) - if ctx.ub_bulk_dgrad: - ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( - ub_obj_dgrad, - ln_out, - quantizer, - ctx.tp_group, - ) - else: - nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") - ln_out_total, ln_out_total_work = gather_along_first_dim( - ln_out, - ctx.tp_group, - async_op=True, - quantizer=quantizer, - ) - nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") - else: - ln_out_total = ln_out - # -------------------------------------------------- - # Input tensor is ready for computing grad weight... - # -------------------------------------------------- - - # -------------------------------------------------- - # Compute grad input tensor - # Note: Gradient w.r.t. GEMM input (i.e. norm output). - # -------------------------------------------------- - - # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): - grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase): - weight.update_usage(columnwise_usage=True) - - # Choose whether to use GEMM kernel with split accumulator - use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: - recipe = ctx.fp8_recipe - if hasattr(recipe, "fp8_gemm_dgrad"): - use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator - - # Update grad input quantizer - if ctx.grad_input_quantizer is not None: - ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) - - # Output buffers for Userbuffers reduce-scatter - gemm_out = None - reduce_scatter_out = None - if ctx.ub_overlap_rs_dgrad: - reduce_scatter_out = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device - ) - elif ctx.ub_bulk_wgrad: - gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False) - - # dgrad GEMM - # Note: dx = dy * w - nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - gemm_out, *_, reduce_scatter_out = general_gemm( - weight, - grad_output, - get_workspace(), - layout="NN", - grad=True, - quantization_params=ctx.grad_input_quantizer, - out=gemm_out, - out_dtype=ctx.activation_dtype, - use_split_accumulator=use_split_accumulator, - ub=ub_obj_dgrad, - ub_type=ub_type_dgrad, - extra_output=reduce_scatter_out, - bulk_overlap=ctx.ub_bulk_dgrad, - ) - nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") - - # Prepare grad input tensor - # Note: Perform tensor-parallel communication - dgrad = None - dgrad_work = None - if ctx.ub_overlap_rs_dgrad: - dgrad = reduce_scatter_out - elif ctx.ub_bulk_wgrad: - dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) - elif ctx.parallel_mode == "column" and ctx.tp_size > 1: - nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") - dgrad = gemm_out - if ctx.sequence_parallel: - dgrad, dgrad_work = reduce_scatter_along_first_dim( - dgrad, - ctx.tp_group, - async_op=True, - ) - else: - dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) - nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") - else: - dgrad = gemm_out - - # -------------------------------------------------- - # Grad input tensor has been computed... - # -------------------------------------------------- - - # -------------------------------------------------- - # Compute grad weight - # -------------------------------------------------- - - wgrad = None - if ctx.requires_wgrad: - # Prepare grad output tensor - # Note: Synchronize tensor-parallel communication and - # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): - # UB does not support pipelined overlapping grad output - # all-gather with wgrad GEMM. Also, we can't - # convert row-scaled MXFP8 to column-scaled, so we - # can't reuse the grad output that was gathered - # for the dgrad GEMM. We work around by explicitly - # overlapping the AG operation with the dgrad GEMM. - - # Get the communication stream from the dgrad GEMM to use for the AG - dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() - - # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) - - ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - - # We use the send stream to copy into the userbuffers. - # This is the same stream that we will use to access the data in the AG, - # so we dont need to add any syncs yet. - with torch.cuda.stream(dgrad_send_stream): - grad_output, _ = fill_userbuffers_buffer_for_all_gather( - ub_obj_overlap_wgrad, - grad_outputs[0], - ctx.grad_output_quantizer, - ctx.tp_group, - ) - - # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm - tex.bulk_overlap_ag_with_external_gemm( - ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream - ) - - # Prepare input tensor - # Note: Synchronize tensor-parallel communication and - # make sure required data is available - if ln_out_total_work is not None: - ln_out_total_work.wait() - ln_out_total_work = None - if ctx.fp8 or ctx.debug: - if isinstance(ln_out_total, QuantizedTensorBase): - ln_out_total.update_usage(columnwise_usage=True) - else: - ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) - ln_out_total = ctx.input_quantizer(ln_out_total) - - if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): - grad_output.update_usage(columnwise_usage=True) - else: - ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.grad_output_quantizer(grad_output) - - # Figure out whether to use split accumulator - use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: - recipe = ctx.fp8_recipe - if hasattr(recipe, "fp8_gemm_wgrad"): - use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator - - # Figure out whether to output wgrad GEMM directly into main grad - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - # Output buffer for overlapping FP8 grad input - # reduce-scatter with wgrad GEMM - reduce_scatter_out = None - if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): - reduce_scatter_out = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device - ) - - # Arguments to include in wgrad GEMM closure - wgrad_gemm_kwargs = { - "workspace": get_workspace(), - "out_dtype": ( - main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype - ), - "quantization_params": ctx.grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, - "layout": "NT", - "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), - "use_split_accumulator": use_split_accumulator, - "grad": True, - "ub": ub_obj_wgrad, - "ub_type": ub_type_wgrad, - "extra_output": reduce_scatter_out, - "bulk_overlap": ctx.ub_bulk_wgrad, - } - - def wgrad_gemm( - x: torch.Tensor, - dy: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Perform wgrad GEMM: dw = dy^T * x - - May be fused with bgrad computation. - - May be called outside of this function to enable - some advanced communication/compute overlapping. - - """ - nvtx_range_push(f"{nvtx_label}.wgrad_gemm") - dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs) - nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") - return dw, db - - # Choose whether to call wgrad GEMM now or delay - if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): - if ( - wgrad_gemm_kwargs["ub"] is not None - or wgrad_gemm_kwargs["ub_type"] is not None - or wgrad_gemm_kwargs["extra_output"] is not None - or wgrad_gemm_kwargs["bulk_overlap"] - ): - raise NotImplementedError( - "Delayed weight grad computation is not supported " - "with Userbuffers (tensor-parallel communication overlapping)" - ) - ctx.wgrad_store.put([ln_out_total, grad_output], wgrad_gemm) - else: - - # Call wgrad GEMM now - wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output) - - # Update grad bias if needed - if grad_bias is None: - grad_bias = grad_bias_ - del grad_bias_ - - # Deallocate input tensors if permitted - if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: - # Input tensors have not been exposed externally - clear_tensor_data(ln_out) - elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered: - # Non-gathered input has not been exposed externally - clear_tensor_data(ln_out) - if ctx.ln_out_needs_gather: - # Gathered input is internal - clear_tensor_data(ln_out_total) - if ctx.parallel_mode == "row" and ctx.sequence_parallel: - # Gathered grad output tensor is internal - clear_tensor_data(grad_output) - - # Update grad input if overlapping reduce-scatter with wgrad GEMM - if ctx.ub_bulk_wgrad: - if ub_obj_wgrad.is_fp8_ubuf(): - dgrad = reduce_scatter_out - else: - dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone() - - # -------------------------------------------------- - # Grad weight has been computed... - # -------------------------------------------------- - - # Don't return grad bias if not needed - if not ctx.use_bias: - grad_bias = None - - # Synchronize tensor parallel communication - if ln_out_total_work is not None: - ln_out_total_work.wait() - ln_out_total_work = None - if dgrad_work is not None: - dgrad_work.wait() - dgrad_work = None - - # Residual gradient - dgrad = dgrad.view(inputmat.shape) - if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: - dgrad = dgrad + grad_outputs[1].view_as(dgrad) - - # Norm gradient - dgamma = None - dbeta = None - nvtx_range_push(f"{nvtx_label}.norm") - if ctx.normalization == "LayerNorm": - dgrad, dgamma, dbeta = tex.layernorm_bwd( - dgrad, - inputmat, - mu, - rsigma, - ln_weight, - ctx.bwd_ln_sm_margin, - ctx.zero_centered_gamma, - ) - dgrad = dgrad.reshape(inputmat.size()) - elif ctx.normalization == "RMSNorm": - dgrad, dgamma = tex.rmsnorm_bwd( - dgrad, - inputmat, - rsigma, - ln_weight, - ctx.bwd_ln_sm_margin, - ctx.zero_centered_gamma, - ) - dgrad = dgrad.reshape(inputmat.size()) - dbeta = None - nvtx_range_pop(f"{nvtx_label}.norm") - clear_tensor_data(mu) - clear_tensor_data(rsigma) - - if ctx.requires_wgrad: - # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): - origin_weight.grad_added_to_main_grad = True - if getattr(origin_weight, "zero_out_wgrad", False): - wgrad = get_dummy_wgrad( - list(origin_weight.main_grad.shape), - origin_weight.dtype, - zero=True, - ) - else: - wgrad = get_dummy_wgrad( - list(origin_weight.main_grad.shape), - origin_weight.dtype, - ) - elif ctx.fuse_wgrad_accumulation: - wgrad = None - else: - wgrad = None - - if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): - nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) - nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") - - # Scatter fp8 weight buffers - # if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): - # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) - - return ( - dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, - dgamma, - dbeta, - wgrad, - grad_bias, - None, # eps - None, # is_first_microbatch - None, # fp8 - None, # fp8_calibration - None, # wgrad_store - None, # fuse_wgrad_accumulation - None, # input_quantizer - None, # weight_quantizer - None, # output_quantizer - None, # grad_input_quantizer - None, # grad_weight_quantizer - None, # grad_output_quantizer - None, # cpu_offloading - None, # tp_group - None, # tp_size - None, # sequence_parallel - None, # tensor_parallel - None, # activation_dtype - None, # parallel_mode - None, # return_layernorm_output - None, # return_layernorm_output_gathered - None, # is_grad_enabled - None, # fwd_ln_sm_margin - None, # bwd_ln_sm_margin - None, # zero_centered_gamma - None, # normalization - None, # ub_overlap_ag_fprop - None, # ub_overlap_rs_fprop - None, # ub_overlap_ag_dgrad - None, # ub_overlap_rs_dgrad - None, # ub_bulk_dgrad - None, # ub_bulk_wgrad - None, # ub_name - None, # fsdp_group - None, # debug - None, # module - None, # skip_fp8_weight_update - None, # symmetric_ar_type - ) - - -class LayerNormLinear(TransformerEngineBaseModule): - r""" - Applies layer normalization followed by linear transformation to the incoming data. - - Parameters - ---------- - in_features : int - size of each input sample. - out_features : int - size of each output sample. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - bias : bool, default = `True` - if set to `False`, the layer will not learn an additive bias. - normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' - type of normalization applied. - init_method : Callable, default = `None` - used for initializing weights in the following way: `init_method(weight)`. - When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. - return_layernorm_output : bool, default = `False` - if set to `True`, output of layernorm is returned from the forward - together with the output of the linear transformation. - Example use case: residual connection for transformer module is - taken post layernorm. - return_layernorm_output_gathered : bool, default = `False` - if set to `True`, output of layernorm is returned after the all - gather operation. Ignored if return_layernorm_output is False. - Example use case: with sequence parallel, input to residual connection - for transformer module (e.g. LoRA) will need to be gathered. - Returning layernorm output gathered will prevent a redundant gather. - parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None - Configuration for splitting the weight and bias tensors along dim 0 into - multiple PyTorch parameters. If a list or tuple of strings is provided, - they are used to make the names of equally-sized parameters. If a dict - (preferably an OrderedDict) is provided, the keys are used as names and - values as split sizes along dim 0. The resulting parameters will have - names that end in `_weight` or `_bias`, so trailing underscores are - stripped from any provided names. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will be allocated. It is the user's - responsibility to ensure all parameters are moved to the GPU before running the - forward pass. - name: str, default = `None` - name of the module, currently used for debugging purposes. - - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - tp_size : int, default = 1 - used as TP (tensor parallel) world size when TP groups are not formed during - initialization. In this case, users must call the - `set_tensor_parallel_group(tp_group)` method on the initialized module before the - forward pass to supply the tensor parallel group needed for tensor and sequence - parallel collectives. - parallel_mode : {None, 'column', 'row'}, default = `None` - used to decide whether this Linear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - return_bias : bool, default = `False` - when set to `True`, this module will not apply the additive bias itself, but - instead return the bias value during the forward pass together with the - output of the linear transformation :math:`y = xA^T`. This is useful when - the bias addition can be fused to subsequent operations. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` - it controls the type used to allocate the initial parameters. Useful when - the model is trained with lower precision and the original FP32 parameters - would not fit in GPU memory. - delay_wgrad_compute : bool, default = `False` - Whether or not to delay weight gradient computation. If set to `True`, - it's the user's responsibility to call `module.backward_dw` to compute - weight gradients. - symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None - Type of symmetric memory all-reduce to use during the forward pass. - This can help in latency bound communication situations. - Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce - is used. - """ - - def __init__( - self, - in_features: int, - out_features: int, - eps: float = 1e-5, - sequence_parallel: bool = False, - fuse_wgrad_accumulation: bool = False, - tp_group: Optional[dist_group_type] = None, - tp_size: int = 1, - get_rng_state_tracker: Optional[Callable] = None, - init_method: Optional[Callable] = None, - bias: bool = True, - normalization: str = "LayerNorm", - return_bias: bool = False, - params_dtype: Optional[torch.dtype] = None, - parallel_mode: Optional[str] = None, - return_layernorm_output: bool = False, - return_layernorm_output_gathered: bool = False, - parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, - zero_centered_gamma: bool = False, - device: Union[torch.device, str] = "cuda", - ub_overlap_ag: bool = False, - ub_overlap_rs: bool = False, - ub_overlap_rs_dgrad: bool = False, - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, - ub_name: Optional[str] = None, - delay_wgrad_compute: bool = False, - symmetric_ar_type: Optional[str] = None, - name: str = None, - ) -> None: - super().__init__() - - params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype - self.in_features = in_features - self.out_features = out_features - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - self.normalization = normalization - assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" - self.use_bias = bias - self.return_bias = return_bias - self.apply_bias = self.use_bias and not return_bias - self.return_layernorm_output = return_layernorm_output - self.return_layernorm_output_gathered = ( - return_layernorm_output_gathered if return_layernorm_output else False - ) - self.zero_centered_gamma = zero_centered_gamma - self.symmetric_ar_type = symmetric_ar_type - - self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) - self.name = name - - if tp_group is None: - self.tp_size = tp_size - if tp_size == 1: - self.set_tensor_parallel_group(tp_group) - else: - self.tp_size = get_distributed_world_size(tp_group) - self.set_tensor_parallel_group(tp_group) - self.set_nccl_overlap_warning_if_tp() - - self.parallel_mode = parallel_mode - assert ( - self.parallel_mode in GemmParallelModes - ), f"parallel_mode {parallel_mode} not supported" - - if self.parallel_mode == "column": - self.out_features = divide(self.out_features, self.tp_size) - elif self.parallel_mode == "row": - self.in_features = divide(self.in_features, self.tp_size) - - if init_method is None: - init_method = get_default_init_method() - - self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - - # Column-parallel overlaps - self.ub_overlap_ag_fprop = ( - ub_overlap_ag and self.sequence_parallel and self.parallel_mode == "column" - ) - self.ub_overlap_rs_dgrad = ( - ub_overlap_rs_dgrad and self.sequence_parallel and self.parallel_mode == "column" - ) - self.ub_bulk_wgrad = ( - ub_bulk_wgrad - and self.sequence_parallel - and self.parallel_mode == "column" - and not self.ub_overlap_rs_dgrad - ) - self.ub_bulk_dgrad = ( - ub_bulk_dgrad - and self.sequence_parallel - and self.parallel_mode == "column" - and not self.ub_overlap_rs_dgrad - ) - - # Row-parallel overlaps - self.ub_overlap_rs_fprop = ( - ub_overlap_rs and self.sequence_parallel and self.parallel_mode == "row" - ) - self.ub_overlap_ag_dgrad = ( - ub_overlap_ag and self.sequence_parallel and self.parallel_mode == "row" - ) - if any( - [ - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - ] - ): - assert ub_name is not None, "Userbuffer name [string] is not set." - self.ub_name = ub_name - - if self.symmetric_ar_type is not None: - assert torch_version() >= ( - 2, - 7, - 0, - ), "Torch version must be at least 2.7 to use symmetric memory" - - self.eps = eps - layer_norm_weight = torch.nn.Parameter( - torch.empty(self.in_features, device=device, dtype=params_dtype) - ) - self.register_parameter( - "layer_norm_weight", - layer_norm_weight, - init_fn=init_method_constant(float(not self.zero_centered_gamma)), - ) - if self.normalization != "RMSNorm": - layer_norm_bias = torch.nn.Parameter( - torch.empty(self.in_features, device=device, dtype=params_dtype) - ) - self.register_parameter( - "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0) - ) - else: - self.layer_norm_bias = None - - # Initialize params in FP8 - with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() - - # Contiguous buffers for params - weight_tensor = torch.empty( - self.out_features, - self.in_features, - device=device, - dtype=params_dtype, - ) - bias_tensor = None - if self.use_bias: - bias_tensor = torch.empty( - self.out_features, - device=device, - dtype=params_dtype, - ) - - # Configure parameter splits - self.weight_names = [] - self.bias_names = [] - self.parameter_split_sizes = [] - if parameters_split is None: - # Split into a single parameter by default - self.weight_names = ["weight"] - self.bias_names = ["bias"] - self.parameter_split_sizes = [out_features] - elif not parameters_split: - raise ValueError("Cannot split weight buffer into 0 parameters") - elif isinstance(parameters_split, dict): - # Split parameters with provided sizes - for name, split_size in parameters_split.items(): - self.weight_names.append(f"{name.rstrip('_')}_weight") - self.bias_names.append(f"{name.rstrip('_')}_bias") - self.parameter_split_sizes.append(split_size) - elif all(isinstance(name, str) for name in parameters_split): - # Split parameters evenly - split_size = out_features // len(parameters_split) - for name in parameters_split: - self.weight_names.append(f"{name.rstrip('_')}_weight") - self.bias_names.append(f"{name.rstrip('_')}_bias") - self.parameter_split_sizes.append(split_size) - else: - raise TypeError("Invalid configuration for parameters split") - - # Make sure parameter splits are valid - if sum(self.parameter_split_sizes) != out_features: - raise ValueError( - f"Trying to split weight buffer ({out_features=}) " - f"with split sizes {self.parameter_split_sizes}" - ) - - # Adjust parameter splits for tensor-parallel distribution - if self.parallel_mode == "column": - for i, size in enumerate(self.parameter_split_sizes): - if size % self.tp_size != 0: - raise RuntimeError( - f"Attempting to distribute a parameter with out_features={size} " - f"between {self.tp_size} tensor-parallel processes" - ) - self.parameter_split_sizes[i] = size // self.tp_size - - # Construct weight parameters - # Note: Register weights together so that they are adjacent to - # each other in LayerNormLinear.parameters(). This makes it - # more likely that they will stay contiguous if the weights - # are manipulated externally, e.g. by FSDP. - offset = 0 - for i, split_size in enumerate(self.parameter_split_sizes): - split_start = offset - offset += split_size - split_end = offset - - # Check if parameters are subviews of buffers - is_subview = (split_start, split_end) != (0, self.out_features) - if is_subview and with_fp8_params: - raise RuntimeError( - "Splitting QuantizedTensor into multiple params is not supported" - ) - - # Construct weight parameter - self.register_parameter( - self.weight_names[i], - torch.nn.Parameter(weight_tensor[split_start:split_end]), - init_fn=init_method, - get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - ) - - # Construct bias parameters if needed - if self.use_bias: - offset = 0 - for i, split_size in enumerate(self.parameter_split_sizes): - split_start = offset - offset += split_size - split_end = offset - self.register_parameter( - self.bias_names[i], - torch.nn.Parameter(bias_tensor[split_start:split_end]), - init_fn=init_method_constant(0.0), - ) - else: - for name in self.bias_names: - bias = torch.Tensor().to(dtype=params_dtype, device=device) - setattr(self, name, bias) - - if with_fp8_params: - self.init_fp8_metadata() - - self.reset_parameters(defer_init=device == "meta") - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.apply_bias: - self.gemm_bias_unfused_add = True - else: - self.gemm_bias_unfused_add = False - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - - if self.wgrad_store.delay_wgrad_compute(): - for name, param in self.named_parameters(): - if name in self.weight_names or name in self.bias_names: - param.skip_backward_post_hook = True - - def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: - """Init scales and amaxes for fwd | bwd.""" - super().set_meta_tensor(fwd, recipe) - - # customize quantizers based on each recipe & layer configs - recipe = FP8GlobalStateManager.get_fp8_recipe() - if recipe.float8_current_scaling(): - self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.float8_block_scaling(): - self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) - elif recipe.nvfp4(): - self._customize_quantizers_nvfp4(fwd, recipe) - # elif other recipes (mxfp8, etc) - - def reset_layer_norm_parameters(self) -> None: - """Init LN params""" - warnings.warn( - "This method will be deprecated in an upcoming release. " - "Update your code to use LayerNormLinear.reset_parameters() instead.", - DeprecationWarning, - stacklevel=2, - ) - if not self.zero_centered_gamma: - init.ones_(self.layer_norm_weight) - else: - init.zeros_(self.layer_norm_weight) - if self.layer_norm_bias is not None: - init.zeros_(self.layer_norm_bias) - - def reset_parameters(self, defer_init=False): - super().reset_parameters(defer_init=defer_init) - - if not defer_init: - # Set parallelism attributes for layer norm parameters - setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) - if self.normalization != "RMSNorm": - setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) - - # Set parallelism attributes for linear weights - for weight in self.weight_names: - set_tensor_model_parallel_attributes( - tensor=getattr(self, weight), - is_parallel=True, - dim=1 if self.parallel_mode == "row" else 0, - stride=1, - ) - - # Set parallelism attributes for linear biases - if self.use_bias: - for bias in self.bias_names: - if self.parallel_mode == "row": - setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) - elif self.parallel_mode == "column": - set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - - @no_torch_dynamo() - def forward( - self, - inp: torch.Tensor, - is_first_microbatch: Optional[bool] = None, - fp8_output: Optional[bool] = False, - fp8_grad: Optional[bool] = False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: - """ - Apply layer normalization to the input followed by a linear transformation. - - Parameters - ---------- - inp : torch.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - * it also allows skipping gradient accumulation during the - first microbatch (since it is the first gradient being - produced) - """ - if is_in_onnx_export_mode(): - return self.onnx_forward(inp, fp8_output) - - debug = self.is_debug_iter() - - if FP8GlobalStateManager.fp8_graph_capturing(): - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() - else: - skip_fp8_weight_update = None - if skip_fp8_weight_update is not None: - is_first_microbatch = False - - if self.ub_overlap_rs_fprop: - if get_ub( - self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): - fp8_output = True - if self.ub_overlap_rs_dgrad: - if get_ub( - self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): - fp8_grad = True - - with torch.cuda.device( - getattr(self, list(self.named_parameters())[0][0]).device - ), self.prepare_forward( - inp, allow_non_contiguous=False # removed .contiguous from inside the layer - ) as inp: - - # Get concatenated weight and bias tensors - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - - quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad) - - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - - if torch.is_grad_enabled(): - fwd_fn = _LayerNormLinear.apply - args = [] - else: - fwd_fn = _LayerNormLinear.forward - args = [None] - args += ( - inp, - self.layer_norm_weight, - self.layer_norm_bias, - weight_tensor, - bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, - self.eps, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - self.fuse_wgrad_accumulation, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - self.return_layernorm_output, - self.return_layernorm_output_gathered, - torch.is_grad_enabled(), - self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_bulk_wgrad, - self.ub_bulk_dgrad, - self.ub_name, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - debug, - ) - out = fwd_fn(*args) - - if self.return_layernorm_output: - out, ln_out = out - - if self.gemm_bias_unfused_add: - out = out + cast_if_needed(bias_tensor, self.activation_dtype) - - if self.return_bias: - if self.return_layernorm_output: - return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out - return out, cast_if_needed(bias_tensor, self.activation_dtype) - if self.return_layernorm_output: - return out, ln_out - return out - - def _get_quantizers(self, fp8_output, fp8_grad): - if not self.fp8: - return [None] * 6 - grad_input_quantizer = None - grad_weight_quantizer = None - grad_output_quantizer = None - output_quantizer = None - input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - input_quantizer.internal = True - (weight_quantizer,) = self._get_weight_quantizers() - if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] - if torch.is_grad_enabled(): - grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] - grad_output_quantizer.internal = True - if fp8_grad: - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] - - return ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) - - def _get_debug_quantizers(self, fp8_output, fp8_grad): - original_quantizers = self._get_quantizers(fp8_output, fp8_grad) - assert TEDebugState.debug_enabled - from ...debug.pytorch.debug_quantization import DebugQuantizer - - names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] - return tuple( - DebugQuantizer(self.name, name, q, self.tp_group) - for name, q in zip(names, original_quantizers) - ) - - def _get_weight_and_bias_tensors(self): - # Get concatenated weight and bias tensors - unfused_weights = self._get_weight_tensors() - - weight_tensor = noop_cat(unfused_weights) - if self.use_bias: - bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) - else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused - return weight_tensor, bias_tensor - - def onnx_forward( - self, - inp: torch.Tensor, - fp8_output: bool, - ) -> torch.Tensor: - """ - ONNX-compatible version of the forward function that provides numerical equivalence - while only using operations that have defined ONNX symbolic translations. - This simplified implementation is designed specifically for inference scenarios. - """ - from ..export import onnx_layernorm, onnx_gemm - - assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export" - assert_warmed_up(self) - ( - input_quantizer, - weight_quantizer, - output_quantizer, - *_, - ) = self._get_quantizers(fp8_output, fp8_grad=False) - inp_dtype = inp.dtype - - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - ln_out, ln_out_return = onnx_layernorm( - inp, - self.layer_norm_weight, - self.layer_norm_bias, - self.eps, - self.normalization, - self.zero_centered_gamma, - inp_dtype, - self.return_layernorm_output, - input_quantizer, - ) - - if weight_quantizer is not None: - weight_tensor_quantized = weight_quantizer.onnx_quantize(weight_tensor) - weight_tensor = weight_quantizer.onnx_dequantize(weight_tensor_quantized) - weight_tensor = weight_tensor.to(inp_dtype) - - if bias_tensor is not None: - bias_tensor = bias_tensor.to(inp_dtype) - - output = onnx_gemm(weight_tensor, ln_out, bias_tensor if self.apply_bias else None) - - if output_quantizer is not None: - raise NotImplementedError("ONNX export of quantized output is not supported") - if self.return_layernorm_output and self.return_bias: - return output, bias_tensor.to(inp_dtype), ln_out_return - if self.return_layernorm_output: - return output, ln_out_return - if self.return_bias: - return output, bias_tensor.to(inp_dtype) - return output - - def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on current scaling recipe + layernorm_linear.""" - assert ( - recipe.float8_current_scaling() - ), "current scaling recipe quantizer customization here" - if fwd: - # set configs about amax epsilon and power_2_scale - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon - # also set weight quantizer with same amax_epsilon & power_2_scale - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT - ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT - ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon - # parallel related - if self.sequence_parallel and self.parallel_mode == "column": - # set input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].amax_reduction_group = self.tp_group - else: - # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - # parallel related - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group - - def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on current scaling recipe + layernorm_linear.""" - assert recipe.nvfp4(), "Incorrect recipe." - if fwd: - if self.sequence_parallel and self.parallel_mode == "column": - # set input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].amax_reduction_group = self.tp_group - else: - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group - - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: - """Get the weight tensors of the module.""" - unfused_weights = [getattr(self, name) for name in self.weight_names] - if any(isinstance(w, QuantizedTensor) for w in unfused_weights): - if self.fp8: - if len(unfused_weights) != 1: - raise RuntimeError( - "Splitting QuantizedTensor into multiple params is not supported" - ) - else: - warnings.warn( - "You are using quantized weights without quantized compute. " - "Please make sure this is intentional." - ) - unfused_weights = [w.dequantize() for w in unfused_weights] - return unfused_weights - - def _get_weight_quantizers(self) -> List[Quantizer]: - """Get the weight quantizers of the module.""" - if not self.fp8 and not self.fp8_calibration: - return [None] - weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - weight_quantizer.internal = True - return [weight_quantizer] - - def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on blockwise scaling recipe + layernorm_linear.""" - assert ( - recipe.float8_block_scaling() - ), "blockwise scaling recipe quantizer customization here" - if fwd: - if self.sequence_parallel and self.parallel_mode == "column": - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].all_gather_usage = True diff --git a/transformer_engine/pytorch/module/linear.py.orig b/transformer_engine/pytorch/module/linear.py.orig deleted file mode 100644 index cf7f58947b..0000000000 --- a/transformer_engine/pytorch/module/linear.py.orig +++ /dev/null @@ -1,1710 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Linear API""" -from typing import Callable, Dict, Optional, Tuple, Union, List -from functools import reduce -from operator import mul as multiply_op -import warnings - -import torch - -import transformer_engine_torch as tex - -from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch import torch_version - -from .base import ( - fill_userbuffers_buffer_for_all_gather, - get_dummy_wgrad, - get_ub, - get_workspace, - TransformerEngineBaseModule, - _2X_ACC_FPROP, - _2X_ACC_DGRAD, - _2X_ACC_WGRAD, -) -from ._common import noop_cat, WeightGradStore, get_module_quantizers -from ..fp8 import FP8GlobalStateManager -from ..utils import ( - cast_if_needed, - clear_tensor_data, - divide, - init_method_constant, - requires_grad, - needs_quantized_gemm, - assert_dim_for_fp8_exec, - assert_dim_for_all_gather, - nvtx_range_pop, - nvtx_range_push, -) -from ..distributed import ( - set_tensor_model_parallel_attributes, - get_distributed_world_size, - allreduce, - symmetric_all_reduce, - reduce_scatter_along_first_dim, - gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, - _fsdp_scatter_tensors, - _fsdp_gather_tensors, -) -from ..cpp_extensions import ( - general_gemm, -) -from ..constants import GemmParallelModes, dist_group_type -from ..jit import no_torch_dynamo -from ..graph import is_graph_capturing -from ..tensor.quantized_tensor import ( - QuantizedTensor, - QuantizedTensorBase, - Quantizer, - prepare_for_saving, - restore_from_saved, -) -from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer -from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor.utils import is_experimental -from ..export import is_in_onnx_export_mode, assert_warmed_up -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload -from ...debug.pytorch.debug_state import TEDebugState - -__all__ = ["Linear"] - - -class _Linear(torch.autograd.Function): - """Linear semi-top level module - Calls custom cuda extensions. - """ - - @staticmethod - def forward( - ctx, - weight: torch.Tensor, - inp: torch.Tensor, - bias: Optional[torch.Tensor], - is_first_microbatch: Union[bool, None], - fp8: bool, - fp8_calibration: bool, - wgrad_store: WeightGradStore, - input_quantizer: Optional[Quantizer], - weight_quantizer: Optional[Quantizer], - output_quantizer: Optional[Quantizer], - grad_input_quantizer: Optional[Quantizer], - grad_weight_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], - fuse_wgrad_accumulation: bool, - cpu_offloading: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - sequence_parallel: bool, - tensor_parallel: bool, - activation_dtype: torch.dtype, - parallel_mode: Union[str, None], - is_grad_enabled: bool, - ub_overlap_rs_fprop: bool, - ub_overlap_ag_dgrad: bool, - ub_overlap_ag_fprop: bool, - ub_overlap_rs_dgrad: bool, - ub_bulk_dgrad: bool, - ub_bulk_wgrad: bool, - ub_name: str, - fp8_output: bool, # pylint: disable=unused-argument - fsdp_group: Union[dist_group_type, None], - module: torch.nn.Module, - skip_fp8_weight_update: bool, - symmetric_ar_type: str, - save_original_input: bool = False, - debug: Optional[bool] = False, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # NVTX label for profiling - nvtx_label = "transformer_engine._Linear.forward" - if ub_name is not None: - nvtx_label = f"{nvtx_label}.{ub_name}" - - # Make sure input dimensions are compatible - out_features, in_features = weight.shape - assert inp.shape[-1] == in_features, "GEMM not possible" - - # Configure tensor-parallel communication - tp_world_size = get_distributed_world_size(tp_group) - backward_needs_input = is_grad_enabled and weight.requires_grad - with_input_all_gather_nccl = ( - parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop - ) - - # Configure Userbuffers communication (comm+GEMM overlap) - if debug: # turn off userbuffers in debug mode - ub_overlap_rs_fprop = False - ub_overlap_ag_fprop = False - ub_overlap_rs_dgrad = False - ub_bulk_wgrad = False - ub_bulk_dgrad = False - ub_obj = None - ub_type = None - if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop", fp8) - ub_type = tex.CommOverlapType.RS - elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop", fp8) - ub_type = tex.CommOverlapType.AG - - # experimental recipe check - experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer) - - # ------------------------------------------------------ - # Prepare input tensor - # Note: Cast to expected dtype and perform tensor-parallel communication - # ------------------------------------------------------ - nvtx_range_push(f"{nvtx_label}.input_cast_comm") - inputmat = inp # Input tensor to save for backward (maybe sharded) - inputmat_total = None # Input tensor to pass to GEMM (gathered) - own_quantized_input = False - if fp8: - assert_dim_for_fp8_exec(inputmat, weight) - assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) - if save_original_input: - assert not isinstance( - input_quantizer, Float8Quantizer - ), "DelayedScaling recipe is not supported with save_original_input" - - if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor - - # Cast local input tensor if needed - if fp8 or debug: - if input_quantizer is None: - raise ValueError("Missing quantizer for input tensor") - if not isinstance(inputmat, QuantizedTensorBase) and not experimental: - own_quantized_input = True - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) - if isinstance( - input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) - ): - # All-gather is not supported with FP8 column-wise data - input_quantizer.set_usage(columnwise=False) - if save_original_input: - # No need for column-wise data since this - # tensor will not be cached for backward pass - input_quantizer.set_usage(columnwise=False) - own_quantized_input = False - inputmat = input_quantizer(inputmat) - else: - inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP - - # Initialize gathered input tensor - quantizer = None - if fp8 or debug: - quantizer = input_quantizer - quantizer.set_usage(rowwise=True, columnwise=False) - if with_input_all_gather_nccl: # Perform NCCL all-gather - inputmat_total, _ = gather_along_first_dim( - inputmat, - tp_group, - quantizer=quantizer, - ) - elif ub_overlap_ag_fprop: # Initialize Userbuffers all-gather - inputmat_total, _ = fill_userbuffers_buffer_for_all_gather( - ub_obj, - inputmat, - quantizer, - tp_group, - ) - - else: # Do not all-gather input tensor - if fp8 or debug: - if isinstance(inputmat, QuantizedTensorBase): - inputmat.update_usage(rowwise_usage=True) - else: - if input_quantizer is None: - raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage( - rowwise=True, columnwise=backward_needs_input and not save_original_input - ) - inputmat = input_quantizer(inputmat) - own_quantized_input = True - else: - inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP - inputmat_total = inputmat - nvtx_range_pop(f"{nvtx_label}.input_cast_comm") - # ------------------------------------------------------ - # Input tensor is ready for GEMM... - # ------------------------------------------------------ - - # ------------------------------------------------------ - # Prepare weight tensor - # ------------------------------------------------------ - weightmat = weight - if fp8 or debug: - # Configure quantizer - if weight_quantizer is not None: - columnwise_usage = is_grad_enabled and inp.requires_grad - if not columnwise_usage: - columnwise_usage = ( - is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ) - weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - - # Get quantized weight - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( - tensor=weight, - quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - workspace_dtype=activation_dtype, - ) - weightmat.update_usage(rowwise_usage=True) - - else: - weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP - # ------------------------------------------------------ - # Weight tensor is ready for GEMM... - # ------------------------------------------------------ - - # Cast bias to expected dtype - bias_dtype = activation_dtype - if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32: - # cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16 - bias_dtype = torch.bfloat16 - bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias - - # Calibrate quantizers if needed - if not fp8 and fp8_calibration: - if input_quantizer is not None: - input_quantizer.calibrate(inputmat_total) - if weight_quantizer is not None: - weight_quantizer.calibrate(weight) - - # Choose whether to use GEMM kernel with split accumulator - use_split_accumulator = _2X_ACC_FPROP - if fp8: - recipe = FP8GlobalStateManager.get_fp8_recipe() - if hasattr(recipe, "fp8_gemm_fprop"): - use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator - - # Configure output quantizer - if output_quantizer is not None: - output_quantizer.set_usage(rowwise=True, columnwise=False) - - # Output buffer for Userbuffers reduce-scatter - reduce_scatter_out = None - if ub_overlap_rs_fprop: - out_shape = list(inp.shape) - out_shape[0] //= tp_world_size - out_shape[-1] = out_features - reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) - - # ------------------------------------------------------ - # Forward GEMM - # Note: y = x * w^T - # ------------------------------------------------------ - nvtx_range_push(f"{nvtx_label}.gemm") - gemm_out, *_, reduce_scatter_out = general_gemm( - weightmat, - inputmat_total, - get_workspace(), - quantization_params=output_quantizer, - out_dtype=activation_dtype, - bias=bias, - use_split_accumulator=use_split_accumulator, - ub=ub_obj, - ub_type=ub_type, - extra_output=reduce_scatter_out, - ) - nvtx_range_pop(f"{nvtx_label}.gemm") - # ------------------------------------------------------ - # Finished forward GEMM... - # ------------------------------------------------------ - - # Deallocate GEMM input tensor if no longer needed - # TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically - # deallocated by GC. Manually deallocating is a temporary hack. - if with_input_all_gather_nccl: - clear_tensor_data(inputmat_total) - inputmat_total = None - - # ------------------------------------------------------ - # Prepare output tensor - # Note: Perform tensor-parallel communication - # ------------------------------------------------------ - out = None - if ub_overlap_rs_fprop: - out = reduce_scatter_out - elif parallel_mode == "row" and tp_size > 1: - nvtx_range_push(f"{nvtx_label}.row_parallel_comm") - out = gemm_out - if sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif tensor_parallel: - if symmetric_ar_type is not None: - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) - else: - out, _ = allreduce(out, tp_group) - nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") - else: - out = gemm_out - # ------------------------------------------------------ - # Output tensor is ready to return... - # ------------------------------------------------------ - - # ------------------------------------------------------ - # Cache state for backward pass - # ------------------------------------------------------ - - if is_grad_enabled: - if save_original_input: - inputmat = inp - - ctx.weight_quantizer = weight_quantizer - - ctx.backward_input_needs_gather = ( - weight.requires_grad and parallel_mode == "column" and sequence_parallel - ) - - # Discard unneeded data in input tensor - if ( - backward_needs_input - and own_quantized_input - and isinstance(inputmat, QuantizedTensorBase) - ): - if ( - ctx.backward_input_needs_gather - and weight_quantizer.supports_only_rowwise_all_gather() - ): - # All-gather is not supported with FP8 column-wise data - inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) - else: - # Discard row-wise data since it is not needed in backward pass - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) - - # Cached input tensor - saved_inputmat = None - if backward_needs_input: - saved_inputmat = inputmat - - # Weight with column-wise usage is needed for dgrad GEMM. - if inp.requires_grad: - if isinstance(weightmat, QuantizedTensorBase): - weightmat.update_usage(columnwise_usage=True) - - if cpu_offloading and saved_inputmat is not None: - mark_activation_offload(saved_inputmat) - - # Scatter intermediate/activation tensors saved for the backward pass - # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights - nvtx_range_push(f"{nvtx_label}.fsdp_scatter") - ctx.fsdp_group = fsdp_group - ctx.fsdp_shapes = _fsdp_scatter_tensors( - fsdp_group, - saved_inputmat, - weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None, - ) - nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - - if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - - if ctx.grad_added_to_main_grad: - # If you are passing torch.nn.Parameter through the Torch hooks, you will - # get back torch.Tensor. Torch rips off the Parameter wrapper. - # You need to preserve the weight object to have all the attributes user - # sets for the weights. Because of this, it is not recommended to offload - # weights if weights are externally touched outside this module - ctx.weight_object = weight - - # TODO(ksivamani): Check memory usage - tensors_to_save, tensor_objects = prepare_for_saving( - saved_inputmat, - weightmat, - weight, - bias, - ) - ctx.save_for_backward(*tensors_to_save) - ctx.tensor_objects = tensor_objects - - ctx.activation_dtype = activation_dtype - ctx.fp8 = fp8 - ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.input_quantizer = input_quantizer - ctx.grad_input_quantizer = grad_input_quantizer - ctx.grad_weight_quantizer = grad_weight_quantizer - ctx.grad_output_quantizer = grad_output_quantizer - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - if fuse_wgrad_accumulation and weight.requires_grad: - # This check is needed to ensure that main_grad is not created - # during the forward pass when using MCore FSDP as it creates - # the main_grad buffer lazily before backprop - if hasattr(weight, "__fsdp_param__"): - # MCore FSDP creates main_grad lazily before backward - ctx.main_grad_func = weight.get_main_grad - else: - ctx.main_grad_func = lambda: weight.main_grad - - ctx.debug = debug - ctx.experimental = experimental - ctx.cpu_offloading = cpu_offloading - ctx.is_first_microbatch = is_first_microbatch - ctx.use_bias = bias is not None - ctx.sequence_parallel = sequence_parallel - ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape - ctx.parallel_mode = parallel_mode - ctx.tp_group = tp_group - ctx.ub_overlap_ag = ub_overlap_ag_dgrad - ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - ctx.ub_bulk_dgrad = ub_bulk_dgrad - ctx.ub_bulk_wgrad = ub_bulk_wgrad - ctx.ub_name = ub_name - ctx.tp_size = tp_size - ctx.requires_dgrad = inp.requires_grad - ctx.requires_wgrad = weight.requires_grad - ctx.reduce_and_update_bwd_fp8_tensors = False - - ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and requires_grad(inp, weight, bias): - _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE - ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() - if in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module - ctx.wgrad_store = wgrad_store - - # ------------------------------------------------------ - # Cached state for backward pass is ready... - # ------------------------------------------------------ - - return out - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring - - # NVTX label for profiling - nvtx_label = "transformer_engine._Linear.backward" - if ctx.ub_name is not None: - nvtx_label = f"{nvtx_label}.{ctx.ub_name}" - - with torch.cuda.nvtx.range("_Linear_backward"): - saved_tensors = ctx.saved_tensors - inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking - restore_from_saved(ctx.tensor_objects, saved_tensors) - ) - - # Delete the references to tensor objects once they've been consumed - # by the `restore_from_saved` method to construct back the actual tensors. - ctx.tensor_objects = None - - # Since main_grad can be modified inplace, it should not be a part of saved_tensors - main_grad = ( - ctx.main_grad_func() - if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad - else None - ) - - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: - weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - weight.main_grad = main_grad - - # Gather intermediate/activation tensors if needed - # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already - # shards/unshards the base weights so we don't do it ourselves - nvtx_range_push(f"{nvtx_label}.fsdp_gather") - _fsdp_gather_tensors( - ctx.fsdp_group, - ctx.fsdp_shapes, - inputmat, - weight_fp8, - ) - nvtx_range_pop(f"{nvtx_label}.fsdp_gather") - - # Configure Userbuffers communication (comm+GEMM overlap) - ctx.ub_obj_gradout = None - ub_obj_dgrad = None - ub_obj_wgrad = None - ub_type_dgrad = None - ub_type_wgrad = None - dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] - if ctx.ub_overlap_ag: - # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout - ub_type_dgrad = tex.CommOverlapType.AG - elif ctx.ub_overlap_rs_dgrad: - # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout - ub_type_dgrad = tex.CommOverlapType.RS - else: - if ctx.ub_bulk_dgrad: - # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout - ub_type_dgrad = tex.CommOverlapType.AG - if ctx.ub_bulk_wgrad: - # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) - ub_type_wgrad = tex.CommOverlapType.RS - - # -------------------------------------------------- - # Prepare grad output tensor - # Note: Cast to expected dtype and perform tensor-parallel communication - # -------------------------------------------------- - - # Unmodified grad output tensor - grad_output_arg = grad_output - - # Configure quantizer for grad output tensor - # Note: dgrad GEMM requires row-wise usage, wgrad GEMM - # requires column-wise usage - if ctx.grad_output_quantizer is not None: - quantizer = ctx.grad_output_quantizer - quantizer.set_usage(rowwise=True, columnwise=True) - if ctx.ub_overlap_ag: - # Userbuffers only supports communication for one - # tensor usage at a time. Configure quantizer with - # usage for only dgrad GEMM. - quantizer.set_usage(columnwise=False) - - # Adjust the quantization direction approach depending - # on whether wgrad calculations will be performed. - # NOTE: If requires_dgrad is False, disabling `rowwise` quantization and keeping `columnwise` quantization - # results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!` - # NOTE: For `ctx.bias is True`, selected quantize kernel errors with - # `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.` - if ( - not ctx.use_bias - and not ctx.requires_wgrad - and ctx.grad_output_quantizer is not None - ): - ctx.grad_output_quantizer.set_usage(columnwise=False) - - # Prepare grad output tensor - nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") - ( - grad_output, - grad_bias, - ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, - grad_output, - ctx.parallel_mode == "row", - ctx.grad_output_quantizer, - ) - nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") - - # -------------------------------------------------- - # Grad output tensor is ready for computing grad input... - # -------------------------------------------------- - - # -------------------------------------------------- - # Prepare input tensor - # Note: Input tensor is needed for wgrad GEMM. - # Tensor-parallel communication is overlapped with dgrad - # GEMM. - # -------------------------------------------------- - inputmat_total = None - inputmat_total_work = None - if ctx.requires_wgrad: - if ctx.fp8 or ctx.debug: - if isinstance(inputmat, QuantizedTensorBase): - # Input tensor is already quantized - pass - elif ctx.debug or ctx.experimental: - # Debug quantizer will be applied immediately before wgrad GEMM - pass - else: - # Quantize input tensor - quantizer = ctx.input_quantizer - if quantizer.supports_only_rowwise_all_gather(): - # All-gather is not supported with FP8 column-wise data - quantizer.set_usage( - rowwise=True, - columnwise=not ctx.backward_input_needs_gather, - ) - else: - quantizer.set_usage(rowwise=False, columnwise=True) - inputmat = quantizer(inputmat) - else: - if isinstance(inputmat, QuantizedTensorBase): - inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) - else: - inputmat = cast_if_needed(inputmat, ctx.activation_dtype) - if ctx.backward_input_needs_gather: - quantizer = None - if ctx.fp8 or ctx.debug: - quantizer = ctx.input_quantizer - if quantizer.supports_only_rowwise_all_gather(): - # If data is in FP8, we compute FP8 transposes manually - quantizer.set_usage(rowwise=True, columnwise=False) - else: - # wgrad GEMM requires input with column-wise usage - quantizer.set_usage(rowwise=False, columnwise=True) - if ctx.ub_bulk_dgrad: - inputmat_total, _ = fill_userbuffers_buffer_for_all_gather( - ub_obj_dgrad, - inputmat, - quantizer, - ctx.tp_group, - ) - else: - nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") - inputmat_total, inputmat_total_work = gather_along_first_dim( - inputmat, - ctx.tp_group, - async_op=True, - quantizer=quantizer, - ) - nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") - else: - inputmat_total = inputmat - # -------------------------------------------------- - # Input tensor is ready for computing grad weight... - # -------------------------------------------------- - - # -------------------------------------------------- - # Compute grad input tensor - # -------------------------------------------------- - - dgrad = None - dgrad_work = None - if ctx.requires_dgrad: - - # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): - grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase): - weight_fp8.update_usage(columnwise_usage=True) - - # Choose whether to use GEMM kernel with split accumulator - use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: - recipe = ctx.fp8_recipe - if hasattr(recipe, "fp8_gemm_dgrad"): - use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator - - # Update grad input quantizer - if ctx.grad_input_quantizer is not None: - ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) - - # Output buffers for Userbuffers reduce-scatter - gemm_out = None - reduce_scatter_out = None - if ctx.ub_overlap_rs_dgrad: - reduce_scatter_out = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device - ) - elif ctx.ub_bulk_wgrad: - gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False) - - # dgrad GEMM - # Note: dx = dy * w - - nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - gemm_out, *_, reduce_scatter_out = general_gemm( - weight_fp8, - grad_output, - get_workspace(), - layout="NN", - grad=True, - quantization_params=ctx.grad_input_quantizer, - out=gemm_out, - out_dtype=ctx.activation_dtype, - use_split_accumulator=use_split_accumulator, - ub=ub_obj_dgrad, - ub_type=ub_type_dgrad, - extra_output=reduce_scatter_out, - bulk_overlap=ctx.ub_bulk_dgrad, - ) - nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") - - # Prepare grad input tensor - # Note: Perform tensor-parallel communication - if ctx.ub_overlap_rs_dgrad: - dgrad = reduce_scatter_out - elif ctx.ub_bulk_wgrad: - dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) - elif ctx.parallel_mode == "column" and ctx.tp_size > 1: - nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") - dgrad = gemm_out - if ctx.sequence_parallel: - dgrad, dgrad_work = reduce_scatter_along_first_dim( - dgrad, - ctx.tp_group, - async_op=True, - ) - else: - dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) - nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") - else: - dgrad = gemm_out - - # -------------------------------------------------- - # Grad input tensor has been computed... - # -------------------------------------------------- - - # -------------------------------------------------- - # Compute grad weight - # -------------------------------------------------- - - wgrad = None - if ctx.requires_wgrad: - - # Prepare input tensor - # Note: Synchronize tensor-parallel communication and - # make sure required data is available - if inputmat_total_work is not None: - inputmat_total_work.wait() - inputmat_total_work = None - if ctx.fp8 or ctx.debug: - if isinstance(inputmat_total, QuantizedTensorBase): - inputmat_total.update_usage(columnwise_usage=True) - else: - ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) - inputmat_total = ctx.input_quantizer(inputmat_total) - - # Prepare grad output tensor - # Note: Synchronize tensor-parallel communication and - # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): - # UB does not support pipelined overlapping grad output - # all-gather with wgrad GEMM. Also, we can't - # convert row-scaled MXFP8 to column-scaled, so we - # can't reuse the grad output that was gathered - # for the dgrad GEMM. We work around by explicitly - # overlapping the AG operation with the dgrad GEMM. - - # Get the communication stream from the dgrad GEMM to use for the AG - dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() - - # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) - - ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - - # We use the send stream to copy into the userbuffers. - # This is the same stream that we will use to access the data in the AG, - # so we dont need to add any syncs yet. - with torch.cuda.stream(dgrad_send_stream): - grad_output, _ = fill_userbuffers_buffer_for_all_gather( - ub_obj_overlap_wgrad, - grad_output_arg, - ctx.grad_output_quantizer, - ctx.tp_group, - ) - - # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm - tex.bulk_overlap_ag_with_external_gemm( - ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream - ) - - if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): - grad_output.update_usage(columnwise_usage=True) - else: - ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.grad_output_quantizer(grad_output) - - # Figure out whether to use split accumulator - use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: - recipe = ctx.fp8_recipe - if hasattr(recipe, "fp8_gemm_wgrad"): - use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator - - # Figure out whether to output wgrad GEMM directly into main grad - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - # Output buffer for overlapping FP8 grad input - # reduce-scatter with wgrad GEMM - reduce_scatter_out = None - if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): - reduce_scatter_out = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device - ) - - # Arguments to include in wgrad GEMM closure - wgrad_gemm_kwargs = { - "workspace": get_workspace(), - "out_dtype": ( - main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype - ), - "quantization_params": ctx.grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, - "layout": "NT", - "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), - "use_split_accumulator": use_split_accumulator, - "grad": True, - "ub": ub_obj_wgrad, - "ub_type": ub_type_wgrad, - "extra_output": reduce_scatter_out, - "bulk_overlap": ctx.ub_bulk_wgrad, - } - - def wgrad_gemm( - x: torch.Tensor, - dy: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Perform wgrad GEMM: dw = dy^T * x - - May be fused with bgrad computation. - - May be called outside of this function to enable - some advanced communication/compute overlapping. - - """ - nvtx_range_push(f"{nvtx_label}.wgrad_gemm") - dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs) - nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") - return dw, db - - # Choose whether to call wgrad GEMM now or delay - if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): - if ( - wgrad_gemm_kwargs["ub"] is not None - or wgrad_gemm_kwargs["ub_type"] is not None - or wgrad_gemm_kwargs["extra_output"] is not None - or wgrad_gemm_kwargs["bulk_overlap"] - ): - raise NotImplementedError( - "Delayed weight grad computation is not supported " - "with Userbuffers (tensor-parallel communication overlapping)" - ) - ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm) - else: - - # Call wgrad GEMM now - wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output) - - # Update grad bias if needed - if grad_bias is None: - grad_bias = grad_bias_ - del grad_bias_ - - # Deallocate tensors if permitted - if ctx.owns_input: - # Input tensor is internal - clear_tensor_data(inputmat_total) - elif ctx.backward_input_needs_gather: - # Gathered input tensor is internal - clear_tensor_data(inputmat_total) - if ctx.parallel_mode == "row" and ctx.sequence_parallel: - # Gathered grad output tensor is internal - clear_tensor_data(grad_output) - - # Update grad input if overlapping reduce-scatter with wgrad GEMM - if ctx.ub_bulk_wgrad: - if ub_obj_wgrad.is_fp8_ubuf(): - dgrad = reduce_scatter_out - else: - dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone() - - # -------------------------------------------------- - # Grad weight has been computed... - # -------------------------------------------------- - - # Don't return grad bias if not needed - if not ctx.use_bias: - grad_bias = None - - # Make sure all tensor-parallel communication is finished - if inputmat_total_work is not None: - inputmat_total_work.wait() - inputmat_total_work = None - if dgrad_work is not None: - dgrad_work.wait() - dgrad_work = None - - if ctx.requires_wgrad: - # Handle custom DDP from mcore. - if ( - ctx.fuse_wgrad_accumulation - and weight is not None - and hasattr(weight, "grad_added_to_main_grad") - ): - weight.grad_added_to_main_grad = True - if getattr(weight, "zero_out_wgrad", False): - wgrad = get_dummy_wgrad( - list(weight.main_grad.shape), - weight.dtype, - zero=True, - ) - else: - wgrad = get_dummy_wgrad( - list(weight.main_grad.shape), - weight.dtype, - ) - elif ctx.fuse_wgrad_accumulation: - wgrad = None - else: - wgrad = None - - # Update FP8 scaling factors if needed - if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): - nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) - nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") - - # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): - _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) - return ( - wgrad, - dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, - grad_bias, - None, # is_first_microbatch - None, # fp8 - None, # fp8_calibration - None, # wgrad_store - None, # input_quantizer - None, # weight_quantizer - None, # output_quantizer - None, # grad_input_quantizer - None, # grad_weight_quantizer - None, # grad_output_quantizer - None, # fuse_wgrad_accumulation - None, # cpu_offloading - None, # tp_group - None, # tp_size - None, # sequence_parallel - None, # tensor_parallel - None, # activation_dtype - None, # parallel_mode - None, # is_grad_enabled - None, # ub_overlap_rs_fprop - None, # ub_overlap_ag_dgrad - None, # ub_overlap_ag_fprop - None, # ub_overlap_rs_dgrad - None, # ub_bulk_dgrad - None, # ub_bulk_wgrad - None, # ub_name - None, # fp8_output - None, # fsdp_group - None, # module - None, # skip_fp8_weight_update - None, # symmetric_ar_type - None, # save_original_input - None, # debug - ) - - -class Linear(TransformerEngineBaseModule): - """Applies a linear transformation to the incoming data :math:`y = xA^T + b` - - On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`. - - Parameters - ---------- - in_features : int - size of each input sample. - out_features : int - size of each output sample. - bias : bool, default = `True` - if set to `False`, the layer will not learn an additive bias. - init_method : Callable, default = `None` - used for initializing weights in the following way: `init_method(weight)`. - When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. - get_rng_state_tracker : Callable, default = `None` - used to get the random number generator state tracker for initializing weights. - rng_tracker_name : str, default = `None` - the param passed to get_rng_state_tracker to get the specific rng tracker. - parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None - Configuration for splitting the weight and bias tensors along dim 0 into - multiple PyTorch parameters. If a list or tuple of strings is provided, - they are used to make the names of equally-sized parameters. If a dict - (preferably an OrderedDict) is provided, the keys are used as names and - values as split sizes along dim 0. The resulting parameters will have - names that end in `_weight` or `_bias`, so trailing underscores are - stripped from any provided names. - device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will be allocated. It is the user's - responsibility to ensure all parameters are moved to the GPU before running the - forward pass. - name: str, default = `None` - name of the module, currently used for debugging purposes. - - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - tp_size : int, default = 1 - used as TP (tensor parallel) world size when TP groups are not formed during - initialization. In this case, users must call the - `set_tensor_parallel_group(tp_group)` method on the initialized module before the - forward pass to supply the tensor parallel group needed for tensor and sequence - parallel collectives. - parallel_mode : {None, 'column', 'row'}, default = `None` - used to decide whether this Linear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - return_bias : bool, default = `False` - when set to `True`, this module will not apply the additive bias itself, but - instead return the bias value during the forward pass together with the - output of the linear transformation :math:`y = xA^T`. This is useful when - the bias addition can be fused to subsequent operations. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` - it controls the type used to allocate the initial parameters. Useful when - the model is trained with lower precision and the original FP32 parameters - would not fit in GPU memory. - delay_wgrad_compute : bool, default = `False` - Whether or not to delay weight gradient computation. If set to `True`, - it's the user's responsibility to call `module.backward_dw` to compute - weight gradients. - symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None - Type of symmetric memory all-reduce to use during the forward pass. - This can help in latency bound communication situations. - Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce - is used. - save_original_input : bool, default = `False` - If set to `True`, always saves the original input tensor rather than the - cast tensor. In some scenarios, the input tensor is used by multiple modules, - and saving the original input tensor may reduce the memory usage. - Cannot work with FP8 DelayedScaling recipe. - """ - - def __init__( - self, - in_features: int, - out_features: int, - sequence_parallel: bool = False, - fuse_wgrad_accumulation: bool = False, - tp_group: Optional[dist_group_type] = None, - tp_size: int = 1, - get_rng_state_tracker: Optional[Callable] = None, - rng_tracker_name: Optional[str] = None, - init_method: Optional[Callable] = None, - bias: bool = True, - return_bias: bool = False, - params_dtype: Optional[torch.dtype] = None, - parallel_mode: Optional[str] = None, - parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, - device: Union[torch.device, str] = "cuda", - ub_overlap_ag: bool = False, - ub_overlap_rs: bool = False, - ub_overlap_rs_dgrad: bool = False, - ub_bulk_dgrad: bool = False, - ub_bulk_wgrad: bool = False, - ub_name: Optional[str] = None, - delay_wgrad_compute: bool = False, - symmetric_ar_type: Optional[str] = None, - save_original_input: bool = False, - name: Optional[str] = None, - ) -> None: - super().__init__() - - params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype - self.in_features = in_features - self.out_features = out_features - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - self.use_bias = bias - self.return_bias = return_bias - self.apply_bias = bias and not return_bias - self.get_rng_state_tracker = get_rng_state_tracker - self.rng_tracker_name = rng_tracker_name - self.symmetric_ar_type = symmetric_ar_type - self.save_original_input = save_original_input - self.name = name - - self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) - - if device == "meta": - assert parameters_split is None, "Cannot split module parameters on 'meta' device." - if tp_group is None: - self.tp_size = tp_size - if tp_size == 1: - self.set_tensor_parallel_group(tp_group) - else: - self.tp_size = get_distributed_world_size(tp_group) - self.set_tensor_parallel_group(tp_group) - self.set_nccl_overlap_warning_if_tp() - - self.parallel_mode = parallel_mode - assert ( - self.parallel_mode in GemmParallelModes - ), f"parallel_mode {parallel_mode} not supported" - - if self.parallel_mode == "column": - self.out_features = divide(self.out_features, self.tp_size) - elif self.parallel_mode == "row": - self.in_features = divide(self.in_features, self.tp_size) - - self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - - # Column parallel TP overlap options - self.ub_overlap_ag_fprop = ( - self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag - ) - self.ub_overlap_rs_dgrad = ( - self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad - ) - self.ub_bulk_dgrad = ( - self.parallel_mode == "column" - and self.sequence_parallel - and ub_bulk_dgrad - and not self.ub_overlap_rs_dgrad - ) - self.ub_bulk_wgrad = ( - self.parallel_mode == "column" - and self.sequence_parallel - and ub_bulk_wgrad - and not self.ub_overlap_rs_dgrad - ) - - # Row parallel TP overlap options - self.ub_overlap_rs_fprop = ( - self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs - ) - self.ub_overlap_ag_dgrad = ( - self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag - ) - - if any( - [ - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - ] - ): - assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." - self.ub_name = ub_name - - if self.symmetric_ar_type is not None: - assert torch_version() >= ( - 2, - 7, - 0, - ), "Torch version must be at least 2.7 to use symmetric memory" - - # Initialize params in FP8 - with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() - - # Contiguous buffers for params - weight_tensor = torch.empty( - self.out_features, - self.in_features, - device=device, - dtype=params_dtype, - ) - bias_tensor = None - if self.use_bias: - bias_tensor = torch.empty( - self.out_features, - device=device, - dtype=params_dtype, - ) - - # Configure parameter splits - self.weight_names = [] - self.bias_names = [] - self.parameter_split_sizes = [] - if parameters_split is None: - # Split into a single parameter by default - self.weight_names = ["weight"] - self.bias_names = ["bias"] - self.parameter_split_sizes = [out_features] - elif not parameters_split: - raise ValueError("Cannot split weight buffer into 0 parameters") - elif isinstance(parameters_split, dict): - # Split parameters with provided sizes - for name, split_size in parameters_split.items(): - self.weight_names.append(f"{name.rstrip('_')}_weight") - self.bias_names.append(f"{name.rstrip('_')}_bias") - self.parameter_split_sizes.append(split_size) - elif all(isinstance(name, str) for name in parameters_split): - # Split parameters evenly - split_size = out_features // len(parameters_split) - for name in parameters_split: - self.weight_names.append(f"{name.rstrip('_')}_weight") - self.bias_names.append(f"{name.rstrip('_')}_bias") - self.parameter_split_sizes.append(split_size) - else: - raise TypeError("Invalid configuration for parameters split") - - # Make sure parameter splits are valid - if sum(self.parameter_split_sizes) != out_features: - raise ValueError( - f"Trying to split weight buffer ({out_features=}) " - f"with split sizes {self.parameter_split_sizes}" - ) - - # Adjust parameter splits for tensor-parallel distribution - if self.parallel_mode == "column": - for i, size in enumerate(self.parameter_split_sizes): - if size % self.tp_size != 0: - raise RuntimeError( - f"Attempting to distribute a parameter with out_features={size} " - f"between {self.tp_size} tensor-parallel processes" - ) - self.parameter_split_sizes[i] = size // self.tp_size - - # Construct weight parameters - # Note: Register weights together so that they are adjacent to - # each other in Linear.parameters(). This makes it more likely - # that they will stay contiguous if the weights are - # manipulated externally, e.g. by FSDP. - offset = 0 - for i, split_size in enumerate(self.parameter_split_sizes): - split_start = offset - offset += split_size - split_end = offset - - # Check if parameters are subviews of buffers - is_subview = (split_start, split_end) != (0, self.out_features) - if is_subview and with_fp8_params: - raise RuntimeError( - "Splitting QuantizedTensor into multiple params is not supported" - ) - - # Construct weight parameter - self.register_parameter( - self.weight_names[i], - torch.nn.Parameter(weight_tensor[split_start:split_end]), - init_fn=init_method, - get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - ) - - # Construct bias parameters if needed - if self.use_bias: - offset = 0 - for i, split_size in enumerate(self.parameter_split_sizes): - split_start = offset - offset += split_size - split_end = offset - self.register_parameter( - self.bias_names[i], - torch.nn.Parameter(bias_tensor[split_start:split_end]), - init_fn=init_method_constant(0.0), - ) - else: - for name in self.bias_names: - bias = torch.Tensor().to(dtype=params_dtype, device=device) - setattr(self, name, bias) - - if with_fp8_params: - self.init_fp8_metadata() - - self.reset_parameters(defer_init=device == "meta") - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.apply_bias: - self.gemm_bias_unfused_add = True - else: - self.gemm_bias_unfused_add = False - - if self.wgrad_store.delay_wgrad_compute(): - for name, param in self.named_parameters(): - if name in self.weight_names or name in self.bias_names: - param.skip_backward_post_hook = True - - def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: - """Init scales and amaxes for fwd | bwd.""" - super().set_meta_tensor(fwd, recipe) - - # customize quantizers based on each recipe & layer configs - recipe = FP8GlobalStateManager.get_fp8_recipe() - if recipe.float8_current_scaling(): - self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.float8_block_scaling(): - self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) - elif recipe.nvfp4(): - self._customize_quantizers_nvfp4(fwd, recipe) - # elif for other recipes (mxfp8, etc.) - - def reset_parameters(self, defer_init=False): - super().reset_parameters(defer_init=defer_init) - - if not defer_init: - # Set parallelism attributes for linear weights - for weight in self.weight_names: - set_tensor_model_parallel_attributes( - tensor=getattr(self, weight), - is_parallel=True, - dim=1 if self.parallel_mode == "row" else 0, - stride=1, - ) - - # Set parallelism attributes for linear biases - if self.use_bias: - for bias in self.bias_names: - if self.parallel_mode == "row": - setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) - elif self.parallel_mode == "column": - set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - - @no_torch_dynamo() - def forward( - self, - inp: torch.Tensor, - is_first_microbatch: Optional[bool] = None, - fp8_output: Optional[bool] = False, - fp8_grad: Optional[bool] = False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: - """ - Apply the linear transformation to the input. - - Parameters - ---------- - inp : torch.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - * it also allows skipping gradient accumulation during the - first microbatch (since it is the first gradient being - produced) - """ - if is_in_onnx_export_mode(): - return self.onnx_forward(inp, fp8_output) - - debug = self.is_debug_iter() - - if FP8GlobalStateManager.fp8_graph_capturing(): - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() - else: - skip_fp8_weight_update = None - if skip_fp8_weight_update is not None: - is_first_microbatch = False - - if self.ub_overlap_rs_fprop: - if get_ub( - self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): - fp8_output = True - if self.ub_overlap_rs_dgrad: - if get_ub( - self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() - ).is_fp8_ubuf(): - fp8_grad = True - - with torch.cuda.device( - getattr(self, list(self.named_parameters())[0][0]).device - ), self.prepare_forward( - inp, - allow_non_contiguous=isinstance(inp, QuantizedTensor), - ) as inp: - - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - - quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad) - - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - - if torch.is_grad_enabled(): - linear_fn = _Linear.apply - args = [] - else: - linear_fn = _Linear.forward - args = [None] - args += ( - weight_tensor, - inp, - bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - self.fuse_wgrad_accumulation, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - torch.is_grad_enabled(), - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.ub_name, - fp8_output, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - self.save_original_input, - debug, - ) - out = linear_fn(*args) - if self.gemm_bias_unfused_add: - out = out + cast_if_needed(bias_tensor, self.activation_dtype) - - if self.return_bias: - return out, cast_if_needed(bias_tensor, self.activation_dtype) - return out - - def _get_quantizers(self, fp8_output, fp8_grad): - if not self.fp8: - return [None] * 6 - grad_input_quantizer = None - grad_weight_quantizer = None - grad_output_quantizer = None - output_quantizer = None - input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - input_quantizer.internal = True - (weight_quantizer,) = self._get_weight_quantizers() - if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] - if torch.is_grad_enabled(): - grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] - grad_output_quantizer.internal = True - if fp8_grad: - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] - return ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) - - def _get_debug_quantizers(self, fp8_output, fp8_grad): - original_quantizers = self._get_quantizers(fp8_output, fp8_grad) - assert TEDebugState.debug_enabled - from ...debug.pytorch.debug_quantization import DebugQuantizer - - names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] - return tuple( - DebugQuantizer(self.name, name, q, self.tp_group) - for name, q in zip(names, original_quantizers) - ) - - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: - """Get the weight tensors of the module.""" - unfused_weights = [getattr(self, name) for name in self.weight_names] - if any(isinstance(w, QuantizedTensor) for w in unfused_weights): - if self.fp8: - if len(unfused_weights) != 1: - raise RuntimeError( - "Splitting QuantizedTensor into multiple params is not supported" - ) - else: - warnings.warn( - "You are using quantized weights without quantized compute. " - "Please make sure this is intentional." - ) - unfused_weights = [w.dequantize() for w in unfused_weights] - return unfused_weights - - def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # Get concatenated weight and bias tensors - unfused_weights = self._get_weight_tensors() - if any(isinstance(w, QuantizedTensor) for w in unfused_weights): - if self.fp8: - if len(unfused_weights) != 1: - raise RuntimeError( - "Splitting QuantizedTensor into multiple params is not supported" - ) - else: - warnings.warn( - "You are using quantized weights without quantized compute. " - "Please make sure this is intentional." - ) - unfused_weights = [w.dequantize() for w in unfused_weights] - - weight_tensor = noop_cat(unfused_weights) - if self.use_bias: - bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) - else: - bias_tensor = None - - return weight_tensor, bias_tensor - - def onnx_forward( - self, - inp: torch.Tensor, - fp8_output: bool, - ) -> torch.Tensor: - """ - ONNX-compatible version of the forward function that provides numerical equivalence - while only using operations that have defined ONNX symbolic translations. - This simplified implementation is designed specifically for inference scenarios. - """ - from ..export import onnx_gemm - - assert_warmed_up(self) - assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export." - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - ( - input_quantizer, - weight_quantizer, - output_quantizer, - *_, - ) = self._get_quantizers(fp8_output, False) - inp_dtype = inp.dtype - - if input_quantizer is not None: - inp_q = input_quantizer.onnx_quantize(inp) - inp = input_quantizer.onnx_dequantize(inp_q) - inp = inp.to(inp_dtype) - - if weight_quantizer is not None: - weight_q = weight_quantizer.onnx_quantize(weight_tensor) - weight_tensor = weight_quantizer.onnx_dequantize(weight_q) - if bias_tensor is not None: - bias_tensor = bias_tensor.to(inp_dtype) - weight_tensor = weight_tensor.to(inp_dtype) - - if self.apply_bias: - output = onnx_gemm(weight_tensor, inp, bias_tensor) - else: - output = onnx_gemm(weight_tensor, inp, None) - - if output_quantizer is not None: - raise NotImplementedError("ONNX export of quantized output is not supported") - - if self.return_bias: - return output, bias_tensor - - return output - - def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on current scaling recipe + linear.""" - assert ( - recipe.float8_current_scaling() - ), "current scaling recipe quantizer customization here" - if fwd: - # set configs about amax epsilon and power_2_scale - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon - # also set weight quantizer with same amax_epsilon & power_2_scale - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT - ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT - ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon - # paralle related - if self.sequence_parallel and self.parallel_mode == "column": - # customize input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].amax_reduction_group = self.tp_group - else: - # set grad_output_quantizer with amax epsilon and power_2_scale - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - # parallel related - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group - - def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on current scaling recipe + linear.""" - assert recipe.nvfp4(), "Incorrect recipe." - if fwd: - if self.sequence_parallel and self.parallel_mode == "column": - # customize input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].amax_reduction_group = self.tp_group - else: - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group - - def _get_weight_quantizers(self) -> List[Quantizer]: - """Get the weight quantizers of the module.""" - if not self.fp8 and not self.fp8_calibration: - return [None] - weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - weight_quantizer.internal = True - return [weight_quantizer] - - def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on blockwise scaling recipe + linear.""" - assert ( - recipe.float8_block_scaling() - ), "blockwise scaling recipe quantizer customization here" - - if fwd: - if self.sequence_parallel and self.parallel_mode == "column": - # set compact for inp tensor X - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].all_gather_usage = True - else: - if self.sequence_parallel and self.parallel_mode == "row": - # set compact for grad_output tensor dY - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].all_gather_usage = True