From 2c0c2322ede3f7a120f5148848ecf1cba74b26ee Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 9 Dec 2025 13:35:31 -0800 Subject: [PATCH] Fix lint stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/277, branch: xmfan/stack/28 --- autoparallel/asynctp.py | 11 ++++---- autoparallel/asynctp_ops.py | 51 ++++++++++++++++++++----------------- autoparallel/collectives.py | 11 +++++--- tests/test_dtensor.py | 3 +-- 4 files changed, 41 insertions(+), 35 deletions(-) diff --git a/autoparallel/asynctp.py b/autoparallel/asynctp.py index 712634b9..5ae8d153 100644 --- a/autoparallel/asynctp.py +++ b/autoparallel/asynctp.py @@ -25,6 +25,7 @@ PatternMatcherPass, ) from torch._logging import trace_structured +from torch.distributed.distributed_c10d import GroupName from torch.utils._ordered_set import OrderedSet import autoparallel.asynctp_ops # noqa: F401 @@ -89,7 +90,7 @@ class _AllGatherMatch: ag_node: torch.fx.Node res_node: torch.fx.Node gather_dim: int - group_name: str + group_name: GroupName def replace_with(self, new_node: torch.fx.Node) -> None: self.res_node.replace_all_uses_with(new_node) @@ -225,7 +226,7 @@ class _ReduceScatterMatch: wait_tensor_node: torch.fx.Node reduce_op: str scatter_dim: int - group_name: str + group_name: GroupName def replace_with(self, new_node: torch.fx.Node) -> None: # Replace all uses of the result node (wait_tensor) with the fused node. @@ -643,7 +644,7 @@ def _insert_fused_all_gather_matmul( matmuls: list[_Matmul], shard_node: torch.fx.Node, gather_dim: int, - group_name: str, + group_name: GroupName, ) -> torch.fx.Node: mm_types = OrderedSet(map(type, matmuls)) assert len(mm_types) == 1 @@ -704,7 +705,7 @@ def _insert_fused_all_gather_transpose_matmul( matmuls: list[_Matmul], shard_node: torch.fx.Node, gather_dim: int, - group_name: str, + group_name: GroupName, ) -> torch.fx.Node: mm_types = OrderedSet(map(type, matmuls)) assert len(mm_types) == 1 @@ -974,7 +975,7 @@ def _insert_fused_matmul_reduce_scatter( matmul: _Matmul, reduce_op: str, orig_scatter_dim: int, - group_name: str, + group_name: GroupName, scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern ) -> torch.fx.Node: diff --git a/autoparallel/asynctp_ops.py b/autoparallel/asynctp_ops.py index 58770e63..ed2cfa8c 100644 --- a/autoparallel/asynctp_ops.py +++ b/autoparallel/asynctp_ops.py @@ -20,6 +20,9 @@ from torch._C._distributed_c10d import _register_work, _SymmetricMemory from torch.distributed._symmetric_memory import get_symm_mem_workspace, rendezvous +# Import GroupName for type checking +GroupName = c10d.GroupName + _is_test_mode: bool = False _mocked_group_names: set[str] | None = None _backend_streams: dict[int, torch.cuda.Stream] = {} @@ -35,7 +38,7 @@ def _pipelined_multi_all_gather_and_consume( shard: list[torch.Tensor], shard_consumer: Callable[[list[torch.Tensor], int], None], ag_out: list[torch.Tensor], - group_name: str, + group_name: GroupName, ag_out_needed: bool = True, ) -> None: """ @@ -181,7 +184,7 @@ def _pipelined_all_gather_and_consume( shard: torch.Tensor, shard_consumer: Callable[[torch.Tensor, int], None], ag_out: torch.Tensor, - group_name: str, + group_name: GroupName, ag_out_needed: bool = True, ) -> None: """ @@ -209,7 +212,7 @@ def adapter(shard: list[torch.Tensor], rank: int) -> None: def _pipelined_produce_and_all2all( chunk_producer: Callable[[int, torch.Tensor], None], output: torch.Tensor, - group_name: str, + group_name: GroupName, out_chunk_dim=0, ) -> None: """ @@ -400,7 +403,7 @@ def _fused_all_gather_matmul_impl( kwargs_list: list[dict[str, Any]], out_dtypes: list[torch.dtype | None], gather_dim: int, - group_name: str, + group_name: GroupName, return_A: bool, ) -> tuple[torch.Tensor | None, list[torch.Tensor]]: if A_shard.dim() < 2: @@ -534,7 +537,7 @@ def _pipelined_all_gather_and_consume_last_dim( shard: torch.Tensor, shard_consumer: Callable[[torch.Tensor, int], None], ag_out: torch.Tensor, - group_name: str, + group_name: GroupName, ag_out_needed: bool = True, ) -> None: p2p_workspace_size_req = 0 @@ -607,7 +610,7 @@ def _fused_all_gather_matmul_last_gather_dim_impl( kwargs_list: list[dict[str, Any]], out_dtypes: list[torch.dtype | None], gather_dim: int, - group_name: str, + group_name: GroupName, return_A: bool, ) -> tuple[torch.Tensor | None, list[torch.Tensor]]: assert gather_dim == A_shard.ndim - 1 @@ -674,7 +677,7 @@ def _fused_all_gather_matmul_fallback( A_shard: torch.Tensor, Bs: list[torch.Tensor], gather_dim: int, - group_name: str, + group_name: GroupName, *, return_A: bool = True, ) -> tuple[torch.Tensor | None, list[torch.Tensor]]: @@ -705,7 +708,7 @@ def _fused_all_gather_matmul( A_shard: torch.Tensor, Bs: list[torch.Tensor], gather_dim: int, - group_name: str, + group_name: GroupName, *, return_A: bool = True, ) -> tuple[torch.Tensor | None, list[torch.Tensor]]: @@ -756,7 +759,7 @@ def _should_use_fused_all_gather_matmul_native( A_shard: torch.Tensor, Bs: list[torch.Tensor], gather_dim: int, - group_name: str, + group_name: GroupName, ) -> bool: group = c10d._resolve_process_group(group_name) local_M = math.prod(A_shard.shape[:-1]) @@ -778,7 +781,7 @@ def _should_use_fused_all_gather_matmul_native( def _fused_all_gather_matmul_native( A_shard: torch.Tensor, B: torch.Tensor, - group_name: str, + group_name: GroupName, ) -> tuple[torch.Tensor, torch.Tensor]: symm_mem = rendezvous(A_shard, group_name) if symm_mem is None: @@ -832,7 +835,7 @@ def _fused_all_gather_matmul_native( def _should_use_multimem_all_gather_matmul( A_shard: torch.Tensor, gather_dim: int, - group_name: str, + group_name: GroupName, return_A: bool, ) -> bool: group = c10d._resolve_process_group(group_name) @@ -858,7 +861,7 @@ def _should_use_multimem_all_gather_matmul( def _multimem_all_gather_matmul( A_shard: torch.Tensor, Bs: list[torch.Tensor], - group_name: str, + group_name: GroupName, ) -> list[torch.Tensor]: group = c10d._resolve_process_group(group_name) A_shape = torch.Size((A_shard.shape[0] * group.size(), *A_shard.shape[1:])) @@ -877,7 +880,7 @@ def _fused_all_gather_scaled_matmul_fallback( A_scale: torch.Tensor, B_scales: list[torch.Tensor], gather_dim: int, - group_name: str, + group_name: GroupName, biases: list[torch.Tensor | None], result_scales: list[torch.Tensor | None], out_dtypes: list[torch.dtype | None], @@ -951,7 +954,7 @@ def _fused_all_gather_scaled_matmul( A_scale: torch.Tensor, B_scales: list[torch.Tensor], gather_dim: int, - group_name: str, + group_name: GroupName, biases: list[torch.Tensor | None], result_scales: list[torch.Tensor | None], out_dtypes: list[torch.dtype | None], @@ -1057,7 +1060,7 @@ def _fused_matmul_reduce_scatter( B: torch.Tensor, reduce_op: str, scatter_dim: int, - group_name: str, + group_name: GroupName, ) -> torch.Tensor: """ Perform the following logic with micro-pipelined computation and @@ -1093,7 +1096,7 @@ def _fused_matmul_reduce_scatter_fallback( B: torch.Tensor, reduce_op: str, scatter_dim: int, - group_name: str, + group_name: GroupName, ) -> torch.Tensor: res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) res = funcol.wait_tensor(res) @@ -1108,7 +1111,7 @@ def _fused_matmul_reduce_scatter_impl( out_dtype: torch.dtype | None, reduce_op: str, scatter_dim: int, - group_name: str, + group_name: GroupName, ) -> torch.Tensor: if A.dim() < 2: raise ValueError("A_shard must be a matrix") @@ -1194,7 +1197,7 @@ def _fused_scaled_matmul_reduce_scatter( reduce_op: str, orig_scatter_dim: int, scatter_dim_after_maybe_reshape: int, - group_name: str, + group_name: GroupName, output_shape: list[int], bias: torch.Tensor | None = None, result_scale: torch.Tensor | None = None, @@ -1248,7 +1251,7 @@ def _fused_scaled_matmul_reduce_scatter_fallback( reduce_op: str, orig_scatter_dim: int, scatter_dim_after_maybe_reshape: int, - group_name: str, + group_name: GroupName, output_shape: list[int], bias: torch.Tensor | None = None, result_scale: torch.Tensor | None = None, @@ -1300,7 +1303,7 @@ def _fused_scaled_matmul_reduce_scatter_impl( reduce_op: str, orig_scatter_dim: int, scatter_dim_after_maybe_reshape: int, - group_name: str, + group_name: GroupName, output_shape: list[int], ) -> torch.Tensor: if A.dim() < 2: @@ -1524,7 +1527,7 @@ def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool: @torch.library.impl(lib, "_low_contention_all_gather", "Meta") def _low_contention_all_gather_meta( tensor: torch.Tensor, - group_name: str, + group_name: GroupName, ) -> torch.Tensor: group_size = c10d._get_group_size_by_name(group_name) return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:]) @@ -1533,7 +1536,7 @@ def _low_contention_all_gather_meta( @torch.library.impl(lib, "_low_contention_all_gather", "CUDA") def _low_contention_all_gather( tensor: torch.Tensor, - group_name: str, + group_name: GroupName, ) -> torch.Tensor: """ Performs all-gather with symmetric memory in a low-contention fashion. @@ -1582,7 +1585,7 @@ def _low_contention_all_gather( def _low_contention_reduce_scatter_meta( tensor: torch.Tensor, reduce_op: str, - group_name: str, + group_name: GroupName, ) -> torch.Tensor: group_size = c10d._get_group_size_by_name(group_name) return tensor.unflatten(0, (group_size, -1)).mean(dim=0) @@ -1665,7 +1668,7 @@ def _low_contention_reduce_scatter_with_workspace( def _low_contention_reduce_scatter( tensor: torch.Tensor, reduce_op: str, - group_name: str, + group_name: GroupName, ) -> torch.Tensor: """ Performs reduce-scatter with symmetric memory in a low-contention fashion. diff --git a/autoparallel/collectives.py b/autoparallel/collectives.py index 1600f36b..fa7092a4 100644 --- a/autoparallel/collectives.py +++ b/autoparallel/collectives.py @@ -9,6 +9,9 @@ import torch.distributed.distributed_c10d as c10d from torch.distributed._tensor.experimental import local_map as _local_map +# Import GroupName for type checking +GroupName = c10d.GroupName + _local_map_device_mesh = None @@ -51,7 +54,7 @@ def axis_index(axis_name): def _all_gather_tensor( x: torch.Tensor, gather_dim: int, - group_name: str, + group_name: GroupName, ) -> torch.Tensor: x = x.contiguous() group_size = c10d._get_group_size_by_name(group_name) @@ -67,7 +70,7 @@ def _all_gather_tensor( def _reduce_scatter_tensor( - self: torch.Tensor, reduceOp: str, scatter_dim: int, group_name: str + self: torch.Tensor, reduceOp: str, scatter_dim: int, group_name: GroupName ): group_size = c10d._get_group_size_by_name(group_name) @@ -88,7 +91,7 @@ def _reduce_scatter_tensor( return res -def _all_reduce(self: torch.Tensor, reduceOp: str, group_name: str): +def _all_reduce(self: torch.Tensor, reduceOp: str, group_name: GroupName): tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name) res = torch.ops._c10d_functional.wait_tensor(tensor) return res @@ -98,7 +101,7 @@ def _all_to_all( self: torch.Tensor, output_split_sizes: Optional[list[int]], input_split_sizes: Optional[list[int]], - group_name: str, + group_name: GroupName, ): group_size = c10d._get_group_size_by_name(group_name) if output_split_sizes is None or input_split_sizes is None: diff --git a/tests/test_dtensor.py b/tests/test_dtensor.py index 22ce387f..0f035f9e 100644 --- a/tests/test_dtensor.py +++ b/tests/test_dtensor.py @@ -16,7 +16,6 @@ OpSpec, OpStrategy, OutputSharding, - OutputSpecType, RuntimeSchemaInfo, TupleStrategy, ) @@ -241,7 +240,7 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin # for ops that return multiple tensors and the output_specs is not # a tuple, we use a tuple of that single output spec as the new # output_specs - output_specs: OutputSpecType = output_strategy.output_specs + output_specs = output_strategy.output_specs if isinstance(output_specs, DTensorSpec): output_specs = tuple( [