diff --git a/tests/pytorch/distributed/test_fused_linear_comms.py b/tests/pytorch/distributed/test_fused_linear_comms.py new file mode 100644 index 0000000000..b2e07437da --- /dev/null +++ b/tests/pytorch/distributed/test_fused_linear_comms.py @@ -0,0 +1,258 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +import torch.distributed._symmetric_memory as symm_mem +import time +import argparse +import os +import uuid +import math + + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser( + description=( + "Run a linear layer with Transformer Engine, CUDA Graphs, and Tensor Parallelism" + ) + ) + parser.add_argument("--hidden_size", type=int, default=8192, help="Input feature size") + parser.add_argument("--batch_size", type=int, default=2048, help="Batch size") + parser.add_argument("--fc_factor", type=int, default=4, help="MLP hidden layer factor") + parser.add_argument( + "--cuda_graph", action="store_true", help="Use CUDA Graphs (pass this flag to enable)" + ) + parser.add_argument("--validate", action="store_true", help="Validate allreduce ubnext") + parser.add_argument( + "--comm_type", + type=str, + default="sym", + help="Comm type: none,nccl,sym,ub,ubnext,ubnext_add,ubnext_rms", + ) + parser.add_argument( + "--sym_type", + type=str, + default="multimem_all_reduce", + help="pytorch sym type: one_shot, two_shot, multimem_all_reduce", + ) + parser.add_argument("--iterations", type=int, default=1000, help="Number of iterations") + parser.add_argument( + "--tp_size", + type=int, + default=None, + help="Tensor parallelism size (defaults to number of GPUs)", + ) + parser.add_argument("--eps", type=float, default=1e-5, help="Epsilon") + args = parser.parse_args() + + # Check CUDA availability and get device count + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Test requires NVIDIA GPUs.") + + num_devices = torch.cuda.device_count() + if num_devices == 0: + raise RuntimeError("No CUDA devices found.") + + # Set tensor parallelism size + tp_size = ( + args.tp_size if args.tp_size is not None else int(os.environ.get("WORLD_SIZE", num_devices)) + ) + + # Initialize distributed environment for each GPU + myrank = int(os.environ.get("RANK", 0)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + num_nodes = world_size // local_size + if num_nodes > 1: + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + # Set device + device = torch.device(f"cuda:{local_rank}") + # Initialize torch.distributed for tensor parallelism + # Only set defaults if not already set by torchrun + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "29500" + torch.cuda.set_device(device) + + torch.distributed.init_process_group( + backend="nccl", world_size=tp_size, rank=myrank % tp_size, device_id=device + ) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + # Transformer Engine handles tensor parallelism internally when distributed is initialized + # Set environment variable for tensor parallelism size + os.environ["NVTE_TP_SIZE"] = str(tp_size) + + ub_cfgs = { + "proj_fprop": { + "method": "pipeline", + "num_splits": 4, + "is_reduce_scatter": True, + "num_sm": 32, + "atomic_gemm": False, + "aggregate": False, + "cga_size": 4, + "set_sm_margin": True, + "fp8_buf": False, + "use_ce": False, + }, + "fc1_fprop": { + "method": "ring_exchange", + "num_splits": 1, + "is_reduce_scatter": False, + "num_sm": 1, + "atomic_gemm": False, + "aggregate": False, + "cga_size": 1, + "set_sm_margin": False, + "fp8_buf": False, + "use_ce": True, + }, + } + + # Initialize model with BF16 precision + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT") or args.comm_type == "ub": + te.module.base.initialize_ub( + [args.batch_size, args.hidden_size], + tp_size, + use_fp8=False, + dtype=torch.bfloat16, + bootstrap_backend="nccl", + ub_cfgs=ub_cfgs, + ) + + proj = te.Linear( + in_features=args.hidden_size // tp_size if args.comm_type == "none" else args.hidden_size, + out_features=args.hidden_size, + bias=False, + device=device, + params_dtype=torch.bfloat16, + tp_size=tp_size if args.comm_type != "none" else 1, + parallel_mode="row" if args.comm_type != "none" else None, + tp_group=torch.distributed.group.WORLD if args.comm_type != "none" else None, + symmetric_ar_type=args.sym_type if args.comm_type == "sym" else args.comm_type, + sequence_parallel=args.comm_type == "ub", + ub_overlap_rs=args.comm_type == "ub", + ub_name="proj" if args.comm_type == "ub" else None, + eps=args.eps if args.comm_type == "ubnext_add_rms" else None, + ) + + fc1 = te.LayerNormLinear( + in_features=args.hidden_size, + out_features=( + args.hidden_size * args.fc_factor // tp_size + if args.comm_type == "none" + else args.hidden_size * args.fc_factor + ), + bias=False, + device=device, + params_dtype=torch.bfloat16, + eps=args.eps, + zero_centered_gamma=False, + normalization="RMSNorm", + tp_size=tp_size if args.comm_type != "none" else 1, + parallel_mode="column" if args.comm_type != "none" else None, + tp_group=torch.distributed.group.WORLD if args.comm_type != "none" else None, + skip_layernorm=args.comm_type == "ubnext_add_rms", + sequence_parallel=args.comm_type == "ub", + ub_overlap_ag=args.comm_type == "ub", + ub_name="fc1" if args.comm_type == "ub" else None, + ) + + if args.comm_type == "ubnext_add_rms": + proj.layer_norm_weight = fc1.layer_norm_weight.data + # Create CUDA stream + stream = torch.cuda.Stream() + # Check for environment variable to override pool size + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + for logbatch in range(int(math.log2(args.batch_size)) + 1): + batch = 2**logbatch + if args.comm_type == "ub": # and batch < tp_size: + batch = args.batch_size # tp_size + # Create input tensor + torch.cuda.synchronize() + torch.distributed.barrier(group=torch.distributed.group.WORLD) + residual = torch.randn( + batch // tp_size if args.comm_type == "ub" else batch, + args.hidden_size, + device=device, + dtype=torch.bfloat16, + ) + inp = torch.randn(batch, args.hidden_size // tp_size, device=device, dtype=torch.bfloat16) + + # Warm-up run + if not args.comm_type.startswith("ubnext_add"): + out_proj = proj(inp) + out_proj.add_(residual) + out = fc1(out_proj) + else: + out = fc1( + proj(inp, residual=residual) + ) # this also allocates distributed internal residual + + torch.cuda.synchronize() + if args.cuda_graph: + with torch.cuda.stream(stream): + # Create CUDA Graph + graph = torch.cuda.CUDAGraph() + + with torch.cuda.graph(graph): + if not args.comm_type.startswith("ubnext_add"): + out_proj = proj(inp) + out_proj.add_(residual) + out = fc1(out_proj) + else: + out = fc1(proj(inp)) + + # Warm-up the graph + for _ in range(5): + graph.replay() + + torch.cuda.synchronize() + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + # Measure time for forward passes + start_time = time.time() + with torch.cuda.stream(stream): + for _ in range(args.iterations): + if args.cuda_graph: + graph.replay() + else: + if not args.comm_type.startswith("ubnext_add"): + out_proj = proj(inp) + out_proj.add_(residual) + out = fc1(out_proj) + else: + out = fc1(proj(inp)) + + torch.cuda.synchronize() + end_time = time.time() + elapsed = end_time - start_time + + # Calculate and print elapsed time (only on rank 0) + if myrank == 0: + print(f"Batch{batch},{(elapsed/ args.iterations) * 1e6:.4f}") + if args.cuda_graph: + # needed or NCCL would hang + del graph + + # Cleanup + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + # Generate a unique run ID for distributed initialization + os.environ["RUN_ID"] = str(uuid.uuid4()) + main() diff --git a/tests/pytorch/distributed/test_linear_comms.py b/tests/pytorch/distributed/test_linear_comms.py new file mode 100644 index 0000000000..17668bafc9 --- /dev/null +++ b/tests/pytorch/distributed/test_linear_comms.py @@ -0,0 +1,370 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +import torch.distributed._symmetric_memory as symm_mem +import time +import argparse +import os +import uuid +import math + + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser( + description=( + "Run a linear layer with Transformer Engine, CUDA Graphs, and Tensor Parallelism" + ) + ) + parser.add_argument("--in_features", type=int, default=8192, help="Input feature size") + parser.add_argument("--out_features", type=int, default=8192, help="Output feature size") + parser.add_argument("--batch_size", type=int, default=2048, help="Batch size") + parser.add_argument( + "--cuda_graph", action="store_true", help="Use CUDA Graphs (pass this flag to enable)" + ) + parser.add_argument("--validate", action="store_true", help="Validate allreduce ubnext") + parser.add_argument("--comm_type", type=str, default="sym", help="Comm type: nccl,sym,ub") + parser.add_argument( + "--sym_type", + type=str, + default="multimem_all_reduce", + help="sym type: one_shot, two_shot, multimem_all_reduce, ubnext", + ) + parser.add_argument("--iterations", type=int, default=1000, help="Number of iterations") + parser.add_argument( + "--tp_size", + type=int, + default=None, + help="Tensor parallelism size (defaults to number of GPUs)", + ) + parser.add_argument("--eps", type=float, default=1e-5, help="Epsilon") + parser.add_argument("--rmsnorm", action="store_true", help="Use RMSNorm") + args = parser.parse_args() + + # Check CUDA availability and get device count + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Test requires NVIDIA GPUs.") + + num_devices = torch.cuda.device_count() + if num_devices == 0: + raise RuntimeError("No CUDA devices found.") + + # Set tensor parallelism size + tp_size = ( + args.tp_size if args.tp_size is not None else int(os.environ.get("WORLD_SIZE", num_devices)) + ) + + # Initialize distributed environment for each GPU + myrank = int(os.environ.get("RANK", 0)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + num_nodes = world_size // local_size + if num_nodes > 1: + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + # Set device + device = torch.device(f"cuda:{local_rank}") + # Initialize torch.distributed for tensor parallelism + # Only set defaults if not already set by torchrun + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "29500" + torch.cuda.set_device(device) + + torch.distributed.init_process_group( + backend="nccl", world_size=tp_size, rank=myrank % tp_size, device_id=device + ) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + # Transformer Engine handles tensor parallelism internally when distributed is initialized + # Set environment variable for tensor parallelism size + os.environ["NVTE_TP_SIZE"] = str(tp_size) + + ub_cfgs = { + "proj_fprop": { + "method": "pipeline", + "num_splits": 1, + "is_reduce_scatter": True, + "num_sm": 32, + "atomic_gemm": False, + "aggregate": False, + "cga_size": 4, + "set_sm_margin": False, + "fp8_buf": False, + "use_ce": False, + } + } + + # Initialize model with BF16 precision + + modelseq = te.Linear( + in_features=int(args.in_features / tp_size), + out_features=args.out_features, + bias=False, + device=device, + params_dtype=torch.bfloat16, + ) + + modelnorm = te.RMSNorm( + normalized_shape=int(args.out_features), + eps=args.eps, + device=device, + dtype=torch.bfloat16, + zero_centered_gamma=False, + ) + residual = ( + torch.randn(args.batch_size, args.out_features, device=device, dtype=torch.bfloat16) + if args.rmsnorm + else None + ) + + ln_weight = modelnorm.weight.data if args.rmsnorm else None + if ( + args.comm_type == "sym" and os.environ.get("NVTE_USE_UB_FOR_UBNEXT") + ) or args.comm_type == "ub": + te.module.base.initialize_ub( + [args.batch_size, args.out_features], + tp_size, + use_fp8=False, + dtype=torch.bfloat16, + bootstrap_backend="nccl", + ub_cfgs=ub_cfgs, + ) + + modelpar = None + + if args.comm_type == "sym" or args.comm_type == "nccl": + modelpar = te.Linear( + in_features=args.in_features, + out_features=args.out_features, + bias=False, + device=device, + params_dtype=torch.bfloat16, + tp_size=tp_size, + parallel_mode="row", + tp_group=torch.distributed.group.WORLD, + symmetric_ar_type=args.sym_type if args.comm_type == "sym" else None, + eps=args.eps, + ln_weight=ln_weight, + ) + + if args.comm_type == "ub": + modelpar = te.Linear( + in_features=args.in_features, + out_features=args.out_features, + bias=False, + device=device, + params_dtype=torch.bfloat16, + tp_size=tp_size, + parallel_mode="row", + tp_group=torch.distributed.group.WORLD, + sequence_parallel=True, + ub_overlap_rs=True, + ub_name="proj", + eps=args.eps, + ln_weight=ln_weight, + ) + + # Create CUDA stream + stream = torch.cuda.Stream() + # Check for environment variable to override pool size + + allocator = None + if args.comm_type == "sym" and args.validate: + pool_size = int(os.environ.get("NVTE_UB_SYMM_POOL_SIZE", 64)) * 1024 * 1024 + allocator = te.cpp_extensions.symm_allocator.SymmAllocator( + pool_size, torch.device(device), torch.distributed.group.WORLD + ) + + # Run tensor comparison tests only for symmetric communication + if args.comm_type == "sym" and args.validate: + + if args.rmsnorm: + torch.manual_seed(57) + torch.cuda.manual_seed(57) + residual = torch.randn(1, args.out_features, dtype=torch.bfloat16, device=device) + t = allocator.create_tensor( + ( + 1, + args.out_features, + ), + dtype=torch.bfloat16, + ) + # te.cpp_extensions.symm_allocator.ubsymm_free_residual(t) + t.fill_(myrank) + t_in = t.clone() + torch.distributed.all_reduce(t_in) + t_in.add_(residual) + out1 = modelnorm(t_in) + out2 = allocator.allreduce_simple( + t, + hidden_size=args.out_features, + residual_in=residual, + residual_out=residual, + fuse_layernorm=True, + eps=args.eps, + gamma=modelnorm.weight.data, + ) + abs_diff = torch.abs(out1 - out2) + max_delta = torch.max(abs_diff).item() + num_different = torch.sum(out1 != out2).item() + print(f"FUSED RMSNorm Max delta: {max_delta}, num different: {num_different}") + if myrank == 0: + print(f"gamma: {modelnorm.weight.data}") + print(f"FUSED RMSNorm output: {out1}") + print(f"FUSED RMSNorm output: {out2}") + + # Test different tensor sizes from 64 to 1024*1024 elements + all_max_deltas = [] + all_num_different = [] + all_total_elements = [] + all_sizes = [] + + size = 64 + while size <= 1024 * 1024: + # Allocate tensors + t = allocator.create_tensor((size,), dtype=torch.bfloat16) + t.fill_(0) + t += torch.randn_like(t) # Add random noise to each element + tmain = t.clone() # Create a copy since allreduce operates in-place + torch.distributed.all_reduce(tmain) + tlamport = allocator.allreduce_lamport(t) + + # Compare the two tensors + abs_diff = torch.abs(tlamport - tmain) + max_delta = torch.max(abs_diff).item() + num_different = torch.sum(tlamport != tmain).item() + + # Store statistics + all_max_deltas.append(max_delta) + all_num_different.append(num_different) + all_total_elements.append(tlamport.numel()) + all_sizes.append(size) + + # Free tensor (memory returned to pool) + del t, tlamport, tmain, abs_diff + + # Double the size for next iteration + size *= 2 + + # Print summary statistics + if myrank == 0: + print("\n=== Tensor Comparison Summary ===") + total_elements_tested = sum(all_total_elements) + total_different_elements = sum(all_num_different) + overall_max_delta = max(all_max_deltas) + + print( + f"Tested sizes: {len(all_sizes)} different tensor sizes from {all_sizes[0]} to" + f" {all_sizes[-1]} elements" + ) + print(f"Total elements tested: {total_elements_tested}") + print(f"Total different elements: {total_different_elements}") + print( + "Overall error rate:" + f" {(total_different_elements / total_elements_tested) * 100:.6f}%" + ) + print(f"Maximum delta across all tests: {overall_max_delta}") + + if total_different_elements > 0 or overall_max_delta > 0: + print("\nPer-size breakdown:") + for i, size in enumerate(all_sizes): + error_rate = (all_num_different[i] / all_total_elements[i]) * 100 + print( + f" Size {size:7d}:" + f" {all_num_different[i]:6d}/{all_total_elements[i]:7d} different" + f" ({error_rate:6.3f}%), max_delta: {all_max_deltas[i]:.6f}" + ) + print("================================\n") + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + for logbatch in range(int(math.log2(args.batch_size)) + 1): + batch = 2**logbatch + if args.comm_type == "ub" and batch < tp_size: + batch = tp_size + # Create input tensor + inp = torch.randn( + batch, int(args.in_features / tp_size), device=device, dtype=torch.bfloat16 + ) + # Warm-up run + out = modelseq(inp) + modelnorm(out) + modelpar(inp, residual=residual) + torch.cuda.synchronize() + if args.cuda_graph: + with torch.cuda.stream(stream): + # Create CUDA Graph + gseq = torch.cuda.CUDAGraph() + gpar = torch.cuda.CUDAGraph() + with torch.cuda.graph(gseq): + output = modelseq(inp) + if args.rmsnorm: + output.add_(residual[:batch, : args.out_features]) + output = modelnorm(output) + with torch.cuda.graph(gpar): + output = modelpar(inp, residual=residual) + # Warm-up the graph + for _ in range(5): + gseq.replay() + gpar.replay() + torch.cuda.synchronize() + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + # Measure time for forward passes + start_time = time.time() + with torch.cuda.stream(stream): + for _ in range(args.iterations): + if args.cuda_graph: + gseq.replay() + else: + modelseq(inp) + + torch.cuda.synchronize() + end_time = time.time() + seq_elapsed = end_time - start_time + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + # Measure time for forward passes + start_time = time.time() + with torch.cuda.stream(stream): + for _ in range(args.iterations): + if args.cuda_graph: + gpar.replay() + else: + modelpar(inp) + + torch.cuda.synchronize() + end_time = time.time() + par_elapsed = end_time - start_time + nccl_elapsed = par_elapsed - seq_elapsed + # Calculate and print elapsed time (only on rank 0) + if myrank == 0: + print( + f"Batch{batch},{(seq_elapsed/ args.iterations) * 1e6:.4f}us,{(par_elapsed/ args.iterations) * 1e6:.4f} us,{(nccl_elapsed/ args.iterations) * 1e6:.4f}" + ) + if args.cuda_graph: + # needed or NCCL would hang + del gseq, gpar + + # Cleanup + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + # Generate a unique run ID for distributed initialization + os.environ["RUN_ID"] = str(uuid.uuid4()) + main() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a4915080e8..45cc445034 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -143,7 +143,8 @@ list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) + comm_gemm_overlap/comm_gemm_overlap.cpp + ubnext.cu) if (NVTE_WITH_CUBLASMP) list(APPEND transformer_engine_SOURCES diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 56369db27f..5dbc0c54e9 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -488,7 +488,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons _ub_comm->cga_size = _cga_size; size_t m = transa ? A.size(0) : A.size(1); size_t k = transa ? A.size(1) : A.size(0); - size_t n = _ubuf.size(0); + size_t n = B.size(0); size_t m_chunk = m / _num_splits; const std::vector input_a_chunk_shape = (transa ? std::vector{m_chunk, k} : std::vector{k, m_chunk}); @@ -640,6 +640,15 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); } +uintptr_t CommOverlapBase::init_ubnext() { + NVTE_CHECK_CUDA(cudaMemset(_ub_comm->mem_ptr[_ub_reg], 0, _ub_comm->mem_size[_ub_reg])); + NVTE_CHECK_CUDA(cudaMemcpy(_ub_comm->mem_ptr[_ub_reg], + (reinterpret_cast(_ub_comm->mem_ptr[0])) + + (_ub_reg * _ub_comm->nvsize * sizeof(void *)), + _ub_comm->nvsize * sizeof(void *), cudaMemcpyDeviceToDevice)); + return (uintptr_t)(_ub_comm->mc_ptr[_ub_reg]); +} + /*************************************************************************************************** * Comm+GEMM Overlap P2P Base (Ring-Exchange) **************************************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index cffc411a0d..572f4c52cd 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -225,6 +225,9 @@ class CommOverlapBase : public CommOverlapCore { void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, cudaStream_t stream_main) override; + + // initialize ubnext buffer and return multicast pointer for allreduce + uintptr_t init_ubnext(); }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { diff --git a/transformer_engine/common/include/transformer_engine/ubnext.h b/transformer_engine/common/include/transformer_engine/ubnext.h new file mode 100644 index 0000000000..0e1b304a61 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/ubnext.h @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_UBNEXT_H_ +#define TRANSFORMER_ENGINE_UBNEXT_H_ + +#include "transformer_engine.h" + +namespace transformer_engine { + +#ifdef __cplusplus +extern "C" { +#endif + +void allreduce_2shot_mc(int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* mcptr_in, + void* mcptr_out, size_t bytes, void* residual_in, void* residual_out, + bool fuse_layernorm, void* gamma, float eps, const int hidden_size, + cudaStream_t stream); +void allreduce_2shot_mc_lamport(int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* ucptr_out, + void* mcptr_in, void* mcptr_out, void* clear_ptr, size_t bytes, + bool poisoned, void* residual_in, void* residual_out, + bool fuse_layernorm, void* gamma, float eps, const int hidden_size, + cudaStream_t stream); +void allreduce_2shot_uc(int ranks, int myrank, void* uc0ptr, void* ucptr_in, void* ucptr_out, + size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, + void* gamma, float eps, const int hidden_size, cudaStream_t stream); + +#ifdef __cplusplus +} +#endif +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 706c237ccc..68e5f8aef8 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -19,7 +19,8 @@ *transformer_engine::CommOverlapP2PBase*; *transformer_engine::CommOverlapCore*; *nvshmem_wait_on_stream*; - *nvshmemi_init_thread* + *nvshmemi_init_thread*; + allreduce_*; }; local: *; }; diff --git a/transformer_engine/common/ubnext.cu b/transformer_engine/common/ubnext.cu new file mode 100644 index 0000000000..8aa5287f9e --- /dev/null +++ b/transformer_engine/common/ubnext.cu @@ -0,0 +1,636 @@ +#include +#include +#include + +#include "./common.h" + +#define TIMEOUT 2000000000ull +//#define UB_TIMEOUT_ENABLED 1 + +#define NVTE_UB_MAXTHREADS 1024 +#define NVTE_UB_MAX_SMS 128 +#define NVTE_UB_LAMPORT_INT 0xFFFAFFFA + +//REG0 flags in use +#define NVTE_UB_FLAG_NVLS2_LAMPORT_ID 0 +#define NVTE_UB_FLAG_NVLS2_LAMPORT_SM_SYNC 1 +#define NVTE_UB_FLAG_NVLS2_LAMPORT_RS_BAR 2 +#define NVTE_UB_FLAG_NVLS2_ID 3 +#define NVTE_UB_FLAG_NVLS2_SM_SYNC 4 +#define NVTE_UB_FLAG_NVLS2_RS_BAR 5 +#define NVTE_UB_FLAG_NVLS2_AG_BAR 6 + +#define xhalf __nv_bfloat16 + +#define ATOMIC_MCINC(ptr) \ + asm volatile("multimem.red.add.u32 [%0], %1;" ::"l"(ptr), "r"(1) \ + : "memor" \ + "y"); +#define ATOMIC_UCINC(ptr) \ + asm volatile("red.global.add.u32 [%0], %1;" ::"l"(ptr), "r"(1) \ + : "memor" \ + "y"); +#define MULTIMEM_ST(val, ptr) \ + asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), \ + "r"(val.y), "r"(val.z), "r"(val.w) \ + : "memory"); + +#define MULTIMEM_LD(val, ptr) \ + asm("multimem.ld_reduce.global.add.v4.bf16x2.acc::f32 {%0,%1,%2,%3}, [%4];" \ + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) \ + : "l"(ptr) \ + : "memory"); + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +// Return true if producer > consumer, otherwise false while preventing integer overflow +// If we expect that producer will be 2B+ messages behind consumer +#define CHECK_IDS(producer, consumer) (((unsigned)(producer) - (unsigned)(consumer)) & (~INT_MAX)) + +#define FINAL_MASK 0xffffffff +template +__inline__ __device__ T warpReduceSumV2(T *val) { +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSumV2(T *val) { + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSumV2(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSumV2(val); + return (T)0.0f; +} + +template +__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) + userbuffers_fp16_sum_inplace_gpu_mc(const int RANKS, const int myrank, const int mylines, + int *uc_flagptr, int *mc_flagptr, uint4 *mc_ptr_in, + uint4 *mc_ptr_out, uint4 *residual_in, uint4 *residual_out, + xhalf *gamma, float eps, const int hidden_size, + bool fuse_layernorm) { + // flags[3,4,5,6]: reduce_id, sm_sync-local, flag-barrier-1,flag-barrier-2 + int reduce_id; + __shared__ float s_variance; + + if (threadIdx.x == 0) { + cudaGridDependencySynchronize(); + if (blockIdx.x == 0) ATOMIC_MCINC(mc_flagptr + NVTE_UB_FLAG_NVLS2_RS_BAR); + + reduce_id = uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] + 1; + + volatile int *flag = (volatile int *)&(uc_flagptr[NVTE_UB_FLAG_NVLS2_RS_BAR]); + + const int expected = reduce_id * RANKS; + +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } + } + + __syncthreads(); + + const int loop_step0 = blockDim.x; + const int loop_step = loop_step0 * UNROLL * gridDim.x; + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL; + const int end_elem = max(start_elem, mylines); + //const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; + //const int end_aligned = start_elem + aligned_elem; + + for (int line = start_elem; line < end_elem; line += loop_step) { + uint4 val[UNROLL]; + xhalf *x = reinterpret_cast(&val[0]); +#pragma unroll + for (int i = 0; i < UNROLL; i++) MULTIMEM_LD(val[i], mc_ptr_in + (line + i * loop_step0)) + + if (residual_in != nullptr) { + for (int i = 0; i < UNROLL; i++) { + uint4 resval = residual_in[line + i * loop_step0]; + xhalf *y = reinterpret_cast(&resval); +#pragma unroll + for (int j = 0; j < 8; j++) x[i * 8 + j] += y[j]; + if (residual_out != nullptr) residual_out[line + i * loop_step0] = val[i]; + } + } + if (fuse_layernorm) { + float local_var_sum = 0.0f; + for (int j = 0; j < UNROLL * sizeof(int4) / sizeof(xhalf); j++) + local_var_sum += (float)(x[j]) * (float)(x[j]); + + float packed[1] = {local_var_sum}; + blockReduceSumV2(packed); + float variance = packed[0]; + + if (threadIdx.x == 0) { + variance = (variance / hidden_size); // Var[x] = E[x²] + s_variance = rsqrtf(variance + eps); + } + __syncthreads(); + } + + int i = 0; +#pragma unroll + for (int g = 0; g < UNROLL; g++) { + if (fuse_layernorm) { +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(xhalf); j++) { + x[i] = + (xhalf)((float)(x[i]) * s_variance * + (float) + gamma[(threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(xhalf) + j]); + i++; + } + } + MULTIMEM_ST(val[g], mc_ptr_out + (line + g * loop_step0)) + } + } + /* + for (int line = end_aligned; line < end_elem; line += loop_step0) { + uint4 val; + xhalf *x = reinterpret_cast(&val); + MULTIMEM_LD(val, mc_ptr_in + (line)) + + if(residual_in!=nullptr) { + uint4 resval = residual_in[line]; + xhalf *y = reinterpret_cast(&resval); + #pragma unroll + for (int j = 0; j < 8; j++) + x[j] += y[j]; + if(residual_out!=nullptr) + residual_out[line]=val; + } + + MULTIMEM_ST(val, mc_ptr_out + (line)) + } + */ + __syncthreads(); + if (threadIdx.x != 0) return; + + __threadfence(); + const int value_to_add = blockIdx.x == 0 ? NVTE_UB_MAX_SMS - gridDim.x + 1 : 1; + const int old_val_sm_sync = atomicAdd(uc_flagptr + NVTE_UB_FLAG_NVLS2_SM_SYNC, value_to_add); + + const int lastSM = + (gridDim.x == 1 || old_val_sm_sync + value_to_add == reduce_id * NVTE_UB_MAX_SMS); + if (!lastSM) return; + __threadfence_system(); + ATOMIC_MCINC(mc_flagptr + NVTE_UB_FLAG_NVLS2_AG_BAR); + uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] = reduce_id; + cudaTriggerProgrammaticLaunchCompletion(); + volatile int *flag = (volatile int *)&(uc_flagptr[NVTE_UB_FLAG_NVLS2_AG_BAR]); + const int expected = reduce_id * RANKS; + +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } +} // fp16 inplace reduce kernel (Hopper) MC + +template +__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) + userbuffers_fp16_sum_inplace_gpu_uc(const int myrank, const int numlines, + const int lineoffset_in, const int lineoffset_out, + int *uc_flagptr, void **commbuff, uint4 *residual_in, + uint4 *residual_out, xhalf *gamma, float eps, + const int hidden_size, bool fuse_layernorm) { + // flags[3,4,5,6]: reduce_id, sm_sync-local, flag-barrier-1,flag-barrier-2 + //NB! uc_flagptr is shifted by ranks*8 for easier flag offsets + // while lineoffset is relative to start of reg0 + __shared__ uint4 *userptr[RANKS]; + __shared__ int lastSM; + int reduce_id; + + if (threadIdx.x < RANKS) { + int *rem_flagptr = (reinterpret_cast(commbuff[threadIdx.x])); + cudaGridDependencySynchronize(); + if (blockIdx.x == 0) ATOMIC_UCINC(rem_flagptr + NVTE_UB_FLAG_NVLS2_RS_BAR + RANKS * 2); + + reduce_id = uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] + 1; + + userptr[threadIdx.x] = (uint4 *)rem_flagptr; + } + + if (threadIdx.x == 0) { + volatile int *flag = uc_flagptr + NVTE_UB_FLAG_NVLS2_RS_BAR; + lastSM = 0; + const int expected = reduce_id * RANKS; +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } + } + + __syncthreads(); + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; + line += blockDim.x * gridDim.x * RANKS) { + uint4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + // int dest = (i+myrank+warp)&(RANKS-1); + val[i] = userptr[dest[i]][lineoffset_in + line]; + } + + uint4 sum = val[0]; + xhalf *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + xhalf *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } + + if (residual_in != nullptr) { + uint4 resval = residual_in[lineoffset_in + line]; + xhalf *y = reinterpret_cast(&resval); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += y[j]; + if (residual_out != nullptr) residual_out[lineoffset_in + line] = sum; + } + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + // int dest = (i+myrank+warp)&(RANKS-1); + userptr[dest[i]][lineoffset_out + line] = sum; + } + } + + __syncthreads(); + + if (threadIdx.x == 0) { + __threadfence(); + const int value_to_add = blockIdx.x == 0 ? NVTE_UB_MAX_SMS - gridDim.x + 1 : 1; + const int old_val_sm_sync = atomicAdd(uc_flagptr + NVTE_UB_FLAG_NVLS2_SM_SYNC, value_to_add); + lastSM = (gridDim.x == 1 || old_val_sm_sync + value_to_add == reduce_id * NVTE_UB_MAX_SMS); + if (lastSM) uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] = reduce_id; + cudaTriggerProgrammaticLaunchCompletion(); + } + if (threadIdx.x >= RANKS) return; + __syncthreads(); + if (!lastSM) return; + if (threadIdx.x == 0) __threadfence_system(); + __syncthreads(); + ATOMIC_UCINC((int *)(userptr[threadIdx.x]) + NVTE_UB_FLAG_NVLS2_AG_BAR + RANKS * 2); + if (threadIdx.x != 0) return; + volatile int *flag = uc_flagptr + NVTE_UB_FLAG_NVLS2_AG_BAR; + const int expected = reduce_id * RANKS; +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } +} // UC 2shot kernel (non-lamport) + +__global__ void memset_int(uint32_t *data, int n, uint32_t val) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + data[idx] = val; + } +} + +template +__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) + userbuffers_fp16_sum_inplace_gpu_mc_lamport(const int RANKS, const int myrank, + const int mylines, const int numlines, + int *uc_flagptr, int *mc_flagptr, uint4 *mc_ptr_in, + uint4 *mc_ptr_out, uint4 *uc_ptr_out, + uint4 *clear_ptr, uint4 *residual_in, + uint4 *residual_out, xhalf *gamma, float eps, + const int hidden_size, bool fuse_layernorm) { + // flags[0,1,2]: reduce_id, sm_sync-local, flag-barrier + // those go right after rank UC pointers, but its the CPU caller who should account for it + int reduce_id; + __shared__ float s_variance; + + if (threadIdx.x == 0) { + cudaGridDependencySynchronize(); + if (blockIdx.x == 0) ATOMIC_MCINC(mc_flagptr + NVTE_UB_FLAG_NVLS2_LAMPORT_RS_BAR); + reduce_id = uc_flagptr[NVTE_UB_FLAG_NVLS2_LAMPORT_ID]; + const int value_to_add = blockIdx.x == 0 ? NVTE_UB_MAX_SMS - gridDim.x + 1 : 1; + const int old_val_sm_sync = + atomicAdd(uc_flagptr + NVTE_UB_FLAG_NVLS2_LAMPORT_SM_SYNC, value_to_add); + volatile int *flag = (volatile int *)&(uc_flagptr[NVTE_UB_FLAG_NVLS2_LAMPORT_RS_BAR]); + reduce_id++; + const int lastSM = + (gridDim.x == 1 || old_val_sm_sync + value_to_add == reduce_id * NVTE_UB_MAX_SMS); + + if (lastSM) uc_flagptr[NVTE_UB_FLAG_NVLS2_LAMPORT_ID] = reduce_id; + cudaTriggerProgrammaticLaunchCompletion(); + + const int expected = reduce_id * RANKS; + +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } + } + __syncthreads(); + + const int loop_step0 = blockDim.x; + const int loop_step = loop_step0 * UNROLL * gridDim.x; + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL; + const int end_elem = max(start_elem, mylines); + + for (int line = start_elem; line < end_elem; line += loop_step) { + uint4 val[UNROLL]; + xhalf *x = reinterpret_cast(&val[0]); +#pragma unroll + for (int i = 0; i < UNROLL; i++) MULTIMEM_LD(val[i], mc_ptr_in + (line + i * loop_step0)) + + if (residual_in != nullptr) { + for (int i = 0; i < UNROLL; i++) { + uint4 resval = residual_in[line + i * loop_step0]; + xhalf *y = reinterpret_cast(&resval); +#pragma unroll + for (int j = 0; j < 8; j++) x[i * 8 + j] += y[j]; + if (residual_out != nullptr) residual_out[line + i * loop_step0] = val[i]; + } + } + if (fuse_layernorm) { + float local_var_sum = 0.0f; + for (int j = 0; j < UNROLL * sizeof(int4) / sizeof(xhalf); j++) + local_var_sum += (float)(x[j]) * (float)(x[j]); + + float packed[1] = {local_var_sum}; + blockReduceSumV2(packed); + float variance = packed[0]; + + if (threadIdx.x == 0) { + variance = (variance / hidden_size); // Var[x] = E[x²] + s_variance = rsqrtf(variance + eps); + } + __syncthreads(); + } + + int i = 0; +#pragma unroll + for (int g = 0; g < UNROLL; g++) { + if (fuse_layernorm) { +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(xhalf); j++) { + x[i] = + (xhalf)((float)(x[i]) * s_variance * + (float) + gamma[(threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(xhalf) + j]); + i++; + } + } + MULTIMEM_ST(val[g], mc_ptr_out + (line + g * loop_step0)) + } + } + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < numlines; + line += blockDim.x * gridDim.x) { +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (true) { + uint4 result; + + asm volatile("ld.volatile.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(result.x), "=r"(result.y), "=r"(result.z), "=r"(result.w) + : "l"(&uc_ptr_out[line]) + : "memory"); + if (result.w != NVTE_UB_LAMPORT_INT) { + if (clear_ptr) clear_ptr[line].w = NVTE_UB_LAMPORT_INT; + break; + } +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("Lamport POLL:SM %d [%d]:expecting %d got (%d,%d,%d) %d\n", blockIdx.x, threadIdx.x, + NVTE_UB_LAMPORT_INT, result.x, result.y, result.z, result.w); + break; + } +#endif + } + } + +} // two-shot NVLS + lamport sync instead of last membar + +#define SETUP_LAUNCH_CONFIG(sms, threads, stream, cga_size, pdl_launch) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[3]; \ + attribute_ub[2].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[2].val.clusterDim.x = sms % cga_size == 0 ? cga_size : 1; \ + attribute_ub[2].val.clusterDim.y = 1; \ + attribute_ub[2].val.clusterDim.z = 1; \ + attribute_ub[1].id = cudaLaunchAttributeCooperative; \ + attribute_ub[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \ + attribute_ub[0].val.programmaticStreamSerializationAllowed = pdl_launch; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = 3; + +namespace transformer_engine { + +#define split_tokens(x) \ + const int elements = bytes / sizeof(half); \ + const int elements_per_thread = sizeof(uint4) / sizeof(half); \ + int nthreads = 1024, nlines = 4; \ + size_t total_bytes = bytes / ranks, start_bytes = myrank * total_bytes; \ + int sms = x; \ + if (hidden_size) { \ + assert(hidden_size <= 32768); \ + assert(elements % hidden_size == 0); \ + assert(hidden_size % elements_per_thread == 0); \ + int ntokens = elements / hidden_size; \ + int my_tokens = ntokens / ranks; \ + int extra_tokens = ntokens % ranks; \ + int first_token = myrank * my_tokens; \ + first_token += myrank < extra_tokens ? myrank : extra_tokens; \ + if (myrank < extra_tokens) my_tokens++; \ + start_bytes = first_token * hidden_size * sizeof(half); \ + total_bytes = my_tokens * hidden_size * sizeof(half); \ + nthreads = hidden_size / elements_per_thread; \ + nlines = 1; \ + while (nthreads > 1024) { \ + nlines++; \ + assert(nlines <= 4); \ + if ((hidden_size / elements_per_thread) % nlines == 0) \ + nthreads = ((hidden_size / elements_per_thread)) / nlines; \ + } \ + if (sms > my_tokens) sms = my_tokens; \ + if (sms == 0) sms = 1; \ + } \ + bool residual_in_global = residual_in != nullptr && residual_in != residual_out && \ + residual_out != nullptr; // out residual is always local + +extern "C" void allreduce_2shot_mc(int ranks, int myrank, void *uc0ptr, void *mc0ptr, + void *mcptr_in, void *mcptr_out, size_t bytes, void *residual_in, + void *residual_out, bool fuse_layernorm, void *gamma, float eps, + const int hidden_size, cudaStream_t stream) { + split_tokens(32); + + SETUP_LAUNCH_CONFIG(sms, nthreads, stream, 4, 1); + + int arg1 = ranks, arg2 = myrank, arg3 = total_bytes / sizeof(uint4); + void *arg4 = uc0ptr + (ranks * 8), *arg5 = mc0ptr + (ranks * 8), *arg6 = mcptr_in + start_bytes, + *arg7 = mcptr_out + start_bytes, + *arg8 = residual_in_global ? residual_in + start_bytes : residual_in, *arg9 = residual_out, + *arg10 = gamma; + float arg11 = eps; + int arg12 = hidden_size; + bool arg13 = fuse_layernorm; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg4, + (void *)&arg5, (void *)&arg6, (void *)&arg7, (void *)&arg8, + (void *)&arg9, (void *)&arg10, (void *)&arg11, (void *)&arg12, + (void *)&arg13}; +#define call_mc_kernel(x, cond) \ + if (x == nlines || cond) { \ + CUDACHECK( \ + cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc), kernelArgs)); \ + return; \ + } + call_mc_kernel(1, false); + call_mc_kernel(2, false); + call_mc_kernel(3, false); + call_mc_kernel(4, true); +} + +extern "C" void allreduce_2shot_uc(int ranks, int myrank, void *uc0ptr, void *ucptr_in, + void *ucptr_out, size_t bytes, void *residual_in, + void *residual_out, bool fuse_layernorm, void *gamma, float eps, + const int hidden_size, cudaStream_t stream) { + SETUP_LAUNCH_CONFIG(64, 1024, stream, 4, 1); + + int arg1 = myrank, arg2 = bytes / 16, arg3 = (int4 *)ucptr_in - (int4 *)uc0ptr, + arg4 = (int4 *)ucptr_out - (int4 *)uc0ptr; + void *arg5 = uc0ptr + (ranks * 8), **arg6 = (void **)uc0ptr, *arg7 = residual_in, + *arg8 = residual_out, *arg9 = gamma; + float arg10 = eps; + int arg11 = hidden_size; + bool arg12 = fuse_layernorm; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg4, + (void *)&arg5, (void *)&arg6, (void *)&arg7, (void *)&arg8, + (void *)&arg9, (void *)&arg10, (void *)&arg11, (void *)&arg12}; +#define call_uc_kernel(x) \ + if (x == ranks) \ + CUDACHECK( \ + cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_uc), kernelArgs)); + call_uc_kernel(2); + call_uc_kernel(4); + call_uc_kernel(8); +} + +extern "C" void allreduce_2shot_mc_lamport(int ranks, int myrank, void *uc0ptr, void *mc0ptr, + void *ucptr_out, void *mcptr_in, void *mcptr_out, + void *clear_ptr, size_t bytes, bool poisoned, + void *residual_in, void *residual_out, + bool fuse_layernorm, void *gamma, float eps, + const int hidden_size, cudaStream_t stream) { + if (!poisoned) { + //user tells us destination was not pre-poisoned, so we need to do it before calling allreduce + int threadsPerBlock = 512; + int blocks = (bytes / 4 + threadsPerBlock - 1) / threadsPerBlock; + memset_int<<>>((uint32_t *)ucptr_out, bytes / 4, + NVTE_UB_LAMPORT_INT); + } + split_tokens(64); + + SETUP_LAUNCH_CONFIG(64, nthreads, stream, 4, 1); + + int arg1 = ranks, arg2 = myrank, arg3 = total_bytes / sizeof(uint4), + arg3a = bytes / sizeof(uint4); + void *arg4 = uc0ptr + (ranks * 8), *arg5 = mc0ptr + (ranks * 8), *arg6 = mcptr_in + start_bytes, + *arg7 = mcptr_out + start_bytes, *arg8 = ucptr_out, *arg9 = clear_ptr, + *arg10 = residual_in_global ? residual_in + start_bytes : residual_in, *arg11 = residual_out, + *arg12 = gamma; + float arg13 = eps; + int arg14 = hidden_size; + bool arg15 = fuse_layernorm; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg3a, + (void *)&arg4, (void *)&arg5, (void *)&arg6, (void *)&arg7, + (void *)&arg8, (void *)&arg9, (void *)&arg10, (void *)&arg11, + (void *)&arg12, (void *)&arg13, (void *)&arg14, (void *)&arg15}; + +#define call_mc_lamport_kernel(x, cond) \ + if (x == nlines || cond) { \ + CUDACHECK(cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc_lamport), \ + kernelArgs)); \ + return; \ + } + + call_mc_lamport_kernel(1, false); + call_mc_lamport_kernel(2, false); + call_mc_lamport_kernel(3, false); + call_mc_lamport_kernel(4, true); +} + +} // namespace transformer_engine diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index bce124e705..240a18e8a7 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "cuda_runtime.h" @@ -117,6 +118,8 @@ std::shared_ptr, \ transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + py::call_guard()) \ + .def("init_ubnext", &transformer_engine::CommOverlapBase::init_ubnext, \ py::call_guard()); \ py::class_, \ @@ -135,6 +138,45 @@ }, \ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); + py::call_guard()); \ + m.def( \ + "allreduce_2shot_mc", \ + [](int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* mcptr_in, void* mcptr_out, \ + size_t bytes, void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, \ + float eps, const int hidden_size) { \ + transformer_engine::allreduce_2shot_mc( \ + ranks, myrank, uc0ptr, mc0ptr, mcptr_in, mcptr_out, bytes, residual_in, residual_out, \ + fuse_layernorm, gamma, eps, hidden_size, at::cuda::getCurrentCUDAStream()); \ + }, \ + py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("mc0ptr"), \ + py::arg("mcptr_in"), py::arg("mcptr_out"), py::arg("bytes"), py::arg("residual_in"), \ + py::arg("residual_out"), py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), \ + py::arg("hidden_size")); \ + m.def( \ + "allreduce_2shot_uc", \ + [](int ranks, int myrank, void* uc0ptr, void* ucptr_in, void* ucptr_out, size_t bytes, \ + void* residual_in, void* residual_out, bool fuse_layernorm, void* gamma, float eps, \ + const int hidden_size) { \ + transformer_engine::allreduce_2shot_uc( \ + ranks, myrank, uc0ptr, ucptr_in, ucptr_out, bytes, residual_in, residual_out, \ + fuse_layernorm, gamma, eps, hidden_size, at::cuda::getCurrentCUDAStream()); \ + }, \ + py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("ucptr_in"), \ + py::arg("ucptr_out"), py::arg("bytes"), py::arg("residual_in"), py::arg("residual_out"), \ + py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), py::arg("hidden_size")); \ + m.def( \ + "allreduce_2shot_mc_lamport", \ + [](int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* ucptr_out, void* mcptr_in, \ + void* mcptr_out, void* clear_ptr, size_t bytes, bool poisoned, void* residual_in, \ + void* residual_out, bool fuse_layernorm, void* gamma, float eps, const int hidden_size) { \ + transformer_engine::allreduce_2shot_mc_lamport( \ + ranks, myrank, uc0ptr, mc0ptr, ucptr_out, mcptr_in, mcptr_out, clear_ptr, bytes, \ + poisoned, residual_in, residual_out, fuse_layernorm, gamma, eps, hidden_size, \ + at::cuda::getCurrentCUDAStream()); \ + }, \ + py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("mc0ptr"), \ + py::arg("ucptr_out"), py::arg("mcptr_in"), py::arg("mcptr_out"), py::arg("clear_ptr"), \ + py::arg("bytes"), py::arg("poisoned"), py::arg("residual_in"), py::arg("residual_out"), \ + py::arg("fuse_layernorm"), py::arg("gamma"), py::arg("eps"), py::arg("hidden_size")); #endif diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index 944d1849bf..07d150b0f0 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -7,3 +7,4 @@ from .fused_attn import * from .gemm import * +from .symm_allocator import * diff --git a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py new file mode 100644 index 0000000000..46253affea --- /dev/null +++ b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py @@ -0,0 +1,439 @@ +import torch +import os +import gc +import weakref +from typing import List, Tuple, Optional, Dict +from threading import Lock +import torch.distributed._symmetric_memory as symm_mem +from ctypes import pythonapi, c_void_p, py_object + + +def to_capsule(ptr): + # Set the return type to py_object to get a Python object (PyCapsule) + pythonapi.PyCapsule_New.restype = py_object + pythonapi.PyCapsule_New.argtypes = [c_void_p, c_void_p, c_void_p] + # Create capsule with a name (optional, can be None) and no destructor + capsule = pythonapi.PyCapsule_New(ptr, None, None) + return capsule + + +class SymmTensor(torch.Tensor): + """Custom tensor subclass that uses custom memory""" + + @staticmethod + def __new__( + cls, + pool: torch.Tensor, + offset: int, + shape: torch.Size, + dtype: torch.dtype, + allocator: "SymmAllocator", + ): + # Calculate number of elements and bytes + num_elements = torch.Size(shape).numel() + element_size = torch.tensor(0, dtype=dtype).element_size() + nbytes = element_size * num_elements + + # Validate pool + assert pool.dtype == torch.uint8, f"Expected uint8 pool, got {pool.dtype}" + assert ( + pool.numel() >= offset + nbytes + ), f"Pool too small: {pool.numel()} bytes, need {offset + nbytes}" + + # Slice the pool to get the required bytes + byte_slice = pool[offset : offset + nbytes] + + # Reinterpret the uint8 bytes as the target dtype + tensor = byte_slice.view(dtype=dtype) + tensor = tensor.view(*shape) + + # Initialize as a subclass of torch.Tensor + self = torch.Tensor._make_subclass(cls, tensor) + if not isinstance(allocator, SymmAllocator): + raise TypeError(f"Expected SymmAllocator, got {type(allocator)}") + self._allocator = allocator + self._ptr = tensor.data_ptr() + self._offset = offset + self._size = nbytes + return self + + def __del__(self): + """Custom deallocator to return memory to the pool.""" + if hasattr(self, "_allocator") and hasattr(self, "_ptr"): + self._allocator.free(self._ptr) + + +class SymmAllocator: + def __init__(self, size_bytes: int, device: torch.device, dist_group: torch.distributed.group): + """Initialize the allocator with a preallocated memory pool.""" + # Preallocate the memory pool using torch.empty + self.reg0_size = 1024 # NVL72*8 plus up to 112 flags + self.device = device + self.world_size = torch.distributed.get_world_size(dist_group) + self.myrank = torch.distributed.get_rank(dist_group) + self.dist_group = dist_group + + from ..module.base import get_ub + + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): + self.ub_obj = get_ub("ubnext") + self.internal_pool = self.ub_obj.get_buffer(False).reshape(-1) + self.mc0_ptr = self.ub_obj.init_ubnext() + self.pool_size = self.internal_pool.numel() + else: + alignment = 2 * 1024 * 1024 # memory is allocated in 2MB pages anyways + self.pool_size = int((size_bytes + alignment - 1) / alignment) * alignment + # symm_mem.set_backend("NCCL") + self.internal_pool = symm_mem.empty(self.pool_size, dtype=torch.uint8, device=device) + self.hdl0 = symm_mem.rendezvous(self.internal_pool, dist_group) + self.mc0_ptr = self.hdl0.multicast_ptr + self.internal_pool.fill_(0) + self.internal_pool.view(torch.int64)[: self.world_size].copy_( + torch.tensor(self.hdl0.buffer_ptrs).view(torch.int64) + ) + # self.hdl0.barrier(channel=0) + # Synchronize all processes before proceeding + torch.distributed.barrier(group=dist_group) + + # Track the raw pointer to the pool + self.pool_ptr = self.internal_pool.data_ptr() + # Track allocated segments: (offset, size) + self.allocated: List[Tuple[int, int]] = [] + # Track free segments: (offset, size) + self.freelist: List[Tuple[int, int]] = [(self.reg0_size, self.pool_size - self.reg0_size)] + self.nextpoisoned = None + self.residual = None + self.residual_tokens = 0 + self.tensors = weakref.WeakSet() + self.lock = Lock() + + def allocate(self, nbytes: int) -> Tuple[Optional[int], Optional[torch.Tensor]]: + """Allocate nbytes from the pool, returning a pointer and pool reference.""" + with self.lock: + for i, (offset, size) in enumerate(self.freelist): + if size >= nbytes: + self.freelist.pop(i) + self.allocated.append((offset, nbytes)) + if size > nbytes: + self.freelist.append((offset + nbytes, size - nbytes)) + return self.pool_ptr + offset, self.internal_pool + return None, None + + # No suitable free segment found + raise MemoryError( + f"Preallocated pool exhausted: requested {nbytes} bytes, " + f"available segments: {self.freelist}" + ) + + def free(self, ptr: int): + """Free the memory at ptr, returning it to the pool.""" + with self.lock: + offset = ptr - self.pool_ptr + for i, (alloc_offset, size) in enumerate(self.allocated): + if alloc_offset == offset: + self.allocated.pop(i) + self.freelist.append((offset, size)) + self.freelist.sort(key=lambda x: x[0]) + self._merge_free_segments() + return + # Ignore invalid pointers silently + pass + + raise ValueError(f"Invalid pointer {ptr} not found in allocated segments") + + def _merge_free_segments(self): + """Merge adjacent free segments to reduce fragmentation.""" + if not self.freelist: + return + merged = [] + current_offset, current_size = self.freelist[0] + for offset, size in self.freelist[1:]: + if current_offset + current_size == offset: + # Adjacent segments, merge them + current_size += size + else: + # Non-adjacent, keep current and start new + merged.append((current_offset, current_size)) + current_offset, current_size = offset, size + merged.append((current_offset, current_size)) + self.freelist = merged + + def create_tensor( + self, shape: torch.Size, dtype: torch.dtype = torch.float32 + ) -> Optional[torch.Tensor]: + """Create a PooledTensor using memory from the pool.""" + nbytes = torch.tensor(0, dtype=dtype).element_size() * torch.Size(shape).numel() + ptr, pool = self.allocate(nbytes) + if ptr is None: + return None + offset = ptr - self.pool_ptr + tensor = SymmTensor(pool, offset, torch.Size(shape), dtype, self) + self.tensors.add(tensor) + return tensor + + def allreduce_uc( + self, + tensor_in: torch.Tensor, + hidden_size: int = 0, + residual_in: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + fuse_layernorm: bool = False, + gamma: Optional[torch.Tensor] = None, + eps: Optional[float] = None, + ) -> torch.Tensor: + """Performs in-place allreduce on the given SymmTensor using best algo""" + assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" + + # tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + + ucptr_in = tensor_in.data_ptr() + # mcptr_out = tensor_out.data_ptr() + nbytes = tensor_in.numel() * tensor_in.element_size() + + # Import your pybind module if not imported + from transformer_engine_torch import allreduce_2shot_uc + + allreduce_2shot_uc( + self.world_size, + self.myrank, + to_capsule(self.internal_pool.data_ptr()), + to_capsule(ucptr_in), + to_capsule(ucptr_in), # out + nbytes, + to_capsule(residual_in.data_ptr()) if residual_in is not None else None, + to_capsule(residual_out.data_ptr()) if residual_out is not None else None, + fuse_layernorm, + to_capsule(gamma.data_ptr()) if gamma is not None else None, + eps if eps is not None else 0.0, + hidden_size, + ) + return tensor_in + + def allreduce_simple( + self, + tensor_in: torch.Tensor, + hidden_size: int = 0, + residual_in: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + fuse_layernorm: bool = False, + gamma: Optional[torch.Tensor] = None, + eps: Optional[float] = None, + ) -> torch.Tensor: + """Performs in-place allreduce on the given SymmTensor using best algo""" + assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" + + # tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + + mcptr_in = self.mc0_ptr + (tensor_in.data_ptr() - self.internal_pool.data_ptr()) + # mcptr_out = self.hdl.multicast_ptr + (tensor_out.data_ptr() - self.internal_pool.data_ptr()) + nbytes = tensor_in.numel() * tensor_in.element_size() + + # Import your pybind module if not imported + from transformer_engine_torch import allreduce_2shot_mc + + allreduce_2shot_mc( + self.world_size, + self.myrank, + to_capsule(self.internal_pool.data_ptr()), + to_capsule(self.mc0_ptr), + to_capsule(mcptr_in), + to_capsule(mcptr_in), # out + nbytes, + to_capsule(residual_in.data_ptr()) if residual_in is not None else None, + to_capsule(residual_out.data_ptr()) if residual_out is not None else None, + fuse_layernorm, + to_capsule(gamma.data_ptr()) if gamma is not None else None, + eps if eps is not None else 0.0, + hidden_size, + ) + return tensor_in + + def allreduce_lamport( + self, + tensor_in: torch.Tensor, + hidden_size: int = 0, + residual_in: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + fuse_layernorm: bool = False, + gamma: Optional[torch.Tensor] = None, + eps: Optional[float] = None, + ) -> torch.Tensor: + """ + Performs allreduce using 2-shot multicast Lamport variant: + - Takes `tensor_in` as input (SymmTensor). + - Allocates `tensor_out` of same shape and dtype. + - Runs `allreduce_2shot_mc_lamport` over them. + - Returns `tensor_out`. + """ + assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" + if self.mc0_ptr is None or self.mc0_ptr == 0: + return self.allreduce_uc( + tensor_in, hidden_size, residual_in, residual_out, fuse_layernorm, gamma, eps + ) + from transformer_engine_torch import allreduce_2shot_mc_lamport + + # Allocate output tensor of same shape/dtype + tensor_out = self.nextpoisoned + poisonedout = True + + if self.nextpoisoned is None or self.nextpoisoned.shape != tensor_in.shape: + if self.nextpoisoned is not None: + del self.nextpoisoned + self.nextpoisoned = None + tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + poisonedout = False + if tensor_out is None: + return self.allreduce_simple( + tensor_in, hidden_size, residual_in, residual_out, fuse_layernorm, gamma, eps + ) + + # alllcate potential output for next allreduce (speculative) and poison it now + self.nextpoisoned = self.create_tensor(tensor_in.shape, tensor_in.dtype) + + # Calculate mcptr_in and mcptr_out with offset relative to internal_pool + offset = tensor_in.data_ptr() - self.internal_pool.data_ptr() + mcptr_in = self.mc0_ptr + offset + mcptr_out = self.mc0_ptr + (tensor_out.data_ptr() - self.internal_pool.data_ptr()) + + # Use clear_ptr to clear output memory before reduction; here we use tensor_out + # clear_ptr = self.nextpoisoned.data_ptr() if self.nextpoisoned is not None else 0 + + nbytes = tensor_in.numel() * tensor_in.element_size() + + # Call your pybind lamport allreduce + allreduce_2shot_mc_lamport( + self.world_size, + self.myrank, + to_capsule(self.internal_pool.data_ptr()), + to_capsule(self.mc0_ptr), + to_capsule(tensor_out.data_ptr()), + to_capsule(mcptr_in), + to_capsule(mcptr_out), + to_capsule(self.nextpoisoned.data_ptr()) if self.nextpoisoned is not None else None, + nbytes, + poisonedout, + to_capsule(residual_in.data_ptr()) if residual_in is not None else None, + to_capsule(residual_out.data_ptr()) if residual_out is not None else None, + fuse_layernorm, + to_capsule(gamma.data_ptr()) if gamma is not None else None, + eps if eps is not None else 0.0, + hidden_size, + ) + + return tensor_out + + +_allocator_map: Dict[torch.distributed.group, Tuple[int, "SymmAllocator"]] = {} + + +def ubsymm_request_allocator( + dist_group: torch.distributed.group, + shape: Optional[torch.Size] = None, + dtype: torch.dtype = torch.bfloat16, +) -> None: + if shape is not None: + num_elements = torch.Size(shape).numel() + element_size = torch.tensor(0, dtype=dtype).element_size() + tensor_size = num_elements * element_size + else: + tensor_size = 0 + + if dist_group not in _allocator_map: + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): + assert not _allocator_map, "Current UBNEXT-UB bypass supports only one process group." + _allocator_map[dist_group] = (tensor_size, None) + else: + old_size, allocator = _allocator_map[dist_group] + assert allocator is None, "Second element of tuple must be None" + max_size = max(old_size, tensor_size) + _allocator_map[dist_group] = (max_size, None) + + +def ubsymm_get_sym_tensor( + shape: torch.Size, dtype: torch.dtype, dist_group: torch.distributed.group +) -> torch.Tensor: + if dtype != torch.bfloat16: + return None # Unsupported dtype, do fallback to nccl + if dist_group not in _allocator_map: + return None # No allocator requested earlier, do fallback to nccl + (max_size, allocator) = _allocator_map[dist_group] + if allocator is None: + new_max_size = int( + os.environ.get("NVTE_UB_SYMM_POOL_SIZE", ((6 * max_size + 1048575) / 1024 / 1024)) + ) + allocator = SymmAllocator( + new_max_size * 1024 * 1024, + torch.device(f"cuda:{torch.cuda.current_device()}"), + dist_group, + ) + _allocator_map[dist_group] = (new_max_size, allocator) + return allocator.create_tensor(shape, dtype) + + +def ubsymm_allreduce( + tensor_in: SymmTensor, + residual_global: Optional[torch.Tensor] = None, + gamma: Optional[torch.Tensor] = None, + eps: Optional[float] = None, +) -> SymmTensor: + """ + Performs allreduce on the given SymmTensor using best algo + Four modes: + standalone allreduce: no residual, no layernorm (residual_global passed by user, both eps and gamma=None) + first PROJ layer: layernorm fused, global residual in, internal residual out (residual_global passed by user, both eps and gamma not None) + this allocates internal residual if it wasnt allocated previously or token count is different from previous allreduce + middle layers: layernorm fused, internal residual in, internal residual out (residual_global=None, both eps and gamma not None) + Last FC2 layer: no layernorm, internal residual in, no residual out(layer output is actually the global residual) (residual_global=None, fboth eps and gamma=None) + this is different from standalone once there is no internal residual allocated + """ + fuse_layernorm = gamma is not None and eps is not None + internal_residual = tensor_in._allocator.residual + num_ranks = tensor_in._allocator.world_size + hidden_size = ( + tensor_in.shape[-1] + if fuse_layernorm or internal_residual is not None or residual_global is not None + else tensor_in.numel() // num_ranks + ) + num_tokens = tensor_in.numel() // hidden_size + myrank = tensor_in._allocator.myrank + if residual_global is not None and ( + internal_residual is None or tensor_in._allocator.residual_tokens != num_tokens + ): + my_tokens = num_tokens // num_ranks + extra_tokens = num_tokens % num_ranks + first_token = myrank * my_tokens + if myrank < extra_tokens: + my_tokens += 1 + first_token += myrank + else: + first_token += extra_tokens + if my_tokens == 0: + my_tokens = 1 # avoid empty residual + if tensor_in._allocator.residual is not None: + del tensor_in._allocator.residual + tensor_in._allocator.residual = torch.empty( + my_tokens * hidden_size, dtype=tensor_in.dtype, device=tensor_in.device + ) + tensor_in._allocator.residual_tokens = num_tokens + internal_residual = tensor_in._allocator.residual + + residual_in = residual_global if residual_global is not None else internal_residual + + residual_out = ( + internal_residual if fuse_layernorm else None + ) # without layernorm new full residual is output of allreduce + if tensor_in.numel() > 1048576: + return tensor_in._allocator.allreduce_simple( + tensor_in, hidden_size, residual_in, residual_out, fuse_layernorm, gamma, eps + ) + else: + return tensor_in._allocator.allreduce_lamport( + tensor_in, hidden_size, residual_in, residual_out, fuse_layernorm, gamma, eps + ) + + +def ubsymm_free_residual(tensor_in: SymmTensor): + if tensor_in._allocator.residual is not None: + del tensor_in._allocator.residual + tensor_in._allocator.residual_tokens = 0 + tensor_in._allocator.residual = None diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 38947c5a9d..d0eed0e716 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -260,12 +260,12 @@ void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) // Userbuffers data void *dst_ptr; if (local_chunk) { - NVTE_CHECK(_ubufs[_tp_id].numel() == input_size, + NVTE_CHECK(_ubufs[_tp_id].numel() >= input_size, "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", "(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); dst_ptr = _ubufs[_tp_id].dptr(); } else { - NVTE_CHECK(_ubuf.numel() == input_size, + NVTE_CHECK(_ubuf.numel() >= input_size, "Tried to copy an invalid tensor into a Userbuffers buffer ", "(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")"); dst_ptr = _ubuf.dptr(); @@ -282,11 +282,11 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -204,20 +209,25 @@ def forward( ) # Apply normalization - nvtx_range_push(f"{nvtx_label}.norm") - ln_out, mu, rsigma = apply_normalization( - inputmat, - None, # ln_out - ln_weight, - ln_bias, - eps, - input_quantizer if with_quantized_norm else None, - inputmat.dtype, - normalization, - fwd_ln_sm_margin, - zero_centered_gamma, - ) - nvtx_range_pop(f"{nvtx_label}.norm") + if skip_layernorm: + ln_out = inputmat + mu = None + rsigma = None + else: + nvtx_range_push(f"{nvtx_label}.norm") + ln_out, mu, rsigma = apply_normalization( + inputmat, + None, # ln_out + ln_weight, + ln_bias, + eps, + input_quantizer if with_quantized_norm else None, + inputmat.dtype, + normalization, + fwd_ln_sm_margin, + zero_centered_gamma, + ) + nvtx_range_pop(f"{nvtx_label}.norm") # Store unquantized layer norm output if we need to return it ln_out_return = None @@ -335,7 +345,20 @@ def forward( out_shape[0] //= tp_world_size out_shape[-1] = out_features reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) - + symm_out = None + if ( + symmetric_ar_type is not None + and symmetric_ar_type.startswith("ubnext") + and parallel_mode == "row" + and tp_size > 1 + ): + out_shape_list = list(tuple(inp.shape)) + out_shape_list[-1] = out_features + symm_out = ubsymm_get_sym_tensor( + torch.Size(out_shape_list), + activation_dtype, + tp_group, + ) # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T @@ -352,6 +375,7 @@ def forward( ub=ub_obj, ub_type=ub_type, extra_output=reduce_scatter_out, + out=symm_out, ) nvtx_range_pop(f"{nvtx_label}.gemm") # ------------------------------------------------------ @@ -380,7 +404,17 @@ def forward( out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: if symmetric_ar_type is not None: - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + if symm_out is not None: + out = ubsymm_allreduce(symm_out) + else: + fallback_symmetric = ( + "multimem_all_reduce" + if symmetric_ar_type.startswith("ubnext") + else symmetric_ar_type + ) + out, _ = symmetric_all_reduce( + out, tp_group, all_reduce_type=fallback_symmetric + ) else: out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") @@ -1046,6 +1080,7 @@ def wgrad_gemm( None, # module None, # skip_fp8_weight_update None, # symmetric_ar_type + None, # skip_layernorm ) @@ -1175,6 +1210,7 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, + skip_layernorm: bool = False, name: str = None, ) -> None: super().__init__() @@ -1268,7 +1304,19 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - + if ( + self.symmetric_ar_type.startswith("ubnext") + and parallel_mode == "row" + and tp_size > 1 + ): + ubsymm_request_allocator( + self.tp_group, + ( + int(os.environ.get("NVTE_UB_MAXBATCH", 64)), + self.out_features, + ), + params_dtype, + ) self.eps = eps layer_norm_weight = torch.nn.Parameter( torch.empty(self.in_features, device=device, dtype=params_dtype) @@ -1287,7 +1335,7 @@ def __init__( ) else: self.layer_norm_bias = None - + self.skip_layernorm = skip_layernorm # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1599,6 +1647,7 @@ def forward( self, skip_fp8_weight_update, self.symmetric_ar_type, + self.skip_layernorm, debug, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index cf7f58947b..598433d34c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -7,6 +7,7 @@ from functools import reduce from operator import mul as multiply_op import warnings +import os import torch @@ -53,6 +54,9 @@ ) from ..cpp_extensions import ( general_gemm, + ubsymm_request_allocator, + ubsymm_get_sym_tensor, + ubsymm_allreduce, ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo @@ -118,6 +122,9 @@ def forward( symmetric_ar_type: str, save_original_input: bool = False, debug: Optional[bool] = False, + residual: Optional[torch.Tensor] = None, + eps: Optional[float] = None, + ln_weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -301,6 +308,24 @@ def forward( out_shape[-1] = out_features reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) + symm_out = None + if ( + symmetric_ar_type is not None + and symmetric_ar_type.startswith("ubnext") + and parallel_mode == "row" + and tp_size > 1 + ): + out_shape_list = list(tuple(inp.shape)) + out_shape_list[-1] = out_features + symm_out = ubsymm_get_sym_tensor( + torch.Size(out_shape_list), + activation_dtype, + tp_group, + ) + assert symm_out is not None or symmetric_ar_type == "ubnext", ( + "No symmetric pool out of space fallback for fused ops, increase" + " NVTE_UB_SYMM_POOL_SIZE" + ) # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T @@ -317,6 +342,7 @@ def forward( ub=ub_obj, ub_type=ub_type, extra_output=reduce_scatter_out, + out=symm_out, ) nvtx_range_pop(f"{nvtx_label}.gemm") # ------------------------------------------------------ @@ -344,7 +370,19 @@ def forward( out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: if symmetric_ar_type is not None: - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + if symm_out is not None: + out = ubsymm_allreduce( + symm_out, residual_global=residual, gamma=ln_weight, eps=eps + ) + else: + fallback_symmetric = ( + "multimem_all_reduce" + if symmetric_ar_type.startswith("ubnext") + else symmetric_ar_type + ) + out, _ = symmetric_all_reduce( + out, tp_group, all_reduce_type=fallback_symmetric + ) else: out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") @@ -1112,6 +1150,8 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + eps: Optional[float] = None, + ln_weight: Optional[torch.Tensor] = None, ) -> None: super().__init__() @@ -1200,7 +1240,21 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - + if ( + self.symmetric_ar_type.startswith("ubnext") + and parallel_mode == "row" + and tp_size > 1 + ): + ubsymm_request_allocator( + self.tp_group, + ( + int(os.environ.get("NVTE_UB_MAXBATCH", 64)), + self.out_features, + ), + params_dtype, + ) + self.eps = eps + self.layer_norm_weight = ln_weight # in general expected to be filled with reference to layernorm_weight from next LayerNormLinear later # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1366,6 +1420,7 @@ def forward( is_first_microbatch: Optional[bool] = None, fp8_output: Optional[bool] = False, fp8_grad: Optional[bool] = False, + residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -1478,6 +1533,9 @@ def forward( self.symmetric_ar_type, self.save_original_input, debug, + residual, + self.eps, + self.layer_norm_weight, ) out = linear_fn(*args) if self.gemm_bias_unfused_add: