Skip to content

Commit 02cf33f

Browse files
authored
Fix lint (#277)
stack-info: PR: #277, branch: xmfan/stack/28
1 parent 3ab0ba3 commit 02cf33f

File tree

4 files changed

+41
-35
lines changed

4 files changed

+41
-35
lines changed

autoparallel/asynctp.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
PatternMatcherPass,
2626
)
2727
from torch._logging import trace_structured
28+
from torch.distributed.distributed_c10d import GroupName
2829
from torch.utils._ordered_set import OrderedSet
2930

3031
import autoparallel.asynctp_ops # noqa: F401
@@ -89,7 +90,7 @@ class _AllGatherMatch:
8990
ag_node: torch.fx.Node
9091
res_node: torch.fx.Node
9192
gather_dim: int
92-
group_name: str
93+
group_name: GroupName
9394

9495
def replace_with(self, new_node: torch.fx.Node) -> None:
9596
self.res_node.replace_all_uses_with(new_node)
@@ -225,7 +226,7 @@ class _ReduceScatterMatch:
225226
wait_tensor_node: torch.fx.Node
226227
reduce_op: str
227228
scatter_dim: int
228-
group_name: str
229+
group_name: GroupName
229230

230231
def replace_with(self, new_node: torch.fx.Node) -> None:
231232
# Replace all uses of the result node (wait_tensor) with the fused node.
@@ -643,7 +644,7 @@ def _insert_fused_all_gather_matmul(
643644
matmuls: list[_Matmul],
644645
shard_node: torch.fx.Node,
645646
gather_dim: int,
646-
group_name: str,
647+
group_name: GroupName,
647648
) -> torch.fx.Node:
648649
mm_types = OrderedSet(map(type, matmuls))
649650
assert len(mm_types) == 1
@@ -704,7 +705,7 @@ def _insert_fused_all_gather_transpose_matmul(
704705
matmuls: list[_Matmul],
705706
shard_node: torch.fx.Node,
706707
gather_dim: int,
707-
group_name: str,
708+
group_name: GroupName,
708709
) -> torch.fx.Node:
709710
mm_types = OrderedSet(map(type, matmuls))
710711
assert len(mm_types) == 1
@@ -974,7 +975,7 @@ def _insert_fused_matmul_reduce_scatter(
974975
matmul: _Matmul,
975976
reduce_op: str,
976977
orig_scatter_dim: int,
977-
group_name: str,
978+
group_name: GroupName,
978979
scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern
979980
output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern
980981
) -> torch.fx.Node:

autoparallel/asynctp_ops.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from torch._C._distributed_c10d import _register_work, _SymmetricMemory
2121
from torch.distributed._symmetric_memory import get_symm_mem_workspace, rendezvous
2222

23+
# Import GroupName for type checking
24+
GroupName = c10d.GroupName
25+
2326
_is_test_mode: bool = False
2427
_mocked_group_names: set[str] | None = None
2528
_backend_streams: dict[int, torch.cuda.Stream] = {}
@@ -35,7 +38,7 @@ def _pipelined_multi_all_gather_and_consume(
3538
shard: list[torch.Tensor],
3639
shard_consumer: Callable[[list[torch.Tensor], int], None],
3740
ag_out: list[torch.Tensor],
38-
group_name: str,
41+
group_name: GroupName,
3942
ag_out_needed: bool = True,
4043
) -> None:
4144
"""
@@ -181,7 +184,7 @@ def _pipelined_all_gather_and_consume(
181184
shard: torch.Tensor,
182185
shard_consumer: Callable[[torch.Tensor, int], None],
183186
ag_out: torch.Tensor,
184-
group_name: str,
187+
group_name: GroupName,
185188
ag_out_needed: bool = True,
186189
) -> None:
187190
"""
@@ -209,7 +212,7 @@ def adapter(shard: list[torch.Tensor], rank: int) -> None:
209212
def _pipelined_produce_and_all2all(
210213
chunk_producer: Callable[[int, torch.Tensor], None],
211214
output: torch.Tensor,
212-
group_name: str,
215+
group_name: GroupName,
213216
out_chunk_dim=0,
214217
) -> None:
215218
"""
@@ -400,7 +403,7 @@ def _fused_all_gather_matmul_impl(
400403
kwargs_list: list[dict[str, Any]],
401404
out_dtypes: list[torch.dtype | None],
402405
gather_dim: int,
403-
group_name: str,
406+
group_name: GroupName,
404407
return_A: bool,
405408
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
406409
if A_shard.dim() < 2:
@@ -534,7 +537,7 @@ def _pipelined_all_gather_and_consume_last_dim(
534537
shard: torch.Tensor,
535538
shard_consumer: Callable[[torch.Tensor, int], None],
536539
ag_out: torch.Tensor,
537-
group_name: str,
540+
group_name: GroupName,
538541
ag_out_needed: bool = True,
539542
) -> None:
540543
p2p_workspace_size_req = 0
@@ -607,7 +610,7 @@ def _fused_all_gather_matmul_last_gather_dim_impl(
607610
kwargs_list: list[dict[str, Any]],
608611
out_dtypes: list[torch.dtype | None],
609612
gather_dim: int,
610-
group_name: str,
613+
group_name: GroupName,
611614
return_A: bool,
612615
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
613616
assert gather_dim == A_shard.ndim - 1
@@ -674,7 +677,7 @@ def _fused_all_gather_matmul_fallback(
674677
A_shard: torch.Tensor,
675678
Bs: list[torch.Tensor],
676679
gather_dim: int,
677-
group_name: str,
680+
group_name: GroupName,
678681
*,
679682
return_A: bool = True,
680683
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
@@ -705,7 +708,7 @@ def _fused_all_gather_matmul(
705708
A_shard: torch.Tensor,
706709
Bs: list[torch.Tensor],
707710
gather_dim: int,
708-
group_name: str,
711+
group_name: GroupName,
709712
*,
710713
return_A: bool = True,
711714
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
@@ -756,7 +759,7 @@ def _should_use_fused_all_gather_matmul_native(
756759
A_shard: torch.Tensor,
757760
Bs: list[torch.Tensor],
758761
gather_dim: int,
759-
group_name: str,
762+
group_name: GroupName,
760763
) -> bool:
761764
group = c10d._resolve_process_group(group_name)
762765
local_M = math.prod(A_shard.shape[:-1])
@@ -778,7 +781,7 @@ def _should_use_fused_all_gather_matmul_native(
778781
def _fused_all_gather_matmul_native(
779782
A_shard: torch.Tensor,
780783
B: torch.Tensor,
781-
group_name: str,
784+
group_name: GroupName,
782785
) -> tuple[torch.Tensor, torch.Tensor]:
783786
symm_mem = rendezvous(A_shard, group_name)
784787
if symm_mem is None:
@@ -832,7 +835,7 @@ def _fused_all_gather_matmul_native(
832835
def _should_use_multimem_all_gather_matmul(
833836
A_shard: torch.Tensor,
834837
gather_dim: int,
835-
group_name: str,
838+
group_name: GroupName,
836839
return_A: bool,
837840
) -> bool:
838841
group = c10d._resolve_process_group(group_name)
@@ -858,7 +861,7 @@ def _should_use_multimem_all_gather_matmul(
858861
def _multimem_all_gather_matmul(
859862
A_shard: torch.Tensor,
860863
Bs: list[torch.Tensor],
861-
group_name: str,
864+
group_name: GroupName,
862865
) -> list[torch.Tensor]:
863866
group = c10d._resolve_process_group(group_name)
864867
A_shape = torch.Size((A_shard.shape[0] * group.size(), *A_shard.shape[1:]))
@@ -877,7 +880,7 @@ def _fused_all_gather_scaled_matmul_fallback(
877880
A_scale: torch.Tensor,
878881
B_scales: list[torch.Tensor],
879882
gather_dim: int,
880-
group_name: str,
883+
group_name: GroupName,
881884
biases: list[torch.Tensor | None],
882885
result_scales: list[torch.Tensor | None],
883886
out_dtypes: list[torch.dtype | None],
@@ -951,7 +954,7 @@ def _fused_all_gather_scaled_matmul(
951954
A_scale: torch.Tensor,
952955
B_scales: list[torch.Tensor],
953956
gather_dim: int,
954-
group_name: str,
957+
group_name: GroupName,
955958
biases: list[torch.Tensor | None],
956959
result_scales: list[torch.Tensor | None],
957960
out_dtypes: list[torch.dtype | None],
@@ -1057,7 +1060,7 @@ def _fused_matmul_reduce_scatter(
10571060
B: torch.Tensor,
10581061
reduce_op: str,
10591062
scatter_dim: int,
1060-
group_name: str,
1063+
group_name: GroupName,
10611064
) -> torch.Tensor:
10621065
"""
10631066
Perform the following logic with micro-pipelined computation and
@@ -1093,7 +1096,7 @@ def _fused_matmul_reduce_scatter_fallback(
10931096
B: torch.Tensor,
10941097
reduce_op: str,
10951098
scatter_dim: int,
1096-
group_name: str,
1099+
group_name: GroupName,
10971100
) -> torch.Tensor:
10981101
res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name)
10991102
res = funcol.wait_tensor(res)
@@ -1108,7 +1111,7 @@ def _fused_matmul_reduce_scatter_impl(
11081111
out_dtype: torch.dtype | None,
11091112
reduce_op: str,
11101113
scatter_dim: int,
1111-
group_name: str,
1114+
group_name: GroupName,
11121115
) -> torch.Tensor:
11131116
if A.dim() < 2:
11141117
raise ValueError("A_shard must be a matrix")
@@ -1194,7 +1197,7 @@ def _fused_scaled_matmul_reduce_scatter(
11941197
reduce_op: str,
11951198
orig_scatter_dim: int,
11961199
scatter_dim_after_maybe_reshape: int,
1197-
group_name: str,
1200+
group_name: GroupName,
11981201
output_shape: list[int],
11991202
bias: torch.Tensor | None = None,
12001203
result_scale: torch.Tensor | None = None,
@@ -1248,7 +1251,7 @@ def _fused_scaled_matmul_reduce_scatter_fallback(
12481251
reduce_op: str,
12491252
orig_scatter_dim: int,
12501253
scatter_dim_after_maybe_reshape: int,
1251-
group_name: str,
1254+
group_name: GroupName,
12521255
output_shape: list[int],
12531256
bias: torch.Tensor | None = None,
12541257
result_scale: torch.Tensor | None = None,
@@ -1300,7 +1303,7 @@ def _fused_scaled_matmul_reduce_scatter_impl(
13001303
reduce_op: str,
13011304
orig_scatter_dim: int,
13021305
scatter_dim_after_maybe_reshape: int,
1303-
group_name: str,
1306+
group_name: GroupName,
13041307
output_shape: list[int],
13051308
) -> torch.Tensor:
13061309
if A.dim() < 2:
@@ -1524,7 +1527,7 @@ def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool:
15241527
@torch.library.impl(lib, "_low_contention_all_gather", "Meta")
15251528
def _low_contention_all_gather_meta(
15261529
tensor: torch.Tensor,
1527-
group_name: str,
1530+
group_name: GroupName,
15281531
) -> torch.Tensor:
15291532
group_size = c10d._get_group_size_by_name(group_name)
15301533
return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:])
@@ -1533,7 +1536,7 @@ def _low_contention_all_gather_meta(
15331536
@torch.library.impl(lib, "_low_contention_all_gather", "CUDA")
15341537
def _low_contention_all_gather(
15351538
tensor: torch.Tensor,
1536-
group_name: str,
1539+
group_name: GroupName,
15371540
) -> torch.Tensor:
15381541
"""
15391542
Performs all-gather with symmetric memory in a low-contention fashion.
@@ -1582,7 +1585,7 @@ def _low_contention_all_gather(
15821585
def _low_contention_reduce_scatter_meta(
15831586
tensor: torch.Tensor,
15841587
reduce_op: str,
1585-
group_name: str,
1588+
group_name: GroupName,
15861589
) -> torch.Tensor:
15871590
group_size = c10d._get_group_size_by_name(group_name)
15881591
return tensor.unflatten(0, (group_size, -1)).mean(dim=0)
@@ -1665,7 +1668,7 @@ def _low_contention_reduce_scatter_with_workspace(
16651668
def _low_contention_reduce_scatter(
16661669
tensor: torch.Tensor,
16671670
reduce_op: str,
1668-
group_name: str,
1671+
group_name: GroupName,
16691672
) -> torch.Tensor:
16701673
"""
16711674
Performs reduce-scatter with symmetric memory in a low-contention fashion.

autoparallel/collectives.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import torch.distributed.distributed_c10d as c10d
1010
from torch.distributed._tensor.experimental import local_map as _local_map
1111

12+
# Import GroupName for type checking
13+
GroupName = c10d.GroupName
14+
1215
_local_map_device_mesh = None
1316

1417

@@ -51,7 +54,7 @@ def axis_index(axis_name):
5154
def _all_gather_tensor(
5255
x: torch.Tensor,
5356
gather_dim: int,
54-
group_name: str,
57+
group_name: GroupName,
5558
) -> torch.Tensor:
5659
x = x.contiguous()
5760
group_size = c10d._get_group_size_by_name(group_name)
@@ -67,7 +70,7 @@ def _all_gather_tensor(
6770

6871

6972
def _reduce_scatter_tensor(
70-
self: torch.Tensor, reduceOp: str, scatter_dim: int, group_name: str
73+
self: torch.Tensor, reduceOp: str, scatter_dim: int, group_name: GroupName
7174
):
7275
group_size = c10d._get_group_size_by_name(group_name)
7376

@@ -88,7 +91,7 @@ def _reduce_scatter_tensor(
8891
return res
8992

9093

91-
def _all_reduce(self: torch.Tensor, reduceOp: str, group_name: str):
94+
def _all_reduce(self: torch.Tensor, reduceOp: str, group_name: GroupName):
9295
tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
9396
res = torch.ops._c10d_functional.wait_tensor(tensor)
9497
return res
@@ -98,7 +101,7 @@ def _all_to_all(
98101
self: torch.Tensor,
99102
output_split_sizes: Optional[list[int]],
100103
input_split_sizes: Optional[list[int]],
101-
group_name: str,
104+
group_name: GroupName,
102105
):
103106
group_size = c10d._get_group_size_by_name(group_name)
104107
if output_split_sizes is None or input_split_sizes is None:

tests/test_dtensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
OpSpec,
1717
OpStrategy,
1818
OutputSharding,
19-
OutputSpecType,
2019
RuntimeSchemaInfo,
2120
TupleStrategy,
2221
)
@@ -241,7 +240,7 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin
241240
# for ops that return multiple tensors and the output_specs is not
242241
# a tuple, we use a tuple of that single output spec as the new
243242
# output_specs
244-
output_specs: OutputSpecType = output_strategy.output_specs
243+
output_specs = output_strategy.output_specs
245244
if isinstance(output_specs, DTensorSpec):
246245
output_specs = tuple(
247246
[

0 commit comments

Comments
 (0)