From 2a88ce7466c5ef228bf76d089a8791e19371e640 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Thu, 2 Oct 2025 09:12:33 -0700 Subject: [PATCH] [asynctp] Fix shape mismatch for agmm lastdim splitcat use stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/184, branch: IvanKobzarev/stack/5 --- autoparallel/asynctp.py | 108 ++++++++++++++++++++++++++++++++---- autoparallel/asynctp_ops.py | 13 +++-- 2 files changed, 106 insertions(+), 15 deletions(-) diff --git a/autoparallel/asynctp.py b/autoparallel/asynctp.py index 9b708ccd..7cec3afd 100644 --- a/autoparallel/asynctp.py +++ b/autoparallel/asynctp.py @@ -13,6 +13,7 @@ import torch from torch._inductor import inductor_prims +from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo from torch._inductor.pattern_matcher import ( MULTIPLE, CallFunction, @@ -23,6 +24,7 @@ PatternExpr, PatternMatcherPass, ) +from torch._logging import trace_structured from torch.utils._ordered_set import OrderedSet import autoparallel.asynctp_ops # noqa: F401 @@ -34,7 +36,8 @@ _micro_pipeline_tp_ag_transpose_mm_enabled = True # Check performance if overhead of decomposition outweights pipeline wins -_micro_pipeline_tp_ag_mm_last_dim_enabled = False +_micro_pipeline_tp_ag_mm_last_dim_enabled = True +_micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled = True _micro_pipeline_tp_mm_rs_last_dim_enabled = True @@ -720,7 +723,7 @@ def _insert_fused_all_gather_transpose_matmul( raise AssertionError(f"Unexpected matmul match type: {mm_type}") -def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: +def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None: """ Fused the pattern @@ -755,6 +758,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: all_gather.group_name, ) + log_strs.append(f"fuse_agmm {all_gather}") if not is_symm_mem_enabled_for_group(group_name): return @@ -774,10 +778,24 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: for matmul in matmuls if all_gather.res_node not in matmul.arg_ancestor_nodes ] + log_strs.append(f"fuse_agmm matmuls:{matmuls}") if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1: return + if ( + _micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled + and gather_dim == _get_tensor(shard_node).ndim - 1 + and len(all_gather.res_node.users) > len(matmuls) + ): + # The result of ag-split-cat is used not only in matmuls. + # Then it has to be materialized, which can have overhead. + # TODO: find out conditions of strideness when there is no overhead. + log_strs.append( + f"fuse_agmm lastdim ag-split-cat {len(all_gather.res_node.users)} used more than num matmuls" + ) + return + # Fuse the all_gather_tensor with the eligible matmuls graph = ag_node.graph with graph.inserting_before(ag_node): @@ -870,6 +888,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: for node in nodes_to_raise: if order[node] > order[fused_node]: fused_node.prepend(node) + log_strs.append("fuse_agmm DONE") def _scatter_dim_after_reshape( @@ -990,7 +1009,7 @@ def _insert_fused_matmul_reduce_scatter( raise AssertionError(f"Unexpected matmul match type: {type(matmul)}") -def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: +def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) -> None: """ Fused the pattern @@ -1004,6 +1023,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: Returns boolean indicating if fusion was successful or not. """ + log_strs.append(f"fuse_mmrs {reduce_scatter}") if ( not torch.distributed.is_available() or not torch.distributed.is_nccl_available() @@ -1032,6 +1052,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: ) if not is_symm_mem_enabled_for_group(group_name): + log_strs.append("fuse_mmrs not symm mem group") return if ( @@ -1048,16 +1069,19 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: log.warning( "matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion." ) + log_strs.append("fuse_mmrs input.node.users != 1") return matmul = _find_producer_matmul(input_node) if matmul is None: + log_strs.append("fuse_mmrs no matmul") log.warning( "no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion" ) return if rs_wait_tensor_node in matmul.arg_ancestor_nodes: + log_strs.append("fuse_mmrs wait in matmul.arg_ancestors") log.warning( "reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion" ) @@ -1123,6 +1147,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: if order[node] > order[fused_node]: fused_node.prepend(node) + log_strs.append("fuse_mmrs DONE") log.debug("successfully fused matmul reduce scatter") @@ -1173,6 +1198,7 @@ def is_collective(node) -> bool: return collective_to_overlappable_nodes +# TODO: Convert return type to set def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]: """ Find all unexposed collectives in the graph. @@ -1209,18 +1235,60 @@ def _is_compute_intensive(node: torch.fx.Node) -> bool: return unexposed_collectives -def micro_pipeline_tp_pass(graph: torch.fx.Graph): +def _get_exposed_collectives(graph): + from torch._inductor.fx_passes.bucketing import is_wait_tensor + + ret = OrderedSet() + node_to_idx = {n: i for i, n in enumerate(graph.nodes)} + for i, n in enumerate(graph.nodes): + if not is_wait_tensor(n): + continue + + assert len(n.args) == 1 + coll_n = n.args[0] + wait_n_idx = i + coll_n_idx = node_to_idx[coll_n] + if wait_n_idx - coll_n_idx <= 1: + ret.add(coll_n) + + return ret + + +def micro_pipeline_tp_pass( + graph: torch.fx.Graph, + collective_info: Optional[dict[torch.fx.Node, CollectiveInfo]] = None, +): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "asynctp_pre_graph", + "encoding": "string", + }, + payload_fn=lambda: graph.owning_module.print_readable( + print_output=False, include_stride=True + ), + ) all_gathers = find_all_gather_patterns(graph) reduce_scatters = find_reduce_scatter_patterns(graph) - # When a collective can be hidden through either simple overlapping or - # micro-pipeline TP, we prefer simple overlapping to avoid the overhead - # associated with decomposition. unexposed_collectives = _get_unexposed_collectives(graph) - all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives] + exposed_collectives = _get_exposed_collectives(graph) + + log_strs = [] + log_strs.append(f"\n all_gathers:{all_gathers}") + log_strs.append(f"\n reduce_scatters:{reduce_scatters}") + + all_gathers = [x for x in all_gathers if x.ag_node in exposed_collectives] reduce_scatters = [ - x for x in reduce_scatters if x.reduce_scatter_node not in unexposed_collectives + x for x in reduce_scatters if x.reduce_scatter_node in exposed_collectives ] + if collective_info is not None: + for n, coll_info in collective_info.items(): + log_strs.append(f"coll_info {n}: {coll_info}") + log_strs.append(f"\n UNUSED unexposed_collectives:{unexposed_collectives}") + log_strs.append(f"\n exposed_collectives:{exposed_collectives}") + log_strs.append(f"\n all_gathers_exposed:{all_gathers}") + log_strs.append(f"\n reduce_scatters_exposed:{reduce_scatters}") if not all_gathers and not reduce_scatters: log.warning( @@ -1228,7 +1296,25 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph): ) for reduce_scatter in reduce_scatters: - fuse_matmul_reduce_scatter(reduce_scatter) + fuse_matmul_reduce_scatter(reduce_scatter, log_strs) for all_gather in all_gathers: - fuse_all_gather_matmul(all_gather) + fuse_all_gather_matmul(all_gather, log_strs) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "asynctp_log", + "encoding": "string", + }, + payload_fn=lambda: "\n".join(log_strs), + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "asynctp_post_graph", + "encoding": "string", + }, + payload_fn=lambda: graph.owning_module.print_readable( + print_output=False, include_stride=True + ), + ) diff --git a/autoparallel/asynctp_ops.py b/autoparallel/asynctp_ops.py index a67d85e1..6d9db73c 100644 --- a/autoparallel/asynctp_ops.py +++ b/autoparallel/asynctp_ops.py @@ -615,6 +615,7 @@ def _fused_all_gather_matmul_last_gather_dim_impl( ) -> tuple[torch.Tensor | None, list[torch.Tensor]]: assert gather_dim == A_shard.ndim - 1 group = c10d._resolve_process_group(group_name) + group_size = group.size() B_shards = [B.chunk(group.size()) for B in Bs] @@ -625,7 +626,6 @@ def unflatten(t: torch.Tensor) -> torch.Tensor: return t.view(*leading_dims, -1) A_out_leading_dims = list(A_shard.shape[:-1]) - A_out_leading_dims[0] *= group.size() def unflatten_A_out(t: torch.Tensor) -> torch.Tensor: return t.view(*A_out_leading_dims, -1) @@ -667,9 +667,14 @@ def default_consumer(shard: torch.Tensor, rank: int) -> None: group_name, return_A, ) + ret_A = None + if return_A: + # This path is inefficient and will be filtered out at passes stage + # Added only for completness. + A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1) + ret_A = unflatten_A_out(A_split_cat_out_flat) - A = unflatten_A_out(A_flat_out) if return_A else None - return A, [unflatten(output) for output in outputs] + return ret_A, [unflatten(output) for output in outputs] @torch.library.impl(lib, "fused_all_gather_matmul", "Meta") @@ -691,7 +696,7 @@ def _fused_all_gather_matmul_fallback( A_mm = torch.cat(A_splits, dim=-1) res = [torch.matmul(A_mm, B) for B in Bs] if return_A: - return A, res + return A_mm, res else: return None, res