Skip to content

Commit b4a76e9

Browse files
Sanket Jayant Purandarexmfan
authored andcommitted
Adding split FSDP Collective Pass
1 parent 2f1452e commit b4a76e9

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

autoparallel/_passes/graph_multiplex.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
import copy
27

38
import torch
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import dataclasses
7+
8+
import torch
9+
import torch.utils._pytree as pytree
10+
from torch._functorch._aot_autograd.descriptors import AOTOutput
11+
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
12+
13+
14+
@dataclasses.dataclass(frozen=True)
15+
class PrefetchOutput(AOTOutput):
16+
pass
17+
18+
19+
def split_fsdp_prefetch(
20+
gm: torch.fx.GraphModule,
21+
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
22+
g = gm.graph
23+
g_ins = g.find_nodes(op="placeholder")
24+
prefetch_g_outs_map = {}
25+
26+
for g_in in g_ins:
27+
n = g_in
28+
while True:
29+
if len(n.users) != 1:
30+
break
31+
user = next(iter(n.users))
32+
if len(user.all_input_nodes) > 1:
33+
break
34+
n = user
35+
prefetch_g_outs_map[g_in] = n
36+
37+
prefetch_g_outs = list(prefetch_g_outs_map.values())
38+
prefetch_g_outs_descs: list[AOTOutput] = [
39+
PrefetchOutput() for _ in range(len(prefetch_g_outs))
40+
]
41+
42+
prefetch_g = _extract_graph_with_inputs_outputs(
43+
g,
44+
g_ins,
45+
prefetch_g_outs,
46+
prefetch_g_outs_descs,
47+
)
48+
49+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
50+
g_outs_descs = pytree.arg_tree_leaves(
51+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
52+
)
53+
main_g = _extract_graph_with_inputs_outputs(
54+
g,
55+
prefetch_g_outs,
56+
g_outs,
57+
g_outs_descs,
58+
)
59+
main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g)
60+
prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g)
61+
return prefetch_gm, main_gm

examples/example_llama3.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.testing._internal.distributed.fake_pg import FakeStore
1313

1414
from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph
15+
from autoparallel._passes.split_fsdp_collectives import split_fsdp_prefetch
1516
from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
1617
from autoparallel.api import AutoParallel
1718
from autoparallel.auto_bucketing import (
@@ -257,9 +258,22 @@ def _pass(graph):
257258
if multiplex_graph:
258259
f_gm = autop.fw_module
259260
b_gm = autop.bw_module
260-
multiplexed_gm = multiplex_fw_bw_graph(f_gm, b_gm)
261+
print("Original Fwd Graph:")
261262
print(f_gm.graph)
263+
print("Original Bwd Graph:")
262264
print(b_gm.graph)
265+
prefetch_f_gm, main_f_gm = split_fsdp_prefetch(f_gm)
266+
print("Main Fwd Graph:")
267+
print(main_f_gm.graph)
268+
print("Prefetch Fwd Graph:")
269+
print(prefetch_f_gm.graph)
270+
prefetch_b_gm, main_b_gm = split_fsdp_prefetch(b_gm)
271+
print("Main Bwd Graph:")
272+
print(main_b_gm.graph)
273+
print("Prefetch Bwd Graph:")
274+
print(prefetch_b_gm.graph)
275+
multiplexed_gm = multiplex_fw_bw_graph(main_f_gm, main_b_gm)
276+
print("Multiplexed Graph:")
263277
print(multiplexed_gm.graph)
264278

265279
# run weight init on our sharded DTensor params

0 commit comments

Comments
 (0)