Skip to content

Commit aa4d15e

Browse files
authored
Fix tensor_meta in LayerNorm prop rules (#269)
- LayerNorm prop rule assumes output tensor_meta is same as input tensor_meta. This assumption is invalid if the input to LN is a view. The output tensor_meta is contiguous in such case. - Same bug exists in backward as well. Testing: Adding unit test of permute -> LN that exposed the bug. <!-- ps-id: 43bc2cd2-761f-4d80-be5f-fa7c16e5158f -->
1 parent e794cc2 commit aa4d15e

File tree

2 files changed

+131
-5
lines changed

2 files changed

+131
-5
lines changed

autoparallel/propagation_rules.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,19 @@ def native_layer_norm_rule(mesh, op_schema):
501501
if is_valid:
502502
output_spec = strategy.output_specs
503503
input_spec = strategy.input_specs[0]
504-
output_spec.tensor_meta = input_spec.tensor_meta
505-
assert output_spec.tensor_meta is not None
506504
mesh = strategy.mesh
505+
506+
# Create output tensor_meta with same shape as input but contiguous strides
507+
# (LayerNorm forward returns contiguous tensor even if input was non-contiguous)
508+
output_tensor_meta = _gen_tensor_meta(
509+
input_spec.tensor_meta.shape, input_spec.tensor_meta.dtype
510+
)
511+
output_spec = DTensorSpec(
512+
mesh=mesh,
513+
placements=output_spec.placements,
514+
tensor_meta=output_tensor_meta,
515+
)
516+
507517
# the output spec is the same as input spec
508518
shape = input_spec.tensor_meta.shape[:axis] + (1,) * len(normalized_size)
509519
mean_std_tgt_spec = DTensorSpec(
@@ -565,15 +575,19 @@ def native_layer_norm_backward_rule(mesh, op_schema):
565575
break
566576
if is_valid:
567577
mesh = strategy.mesh
578+
# Create grad_input tensor_meta with same shape as input but contiguous strides
579+
# (LayerNorm backward returns contiguous gradient even if input was non-contiguous)
580+
grad_input_tensor_meta = _gen_tensor_meta(
581+
input_spec.tensor_meta.shape, input_spec.tensor_meta.dtype
582+
)
568583
grad_input_spec = DTensorSpec(
569584
mesh=mesh,
570585
placements=strategy.output_specs.placements,
571-
tensor_meta=strategy.output_specs.tensor_meta,
586+
tensor_meta=grad_input_tensor_meta,
572587
)
588+
assert grad_input_spec.tensor_meta is not None
573589
weight_spec = strategy.input_specs[4]
574590
bias_spec = strategy.input_specs[5]
575-
grad_input_spec.tensor_meta = input_spec.tensor_meta
576-
assert grad_input_spec.tensor_meta is not None
577591
weight_tgt_spec = DTensorSpec(
578592
mesh=mesh,
579593
placements=weight_spec.placements,

tests/test_propagation_rules.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import pytest
7+
import torch
8+
from torch import nn
9+
from torch.distributed.fsdp import MixedPrecisionPolicy
10+
from torch.distributed.tensor.placement_types import Shard
11+
from torch.testing._internal.distributed.fake_pg import FakeStore
12+
13+
from autoparallel.api import AutoParallel
14+
15+
16+
@pytest.fixture(scope="module", autouse=True)
17+
def init_pg():
18+
world_size = 256
19+
fake_store = FakeStore()
20+
if torch.distributed.is_initialized():
21+
return
22+
torch.distributed.init_process_group(
23+
"fake", store=fake_store, rank=0, world_size=world_size
24+
)
25+
26+
27+
@pytest.fixture(scope="module")
28+
def device_mesh_1d():
29+
world_size = torch.distributed.get_world_size()
30+
mesh = torch.distributed.device_mesh.init_device_mesh(
31+
"cuda", (world_size,), mesh_dim_names=("dp",)
32+
)
33+
return mesh
34+
35+
36+
def test_permute_layernorm_stride_handling(device_mesh_1d):
37+
"""Test that permute + layernorm handles non-contiguous to contiguous stride transitions.
38+
39+
This test reproduces the stride mismatch bug in ConvNeXt-style architectures where:
40+
1. First permute creates a non-contiguous tensor (view) with stride (301056, 56, 1, 3136)
41+
2. LayerNorm receives non-contiguous input but returns a contiguous tensor
42+
3. Second permute creates another non-contiguous tensor (view)
43+
"""
44+
45+
class PermuteLayerNormNet(nn.Module):
46+
"""Network with permute -> LayerNorm -> permute."""
47+
48+
def __init__(self, channels):
49+
super().__init__()
50+
self.norm = nn.LayerNorm(channels, eps=1e-6)
51+
52+
def forward(self, x):
53+
# (N, C, H, W) -> (N, H, W, C)
54+
x = x.permute(0, 2, 3, 1)
55+
# LayerNorm on last dim (C)
56+
x = self.norm(x)
57+
# (N, H, W, C) -> (N, C, H, W)
58+
x = x.permute(0, 3, 1, 2)
59+
return x
60+
61+
batch_size = 256
62+
channels = 96
63+
height = 56
64+
width = 56
65+
66+
def input_fn():
67+
return torch.rand(batch_size, channels, height, width, device="cuda")
68+
69+
# Create model on meta device
70+
with torch.device("meta"):
71+
model = PermuteLayerNormNet(channels=channels)
72+
73+
# Mixed precision policy
74+
mp_policy = MixedPrecisionPolicy(
75+
param_dtype=torch.float32, reduce_dtype=torch.float32
76+
)
77+
78+
# This should not raise an AssertionError about tensor_meta stride mismatch.
79+
with AutoParallel(
80+
model, input_fn, device_mesh_1d, mp_policy, compile=True
81+
) as autop:
82+
x_sharding = (Shard(0),)
83+
y_sharding = (Shard(0),)
84+
85+
autop.add_input_constraints([x_sharding])
86+
autop.add_output_constraints([y_sharding])
87+
88+
sharding_placement = autop.optimize_placement()
89+
90+
# Apply the optimized placement
91+
parallel_mod = autop.apply_placement(sharding_placement)
92+
93+
# Initialize the parallel module
94+
parallel_mod.to_empty(device="cuda")
95+
96+
for name, param in parallel_mod.named_parameters():
97+
if "weight" in name:
98+
torch.nn.init.ones_(param)
99+
elif "bias" in name:
100+
torch.nn.init.zeros_(param)
101+
102+
# Test forward pass execution works
103+
local_batch_size = batch_size // torch.distributed.get_world_size()
104+
x_test = torch.rand(local_batch_size, channels, height, width, device="cuda")
105+
out = parallel_mod(x_test)
106+
107+
# Verify output shape (should match input after permute -> norm -> permute)
108+
assert out.shape == (local_batch_size, channels, height, width)
109+
# Output may be non-contiguous due to final permute (view operation)
110+
111+
# Verify forward execution produces correct output
112+
assert out.abs().sum() > 0

0 commit comments

Comments
 (0)