Skip to content

Commit dd64fb7

Browse files
committed
Prototype to run AutoParallel PP with Local Tensor
ghstack-source-id: 1fd148b Pull Request resolved: #252
1 parent f609a73 commit dd64fb7

File tree

7 files changed

+1069
-151
lines changed

7 files changed

+1069
-151
lines changed

autoparallel/_testing/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
import torch.distributed as dist
8+
import torch.nn as nn
9+
from autoparallel.graph_pp_runner import GraphPipelineStage
10+
from torch._C._distributed_c10d import FakeWork, PythonCallbackWork
11+
from torch.distributed import DeviceMesh
12+
from torch.distributed._local_tensor import (
13+
local_tensor_mode,
14+
LocalIntNode,
15+
LocalRunnerMode,
16+
LocalTensor,
17+
LocalTensorMode,
18+
maybe_disable_local_tensor_mode,
19+
)
20+
from torch.distributed._local_tensor._c10d import local_p2p_op
21+
from torch.distributed.pipelining.stage import InputInfo, PipelineStage
22+
from torch.distributed.tensor import DTensor
23+
from torch.export._unlift import _assign_attr
24+
from torch.export.unflatten import _AttrKind
25+
26+
27+
_pg_groups: list[list[int]] = []
28+
29+
30+
def create_local_tensor_mode(dp_ep_mesh: DeviceMesh, pp_rank: int) -> LocalTensorMode:
31+
dp_ep_full_mesh = dp_ep_mesh._layout.remap_to_tensor(dp_ep_mesh._rank_map)
32+
dp_ep_ranks = dp_ep_full_mesh[pp_rank].flatten().tolist()
33+
print(f"Creating local tensor mode for ranks {dp_ep_ranks}")
34+
return LocalTensorMode(frozenset(dp_ep_ranks))
35+
36+
37+
def cache_pp_groups(pp_mesh: DeviceMesh) -> list[list[int]]:
38+
pp_full_mesh = pp_mesh._layout.remap_to_tensor(pp_mesh._rank_map)
39+
pp_groups = []
40+
for i in range(pp_full_mesh.size(dim=0)):
41+
pp_group = pp_full_mesh[i].tolist()
42+
pp_groups.append(pp_group)
43+
global _pp_groups
44+
_pp_groups = pp_groups
45+
return pp_groups
46+
47+
48+
def combine_works(works: list[dist.Work], ctx: str | None = None) -> dist.Work:
49+
def _wait_all(timeout) -> bool:
50+
for w in works:
51+
w.wait()
52+
return True
53+
54+
return PythonCallbackWork(_wait_all)
55+
56+
57+
def get_pp_peer(self: int, peer: int) -> torch.SymInt:
58+
pp_ret = {}
59+
global _pp_groups
60+
for pp_group in _pp_groups:
61+
global_rank = pp_group[self]
62+
global_peer = pp_group[peer]
63+
pp_ret[global_rank] = global_peer
64+
return torch.SymInt(LocalIntNode(pp_ret))
65+
66+
67+
def expand_p2p_ops(
68+
ops: list[dist.P2POp], pp_rank: int, ctx: str | None = None
69+
) -> list[dist.P2POp]:
70+
# Ops where generated from a perspective of pp group where rank 0 is present.
71+
72+
def multi_isend(tensor, dst=None, group=None, tag=0, group_src=None):
73+
assert group_src is not None, "Expected group rank"
74+
peer = get_pp_peer(pp_rank, group_src)
75+
if not isinstance(tensor, LocalTensor):
76+
tensor = maybe_make_tensor_local(tensor)
77+
works = local_p2p_op(peer, tensor, dist.isend)
78+
return FakeWork()
79+
80+
def multi_irecv(tensor, src=None, group=None, tag=0, group_src=None):
81+
assert group_src is not None, "Expected group rank"
82+
peer = get_pp_peer(pp_rank, group_src)
83+
assert isinstance(tensor, LocalTensor), "Expected LocalTensor"
84+
works = local_p2p_op(peer, tensor, dist.irecv)
85+
return combine_works(works)
86+
87+
send_ops = []
88+
recv_ops = []
89+
for p2p_op in ops:
90+
op = p2p_op.op
91+
if op is dist.isend:
92+
p2p_op.op = multi_isend
93+
send_ops.append(p2p_op)
94+
elif op is dist.irecv:
95+
p2p_op.op = multi_irecv
96+
recv_ops.append(p2p_op)
97+
else:
98+
raise AssertionError("Unxpected op {op}")
99+
100+
# Execute send ops first and then recv because the latter are blocking
101+
return send_ops + recv_ops
102+
103+
104+
class LocalGraphPipelineStage(GraphPipelineStage):
105+
def log_name(self) -> str:
106+
return (
107+
f"PP rank {self.group_rank} Stage {self.stage_index} of {self.num_stages}"
108+
)
109+
110+
def _get_recv_ops(self, recv_infos: tuple[InputInfo, ...]) -> list[dist.P2POp]:
111+
ops = super()._get_recv_ops(recv_infos)
112+
ops = expand_p2p_ops(ops, self.group_rank, self.log_name() + " _get_recv_ops")
113+
return ops
114+
115+
def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
116+
ops = super().get_fwd_send_ops(fwd_chunk_id)
117+
ops = expand_p2p_ops(
118+
ops, self.group_rank, self.log_name() + " get_fwd_send_ops"
119+
)
120+
return ops
121+
122+
def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]:
123+
ops = super().get_bwd_send_ops(bwd_chunk_id)
124+
ops = expand_p2p_ops(
125+
ops, self.group_rank, self.log_name() + " get_bwd_send_ops"
126+
)
127+
return ops
128+
129+
def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]:
130+
ops = super()._get_init_p2p_neighbors_ops()
131+
ops = expand_p2p_ops(
132+
ops, self.group_rank, self.log_name() + " _get_init_p2p_neighbors_ops"
133+
)
134+
return ops
135+
136+
137+
def local_tensor_mode_if_enabled(
138+
ltm: LocalTensorMode | None = None,
139+
) -> LocalTensorMode | None:
140+
141+
for _ in range(2):
142+
if ltm is not None and not ltm._disable:
143+
return ltm
144+
ltm = local_tensor_mode()
145+
146+
return None
147+
148+
149+
def maybe_make_tensor_local(
150+
tensor: torch.Tensor,
151+
ltm: LocalTensorMode | None = None,
152+
) -> torch.Tensor:
153+
ltm = local_tensor_mode_if_enabled(ltm)
154+
if ltm is None:
155+
return tensor
156+
157+
if isinstance(tensor, LocalTensor):
158+
return tensor
159+
160+
if isinstance(tensor, DTensor):
161+
tensor._local_tensor = maybe_make_tensor_local(tensor._local_tensor, ltm)
162+
return tensor
163+
164+
local_tensor = ltm.rank_map(lambda r: tensor.clone().detach())
165+
local_tensor.requires_grad = tensor.requires_grad
166+
return local_tensor
167+
168+
169+
def maybe_make_module_local(
170+
module: nn.Module,
171+
ltm: LocalTensorMode | None = None,
172+
) -> None:
173+
ltm = local_tensor_mode_if_enabled(ltm)
174+
print(f"maybe_make_module_local {ltm.ranks}")
175+
if ltm is None:
176+
return
177+
178+
for k, v in module.named_parameters():
179+
_assign_attr(
180+
nn.Parameter(
181+
data=maybe_make_tensor_local(v.data, ltm),
182+
requires_grad=v.requires_grad,
183+
),
184+
module,
185+
k,
186+
attr_kind=_AttrKind.PARAMETER,
187+
)
188+
189+
for k, v in module.named_buffers():
190+
_assign_attr(
191+
maybe_make_tensor_local(v, ltm), module, k, attr_kind=_AttrKind.BUFFER
192+
)

autoparallel/api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
import functools
88
import itertools
99
import warnings
10-
from contextlib import ExitStack, contextmanager
10+
from contextlib import contextmanager, ExitStack
1111
from types import MethodType
1212
from typing import Any, Callable, Optional, Union
1313

1414
import torch
15+
16+
from autoparallel._passes.graph_partition import partition_joint_with_descriptors
1517
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
1618
from torch._functorch.aot_autograd import (
1719
aot_compile_joint_with_descriptors,
@@ -29,8 +31,6 @@
2931
from torch.export.unflatten import _AttrKind
3032
from torch.fx.experimental.symbolic_shapes import ShapeEnv
3133

32-
from autoparallel._passes.graph_partition import partition_joint_with_descriptors
33-
3434
from .activation_checkpointing import ac_joint_pass
3535
from .apply_sharding import apply_sharding_to_model
3636
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
@@ -44,9 +44,9 @@
4444
from .init_weights import hook_params_setters
4545
from .optimize_sharding import ShardingOptimizer
4646
from .utils import (
47-
NumericsLogger,
4847
_get_device_from_mesh,
4948
debug_boxed_nop_preserve_node_meta,
49+
NumericsLogger,
5050
)
5151

5252
_APPLY_VIEW_MM_VIEW_PATTERN = False
@@ -120,7 +120,7 @@ def _move_to_fake(module, k, device, parameter=True):
120120
# can patch the verification logic.
121121
@contextmanager
122122
def monkey_patch_export_verifier():
123-
from torch._export.verifier import SpecViolationError, Verifier, final
123+
from torch._export.verifier import final, SpecViolationError, Verifier
124124

125125
prior = Verifier._check_graph_module
126126

autoparallel/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
import torch
1010
import torch.utils._pytree as pytree
11+
12+
from autoparallel.propagation_rules import generate_dummy_redistribute_costs
13+
from torch.distributed._local_tensor import LocalTensor
1114
from torch.distributed._tensor.placement_types import Placement, TensorMeta
1215
from torch.distributed.device_mesh import _get_device_handle
1316
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@@ -22,14 +25,12 @@
2225
from torch.distributed.tensor.placement_types import Replicate
2326
from torch.utils._pytree import tree_flatten, tree_map_only
2427

25-
from autoparallel.propagation_rules import generate_dummy_redistribute_costs
26-
2728
from .dtensor_util import get_op_strategy, with_implicit_strategies
2829
from .propagation_rules import (
29-
TENSOR_FACTORY_OPS,
3030
_op_partial_rules,
3131
_op_rules,
3232
remove_invalid_configs,
33+
TENSOR_FACTORY_OPS,
3334
)
3435

3536

@@ -466,7 +467,7 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, should_lo
466467
if name not in real_params:
467468
continue
468469
param = real_params[name]
469-
param_logs.append(f"{name=} hash={hash_tensor(param)}")
470+
param_logs.append(f"rank={name=} hash={hash_tensor(param)}")
470471
with open(path, "a") as f:
471472
f.write("\n".join(param_logs) + "\n")
472473
torch.distributed.barrier()
@@ -490,7 +491,6 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, should_lo
490491

491492
def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, should_log):
492493
path = self.dir / "diff.log"
493-
494494
for i in range(num_world_stages):
495495
if should_log and i in stage_mods:
496496
grad_logs = []

0 commit comments

Comments
 (0)