Skip to content

Commit

Permalink
removing op_name in the forward pass by adding OpNameContext
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Mar 7, 2025
1 parent d3d8c10 commit 74d415c
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 35 deletions.
49 changes: 21 additions & 28 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
BWD_MLP_OP_NAME,
FWD_ATTN_OP_NAME,
FWD_MLP_OP_NAME,
OpNameContext,
)
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from nanotron.parallel.tensor_parallel.nn import (
Expand Down Expand Up @@ -250,11 +251,9 @@ def __init__(
)
self.split_silu_mul = GLUActivation(config.hidden_act)

def forward(
self, hidden_states: torch.Tensor, op_name: Optional[str] = None
): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states, op_name=op_name)
hidden_states = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name)
def forward(self, hidden_states: torch.Tensor): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
hidden_states = self.down_proj(self.split_silu_mul(merged_states))
return {"hidden_states": hidden_states}


Expand Down Expand Up @@ -449,9 +448,6 @@ def forward(
self,
hidden_states, # [seq_length, batch_size, hidden_size]
sequence_mask, # [batch_size, seq_length]
# NOTE: because we dynamically determine which input split
# of domino at runtime, so we need to pass in the op_name
op_name: Optional[str] = None,
):
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
Expand All @@ -460,7 +456,7 @@ def forward(
)

qkv_states = self.qkv_proj(
hidden_states, op_name=op_name
hidden_states
) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
q_length, batch_size, _ = qkv_states.shape

Expand Down Expand Up @@ -706,7 +702,7 @@ def forward(
attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
)
# output, work = self.o_proj(attention_output, op_name=op_name)
output = self.o_proj(attention_output, op_name=op_name)
output = self.o_proj(attention_output)

return {"hidden_states": output, "sequence_mask": sequence_mask}

Expand Down Expand Up @@ -801,16 +797,17 @@ def _core_forward(
self.stream_manager,
)

attn_output0 = self.attn(
hidden_states=hidden_states0,
sequence_mask=sequence_mask0,
op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 0),
)
attn_output1 = self.attn(
hidden_states=hidden_states1,
sequence_mask=sequence_mask1,
op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 1),
)
with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)):
attn_output0 = self.attn(
hidden_states=hidden_states0,
sequence_mask=sequence_mask0,
)

with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 1)):
attn_output1 = self.attn(
hidden_states=hidden_states1,
sequence_mask=sequence_mask1,
)

comm_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(comm_stream):
Expand All @@ -827,10 +824,8 @@ def _core_forward(
self.stream_manager,
)

mlp_output0 = self.mlp(
hidden_states=hidden_states0,
op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 0),
)
with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 0)):
mlp_output0 = self.mlp(hidden_states=hidden_states0)

comm_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(comm_stream):
Expand All @@ -842,10 +837,8 @@ def _core_forward(
residual1 = hidden_states1
hidden_states1 = self.post_attention_layernorm(hidden_states1)

mlp_output1 = self.mlp(
hidden_states=hidden_states1,
op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 1),
)
with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 1)):
mlp_output1 = self.mlp(hidden_states=hidden_states1)

comm_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(comm_stream):
Expand Down
31 changes: 31 additions & 0 deletions src/nanotron/parallel/tensor_parallel/domino.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import re
import threading
from typing import Optional

FWD_MLP_OP_NAME = "fwd.layer_mlp_{}_batch_{}"
FWD_ATTN_OP_NAME = "fwd.layer_attn_{}_batch_{}"
BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}"
BWD_MLP_OP_NAME = "bwd.layer_mlp_{}_batch_{}"

_operation_context = threading.local()


def is_domino_async_comm(x: str) -> bool:
"""
Expand All @@ -20,3 +24,30 @@ def is_domino_async_comm(x: str) -> bool:
regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex
not_async = bool(regex.match(x))
return not not_async


class OpNameContext:
"""
A context manager to set the name of a module operation
"""

def __init__(self, op_name: str):
self.op_name = op_name
self.previous_op_name = None

def __enter__(self):
if not hasattr(_operation_context, "current_op_name"):
_operation_context.current_op_name = None
self.previous_op_name = _operation_context.current_op_name
_operation_context.current_op_name = self.op_name
return self

def __exit__(self, exc_type, exc_val, exc_tb):
_operation_context.current_op_name = self.previous_op_name


def get_op_name() -> Optional[str]:
"""
Get the name of the current operation.
"""
return getattr(_operation_context, "current_op_name", None)
5 changes: 3 additions & 2 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
differentiable_identity,
differentiable_reduce_scatter_sum,
)
from nanotron.parallel.tensor_parallel.domino import get_op_name
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1

Expand Down Expand Up @@ -437,13 +438,13 @@ def column_linear(
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool = True,
op_name: Optional[str] = None,
stream_manager: Optional[CudaStreamManager] = None,
):
if async_communication:
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
op_name = get_op_name()
input = differentiable_identity(input, group=group, op_name=op_name, stream_manager=stream_manager)
return F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
Expand Down Expand Up @@ -591,7 +592,6 @@ def row_linear(
group: dist.ProcessGroup,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
op_name: Optional[str] = None,
stream_manager: Optional[CudaStreamManager] = None,
):
if async_communication:
Expand All @@ -600,6 +600,7 @@ def row_linear(
out = F.linear(input, weight, bias)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
op_name = get_op_name()
out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
out = differentiable_reduce_scatter_sum(out, group=group)
Expand Down
5 changes: 1 addition & 4 deletions src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def __init__(
def forward(
self,
x: torch.Tensor,
op_name: Optional[str] = None,
) -> torch.Tensor:
return column_linear(
input=x,
Expand All @@ -101,7 +100,6 @@ def forward(
tp_mode=self.mode,
async_communication=self.async_communication,
tp_recompute_allgather=self.tp_recompute_allgather,
op_name=op_name,
stream_manager=self.stream_manager,
)

Expand Down Expand Up @@ -169,15 +167,14 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
)
setattr(self, name, new_param)

def forward(self, x: torch.Tensor, op_name: Optional[str] = None) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
return row_linear(
input=x,
weight=self.weight,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
op_name=op_name,
stream_manager=self.stream_manager,
)

Expand Down
42 changes: 41 additions & 1 deletion tests/test_domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from nanotron.models.llama import DominoLlamaDecoderLayer
from nanotron.parallel import ParallelContext
from nanotron.parallel.comm import CudaStreamManager
from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm
from nanotron.parallel.tensor_parallel.domino import OpNameContext, get_op_name, is_domino_async_comm


@pytest.mark.parametrize(
Expand Down Expand Up @@ -72,3 +72,43 @@ def _test_domino_model(

assert isinstance(outputs["loss"], torch.Tensor)
assert stream_manager.comm_bucket.is_all_completed() is True


### OpNameContext tests ###


def test_op_name_context_reentry():
assert get_op_name() is None
context = OpNameContext("reusable_op")

with context:
assert get_op_name() == "reusable_op"

assert get_op_name() is None

with context:
assert get_op_name() == "reusable_op"

assert get_op_name() is None


def test_deeply_nested_contexts():
with OpNameContext("level1"):
assert get_op_name() == "level1"

with OpNameContext("level2"):
assert get_op_name() == "level2"

assert get_op_name() == "level1"


def test_multiple_sequential_contexts():
assert get_op_name() is None

with OpNameContext("first_op"):
assert get_op_name() == "first_op"

with OpNameContext("second_op"):
assert get_op_name() == "second_op"

assert get_op_name() is None

0 comments on commit 74d415c

Please sign in to comment.