Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 97 additions & 11 deletions autoparallel/asynctp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
from torch._inductor import inductor_prims
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
from torch._inductor.pattern_matcher import (
MULTIPLE,
CallFunction,
Expand All @@ -23,6 +24,7 @@
PatternExpr,
PatternMatcherPass,
)
from torch._logging import trace_structured
from torch.utils._ordered_set import OrderedSet

import autoparallel.asynctp_ops # noqa: F401
Expand All @@ -34,7 +36,8 @@
_micro_pipeline_tp_ag_transpose_mm_enabled = True

# Check performance if overhead of decomposition outweights pipeline wins
_micro_pipeline_tp_ag_mm_last_dim_enabled = False
_micro_pipeline_tp_ag_mm_last_dim_enabled = True
_micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled = True

_micro_pipeline_tp_mm_rs_last_dim_enabled = True

Expand Down Expand Up @@ -720,7 +723,7 @@ def _insert_fused_all_gather_transpose_matmul(
raise AssertionError(f"Unexpected matmul match type: {mm_type}")


def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
"""
Fused the pattern

Expand Down Expand Up @@ -755,6 +758,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
all_gather.group_name,
)

log_strs.append(f"fuse_agmm {all_gather}")
if not is_symm_mem_enabled_for_group(group_name):
return

Expand All @@ -774,10 +778,24 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
for matmul in matmuls
if all_gather.res_node not in matmul.arg_ancestor_nodes
]
log_strs.append(f"fuse_agmm matmuls:{matmuls}")

if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1:
return

if (
_micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled
and gather_dim == _get_tensor(shard_node).ndim - 1
and len(all_gather.res_node.users) > len(matmuls)
):
# The result of ag-split-cat is used not only in matmuls.
# Then it has to be materialized, which can have overhead.
# TODO: find out conditions of strideness when there is no overhead.
log_strs.append(
f"fuse_agmm lastdim ag-split-cat {len(all_gather.res_node.users)} used more than num matmuls"
)
return

# Fuse the all_gather_tensor with the eligible matmuls
graph = ag_node.graph
with graph.inserting_before(ag_node):
Expand Down Expand Up @@ -870,6 +888,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
for node in nodes_to_raise:
if order[node] > order[fused_node]:
fused_node.prepend(node)
log_strs.append("fuse_agmm DONE")


def _scatter_dim_after_reshape(
Expand Down Expand Up @@ -990,7 +1009,7 @@ def _insert_fused_matmul_reduce_scatter(
raise AssertionError(f"Unexpected matmul match type: {type(matmul)}")


def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) -> None:
"""
Fused the pattern

Expand All @@ -1004,6 +1023,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:

Returns boolean indicating if fusion was successful or not.
"""
log_strs.append(f"fuse_mmrs {reduce_scatter}")
if (
not torch.distributed.is_available()
or not torch.distributed.is_nccl_available()
Expand Down Expand Up @@ -1032,6 +1052,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
)

if not is_symm_mem_enabled_for_group(group_name):
log_strs.append("fuse_mmrs not symm mem group")
return

if (
Expand All @@ -1048,16 +1069,19 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
log.warning(
"matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion."
)
log_strs.append("fuse_mmrs input.node.users != 1")
return

matmul = _find_producer_matmul(input_node)
if matmul is None:
log_strs.append("fuse_mmrs no matmul")
log.warning(
"no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion"
)
return

if rs_wait_tensor_node in matmul.arg_ancestor_nodes:
log_strs.append("fuse_mmrs wait in matmul.arg_ancestors")
log.warning(
"reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion"
)
Expand Down Expand Up @@ -1123,6 +1147,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
if order[node] > order[fused_node]:
fused_node.prepend(node)

log_strs.append("fuse_mmrs DONE")
log.debug("successfully fused matmul reduce scatter")


Expand Down Expand Up @@ -1173,6 +1198,7 @@ def is_collective(node) -> bool:
return collective_to_overlappable_nodes


# TODO: Convert return type to set
def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]:
"""
Find all unexposed collectives in the graph.
Expand Down Expand Up @@ -1209,26 +1235,86 @@ def _is_compute_intensive(node: torch.fx.Node) -> bool:
return unexposed_collectives


def micro_pipeline_tp_pass(graph: torch.fx.Graph):
def _get_exposed_collectives(graph):
from torch._inductor.fx_passes.bucketing import is_wait_tensor

ret = OrderedSet()
node_to_idx = {n: i for i, n in enumerate(graph.nodes)}
for i, n in enumerate(graph.nodes):
if not is_wait_tensor(n):
continue

assert len(n.args) == 1
coll_n = n.args[0]
wait_n_idx = i
coll_n_idx = node_to_idx[coll_n]
if wait_n_idx - coll_n_idx <= 1:
ret.add(coll_n)

return ret


def micro_pipeline_tp_pass(
graph: torch.fx.Graph,
collective_info: Optional[dict[torch.fx.Node, CollectiveInfo]] = None,
):
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "asynctp_pre_graph",
"encoding": "string",
},
payload_fn=lambda: graph.owning_module.print_readable(
print_output=False, include_stride=True
),
)
all_gathers = find_all_gather_patterns(graph)
reduce_scatters = find_reduce_scatter_patterns(graph)

# When a collective can be hidden through either simple overlapping or
# micro-pipeline TP, we prefer simple overlapping to avoid the overhead
# associated with decomposition.
unexposed_collectives = _get_unexposed_collectives(graph)
all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives]
exposed_collectives = _get_exposed_collectives(graph)

log_strs = []
log_strs.append(f"\n all_gathers:{all_gathers}")
log_strs.append(f"\n reduce_scatters:{reduce_scatters}")

all_gathers = [x for x in all_gathers if x.ag_node in exposed_collectives]
reduce_scatters = [
x for x in reduce_scatters if x.reduce_scatter_node not in unexposed_collectives
x for x in reduce_scatters if x.reduce_scatter_node in exposed_collectives
]
if collective_info is not None:
for n, coll_info in collective_info.items():
log_strs.append(f"coll_info {n}: {coll_info}")
log_strs.append(f"\n UNUSED unexposed_collectives:{unexposed_collectives}")
log_strs.append(f"\n exposed_collectives:{exposed_collectives}")
log_strs.append(f"\n all_gathers_exposed:{all_gathers}")
log_strs.append(f"\n reduce_scatters_exposed:{reduce_scatters}")

if not all_gathers and not reduce_scatters:
log.warning(
"async TP found no matching all-gather/reduce-scatter patterns for fusion"
)

for reduce_scatter in reduce_scatters:
fuse_matmul_reduce_scatter(reduce_scatter)
fuse_matmul_reduce_scatter(reduce_scatter, log_strs)

for all_gather in all_gathers:
fuse_all_gather_matmul(all_gather)
fuse_all_gather_matmul(all_gather, log_strs)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "asynctp_log",
"encoding": "string",
},
payload_fn=lambda: "\n".join(log_strs),
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "asynctp_post_graph",
"encoding": "string",
},
payload_fn=lambda: graph.owning_module.print_readable(
print_output=False, include_stride=True
),
)
13 changes: 9 additions & 4 deletions autoparallel/asynctp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ def _fused_all_gather_matmul_last_gather_dim_impl(
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
assert gather_dim == A_shard.ndim - 1
group = c10d._resolve_process_group(group_name)
group_size = group.size()

B_shards = [B.chunk(group.size()) for B in Bs]

Expand All @@ -625,7 +626,6 @@ def unflatten(t: torch.Tensor) -> torch.Tensor:
return t.view(*leading_dims, -1)

A_out_leading_dims = list(A_shard.shape[:-1])
A_out_leading_dims[0] *= group.size()

def unflatten_A_out(t: torch.Tensor) -> torch.Tensor:
return t.view(*A_out_leading_dims, -1)
Expand Down Expand Up @@ -667,9 +667,14 @@ def default_consumer(shard: torch.Tensor, rank: int) -> None:
group_name,
return_A,
)
ret_A = None
if return_A:
# This path is inefficient and will be filtered out at passes stage
# Added only for completness.
A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1)
ret_A = unflatten_A_out(A_split_cat_out_flat)

A = unflatten_A_out(A_flat_out) if return_A else None
return A, [unflatten(output) for output in outputs]
return ret_A, [unflatten(output) for output in outputs]


@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
Expand All @@ -691,7 +696,7 @@ def _fused_all_gather_matmul_fallback(
A_mm = torch.cat(A_splits, dim=-1)
res = [torch.matmul(A_mm, B) for B in Bs]
if return_A:
return A, res
return A_mm, res
else:
return None, res

Expand Down