From 74d415c1c02b9463214fb46db060c0efbfa5a0e4 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 7 Mar 2025 18:43:08 +0000 Subject: [PATCH] removing op_name in the forward pass by adding OpNameContext --- src/nanotron/models/llama.py | 49 ++++++++----------- .../parallel/tensor_parallel/domino.py | 31 ++++++++++++ .../parallel/tensor_parallel/functional.py | 5 +- src/nanotron/parallel/tensor_parallel/nn.py | 5 +- tests/test_domino.py | 42 +++++++++++++++- 5 files changed, 97 insertions(+), 35 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ef182c7b..0fb2af4d 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -39,6 +39,7 @@ BWD_MLP_OP_NAME, FWD_ATTN_OP_NAME, FWD_MLP_OP_NAME, + OpNameContext, ) from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( @@ -250,11 +251,9 @@ def __init__( ) self.split_silu_mul = GLUActivation(config.hidden_act) - def forward( - self, hidden_states: torch.Tensor, op_name: Optional[str] = None - ): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj(hidden_states, op_name=op_name) - hidden_states = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name) + def forward(self, hidden_states: torch.Tensor): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj(hidden_states) + hidden_states = self.down_proj(self.split_silu_mul(merged_states)) return {"hidden_states": hidden_states} @@ -449,9 +448,6 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] - # NOTE: because we dynamically determine which input split - # of domino at runtime, so we need to pass in the op_name - op_name: Optional[str] = None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -460,7 +456,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states, op_name=op_name + hidden_states ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -706,7 +702,7 @@ def forward( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) # output, work = self.o_proj(attention_output, op_name=op_name) - output = self.o_proj(attention_output, op_name=op_name) + output = self.o_proj(attention_output) return {"hidden_states": output, "sequence_mask": sequence_mask} @@ -801,16 +797,17 @@ def _core_forward( self.stream_manager, ) - attn_output0 = self.attn( - hidden_states=hidden_states0, - sequence_mask=sequence_mask0, - op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 0), - ) - attn_output1 = self.attn( - hidden_states=hidden_states1, - sequence_mask=sequence_mask1, - op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 1), - ) + with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)): + attn_output0 = self.attn( + hidden_states=hidden_states0, + sequence_mask=sequence_mask0, + ) + + with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 1)): + attn_output1 = self.attn( + hidden_states=hidden_states1, + sequence_mask=sequence_mask1, + ) comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): @@ -827,10 +824,8 @@ def _core_forward( self.stream_manager, ) - mlp_output0 = self.mlp( - hidden_states=hidden_states0, - op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 0), - ) + with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 0)): + mlp_output0 = self.mlp(hidden_states=hidden_states0) comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): @@ -842,10 +837,8 @@ def _core_forward( residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) - mlp_output1 = self.mlp( - hidden_states=hidden_states1, - op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 1), - ) + with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 1)): + mlp_output1 = self.mlp(hidden_states=hidden_states1) comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 9bfe79d6..d7a98f04 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -1,10 +1,14 @@ import re +import threading +from typing import Optional FWD_MLP_OP_NAME = "fwd.layer_mlp_{}_batch_{}" FWD_ATTN_OP_NAME = "fwd.layer_attn_{}_batch_{}" BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}" BWD_MLP_OP_NAME = "bwd.layer_mlp_{}_batch_{}" +_operation_context = threading.local() + def is_domino_async_comm(x: str) -> bool: """ @@ -20,3 +24,30 @@ def is_domino_async_comm(x: str) -> bool: regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex not_async = bool(regex.match(x)) return not not_async + + +class OpNameContext: + """ + A context manager to set the name of a module operation + """ + + def __init__(self, op_name: str): + self.op_name = op_name + self.previous_op_name = None + + def __enter__(self): + if not hasattr(_operation_context, "current_op_name"): + _operation_context.current_op_name = None + self.previous_op_name = _operation_context.current_op_name + _operation_context.current_op_name = self.op_name + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _operation_context.current_op_name = self.previous_op_name + + +def get_op_name() -> Optional[str]: + """ + Get the name of the current operation. + """ + return getattr(_operation_context, "current_op_name", None) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 57ca7446..03357160 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -25,6 +25,7 @@ differentiable_identity, differentiable_reduce_scatter_sum, ) +from nanotron.parallel.tensor_parallel.domino import get_op_name from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 @@ -437,13 +438,13 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, - op_name: Optional[str] = None, stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + op_name = get_op_name() input = differentiable_identity(input, group=group, op_name=op_name, stream_manager=stream_manager) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: @@ -591,7 +592,6 @@ def row_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, - op_name: Optional[str] = None, stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: @@ -600,6 +600,7 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + op_name = get_op_name() out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=group) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 53f1f930..f92fe0ee 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -91,7 +91,6 @@ def __init__( def forward( self, x: torch.Tensor, - op_name: Optional[str] = None, ) -> torch.Tensor: return column_linear( input=x, @@ -101,7 +100,6 @@ def forward( tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, - op_name=op_name, stream_manager=self.stream_manager, ) @@ -169,7 +167,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor, op_name: Optional[str] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -177,7 +175,6 @@ def forward(self, x: torch.Tensor, op_name: Optional[str] = None) -> torch.Tenso group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - op_name=op_name, stream_manager=self.stream_manager, ) diff --git a/tests/test_domino.py b/tests/test_domino.py index b9f93a24..e3182de4 100644 --- a/tests/test_domino.py +++ b/tests/test_domino.py @@ -9,7 +9,7 @@ from nanotron.models.llama import DominoLlamaDecoderLayer from nanotron.parallel import ParallelContext from nanotron.parallel.comm import CudaStreamManager -from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm +from nanotron.parallel.tensor_parallel.domino import OpNameContext, get_op_name, is_domino_async_comm @pytest.mark.parametrize( @@ -72,3 +72,43 @@ def _test_domino_model( assert isinstance(outputs["loss"], torch.Tensor) assert stream_manager.comm_bucket.is_all_completed() is True + + +### OpNameContext tests ### + + +def test_op_name_context_reentry(): + assert get_op_name() is None + context = OpNameContext("reusable_op") + + with context: + assert get_op_name() == "reusable_op" + + assert get_op_name() is None + + with context: + assert get_op_name() == "reusable_op" + + assert get_op_name() is None + + +def test_deeply_nested_contexts(): + with OpNameContext("level1"): + assert get_op_name() == "level1" + + with OpNameContext("level2"): + assert get_op_name() == "level2" + + assert get_op_name() == "level1" + + +def test_multiple_sequential_contexts(): + assert get_op_name() is None + + with OpNameContext("first_op"): + assert get_op_name() == "first_op" + + with OpNameContext("second_op"): + assert get_op_name() == "second_op" + + assert get_op_name() is None