Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions autoparallel/asynctp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
PatternMatcherPass,
)
from torch._logging import trace_structured
from torch.distributed.distributed_c10d import GroupName
from torch.utils._ordered_set import OrderedSet

import autoparallel.asynctp_ops # noqa: F401
Expand Down Expand Up @@ -89,7 +90,7 @@ class _AllGatherMatch:
ag_node: torch.fx.Node
res_node: torch.fx.Node
gather_dim: int
group_name: str
group_name: GroupName

def replace_with(self, new_node: torch.fx.Node) -> None:
self.res_node.replace_all_uses_with(new_node)
Expand Down Expand Up @@ -225,7 +226,7 @@ class _ReduceScatterMatch:
wait_tensor_node: torch.fx.Node
reduce_op: str
scatter_dim: int
group_name: str
group_name: GroupName

def replace_with(self, new_node: torch.fx.Node) -> None:
# Replace all uses of the result node (wait_tensor) with the fused node.
Expand Down Expand Up @@ -643,7 +644,7 @@ def _insert_fused_all_gather_matmul(
matmuls: list[_Matmul],
shard_node: torch.fx.Node,
gather_dim: int,
group_name: str,
group_name: GroupName,
) -> torch.fx.Node:
mm_types = OrderedSet(map(type, matmuls))
assert len(mm_types) == 1
Expand Down Expand Up @@ -704,7 +705,7 @@ def _insert_fused_all_gather_transpose_matmul(
matmuls: list[_Matmul],
shard_node: torch.fx.Node,
gather_dim: int,
group_name: str,
group_name: GroupName,
) -> torch.fx.Node:
mm_types = OrderedSet(map(type, matmuls))
assert len(mm_types) == 1
Expand Down Expand Up @@ -974,7 +975,7 @@ def _insert_fused_matmul_reduce_scatter(
matmul: _Matmul,
reduce_op: str,
orig_scatter_dim: int,
group_name: str,
group_name: GroupName,
scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern
output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern
) -> torch.fx.Node:
Expand Down
51 changes: 27 additions & 24 deletions autoparallel/asynctp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from torch._C._distributed_c10d import _register_work, _SymmetricMemory
from torch.distributed._symmetric_memory import get_symm_mem_workspace, rendezvous

# Import GroupName for type checking
GroupName = c10d.GroupName

_is_test_mode: bool = False
_mocked_group_names: set[str] | None = None
_backend_streams: dict[int, torch.cuda.Stream] = {}
Expand All @@ -35,7 +38,7 @@ def _pipelined_multi_all_gather_and_consume(
shard: list[torch.Tensor],
shard_consumer: Callable[[list[torch.Tensor], int], None],
ag_out: list[torch.Tensor],
group_name: str,
group_name: GroupName,
ag_out_needed: bool = True,
) -> None:
"""
Expand Down Expand Up @@ -181,7 +184,7 @@ def _pipelined_all_gather_and_consume(
shard: torch.Tensor,
shard_consumer: Callable[[torch.Tensor, int], None],
ag_out: torch.Tensor,
group_name: str,
group_name: GroupName,
ag_out_needed: bool = True,
) -> None:
"""
Expand Down Expand Up @@ -209,7 +212,7 @@ def adapter(shard: list[torch.Tensor], rank: int) -> None:
def _pipelined_produce_and_all2all(
chunk_producer: Callable[[int, torch.Tensor], None],
output: torch.Tensor,
group_name: str,
group_name: GroupName,
out_chunk_dim=0,
) -> None:
"""
Expand Down Expand Up @@ -400,7 +403,7 @@ def _fused_all_gather_matmul_impl(
kwargs_list: list[dict[str, Any]],
out_dtypes: list[torch.dtype | None],
gather_dim: int,
group_name: str,
group_name: GroupName,
return_A: bool,
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
if A_shard.dim() < 2:
Expand Down Expand Up @@ -534,7 +537,7 @@ def _pipelined_all_gather_and_consume_last_dim(
shard: torch.Tensor,
shard_consumer: Callable[[torch.Tensor, int], None],
ag_out: torch.Tensor,
group_name: str,
group_name: GroupName,
ag_out_needed: bool = True,
) -> None:
p2p_workspace_size_req = 0
Expand Down Expand Up @@ -607,7 +610,7 @@ def _fused_all_gather_matmul_last_gather_dim_impl(
kwargs_list: list[dict[str, Any]],
out_dtypes: list[torch.dtype | None],
gather_dim: int,
group_name: str,
group_name: GroupName,
return_A: bool,
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
assert gather_dim == A_shard.ndim - 1
Expand Down Expand Up @@ -674,7 +677,7 @@ def _fused_all_gather_matmul_fallback(
A_shard: torch.Tensor,
Bs: list[torch.Tensor],
gather_dim: int,
group_name: str,
group_name: GroupName,
*,
return_A: bool = True,
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
Expand Down Expand Up @@ -705,7 +708,7 @@ def _fused_all_gather_matmul(
A_shard: torch.Tensor,
Bs: list[torch.Tensor],
gather_dim: int,
group_name: str,
group_name: GroupName,
*,
return_A: bool = True,
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
Expand Down Expand Up @@ -756,7 +759,7 @@ def _should_use_fused_all_gather_matmul_native(
A_shard: torch.Tensor,
Bs: list[torch.Tensor],
gather_dim: int,
group_name: str,
group_name: GroupName,
) -> bool:
group = c10d._resolve_process_group(group_name)
local_M = math.prod(A_shard.shape[:-1])
Expand All @@ -778,7 +781,7 @@ def _should_use_fused_all_gather_matmul_native(
def _fused_all_gather_matmul_native(
A_shard: torch.Tensor,
B: torch.Tensor,
group_name: str,
group_name: GroupName,
) -> tuple[torch.Tensor, torch.Tensor]:
symm_mem = rendezvous(A_shard, group_name)
if symm_mem is None:
Expand Down Expand Up @@ -832,7 +835,7 @@ def _fused_all_gather_matmul_native(
def _should_use_multimem_all_gather_matmul(
A_shard: torch.Tensor,
gather_dim: int,
group_name: str,
group_name: GroupName,
return_A: bool,
) -> bool:
group = c10d._resolve_process_group(group_name)
Expand All @@ -858,7 +861,7 @@ def _should_use_multimem_all_gather_matmul(
def _multimem_all_gather_matmul(
A_shard: torch.Tensor,
Bs: list[torch.Tensor],
group_name: str,
group_name: GroupName,
) -> list[torch.Tensor]:
group = c10d._resolve_process_group(group_name)
A_shape = torch.Size((A_shard.shape[0] * group.size(), *A_shard.shape[1:]))
Expand All @@ -877,7 +880,7 @@ def _fused_all_gather_scaled_matmul_fallback(
A_scale: torch.Tensor,
B_scales: list[torch.Tensor],
gather_dim: int,
group_name: str,
group_name: GroupName,
biases: list[torch.Tensor | None],
result_scales: list[torch.Tensor | None],
out_dtypes: list[torch.dtype | None],
Expand Down Expand Up @@ -951,7 +954,7 @@ def _fused_all_gather_scaled_matmul(
A_scale: torch.Tensor,
B_scales: list[torch.Tensor],
gather_dim: int,
group_name: str,
group_name: GroupName,
biases: list[torch.Tensor | None],
result_scales: list[torch.Tensor | None],
out_dtypes: list[torch.dtype | None],
Expand Down Expand Up @@ -1057,7 +1060,7 @@ def _fused_matmul_reduce_scatter(
B: torch.Tensor,
reduce_op: str,
scatter_dim: int,
group_name: str,
group_name: GroupName,
) -> torch.Tensor:
"""
Perform the following logic with micro-pipelined computation and
Expand Down Expand Up @@ -1093,7 +1096,7 @@ def _fused_matmul_reduce_scatter_fallback(
B: torch.Tensor,
reduce_op: str,
scatter_dim: int,
group_name: str,
group_name: GroupName,
) -> torch.Tensor:
res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name)
res = funcol.wait_tensor(res)
Expand All @@ -1108,7 +1111,7 @@ def _fused_matmul_reduce_scatter_impl(
out_dtype: torch.dtype | None,
reduce_op: str,
scatter_dim: int,
group_name: str,
group_name: GroupName,
) -> torch.Tensor:
if A.dim() < 2:
raise ValueError("A_shard must be a matrix")
Expand Down Expand Up @@ -1194,7 +1197,7 @@ def _fused_scaled_matmul_reduce_scatter(
reduce_op: str,
orig_scatter_dim: int,
scatter_dim_after_maybe_reshape: int,
group_name: str,
group_name: GroupName,
output_shape: list[int],
bias: torch.Tensor | None = None,
result_scale: torch.Tensor | None = None,
Expand Down Expand Up @@ -1248,7 +1251,7 @@ def _fused_scaled_matmul_reduce_scatter_fallback(
reduce_op: str,
orig_scatter_dim: int,
scatter_dim_after_maybe_reshape: int,
group_name: str,
group_name: GroupName,
output_shape: list[int],
bias: torch.Tensor | None = None,
result_scale: torch.Tensor | None = None,
Expand Down Expand Up @@ -1300,7 +1303,7 @@ def _fused_scaled_matmul_reduce_scatter_impl(
reduce_op: str,
orig_scatter_dim: int,
scatter_dim_after_maybe_reshape: int,
group_name: str,
group_name: GroupName,
output_shape: list[int],
) -> torch.Tensor:
if A.dim() < 2:
Expand Down Expand Up @@ -1524,7 +1527,7 @@ def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool:
@torch.library.impl(lib, "_low_contention_all_gather", "Meta")
def _low_contention_all_gather_meta(
tensor: torch.Tensor,
group_name: str,
group_name: GroupName,
) -> torch.Tensor:
group_size = c10d._get_group_size_by_name(group_name)
return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:])
Expand All @@ -1533,7 +1536,7 @@ def _low_contention_all_gather_meta(
@torch.library.impl(lib, "_low_contention_all_gather", "CUDA")
def _low_contention_all_gather(
tensor: torch.Tensor,
group_name: str,
group_name: GroupName,
) -> torch.Tensor:
"""
Performs all-gather with symmetric memory in a low-contention fashion.
Expand Down Expand Up @@ -1582,7 +1585,7 @@ def _low_contention_all_gather(
def _low_contention_reduce_scatter_meta(
tensor: torch.Tensor,
reduce_op: str,
group_name: str,
group_name: GroupName,
) -> torch.Tensor:
group_size = c10d._get_group_size_by_name(group_name)
return tensor.unflatten(0, (group_size, -1)).mean(dim=0)
Expand Down Expand Up @@ -1665,7 +1668,7 @@ def _low_contention_reduce_scatter_with_workspace(
def _low_contention_reduce_scatter(
tensor: torch.Tensor,
reduce_op: str,
group_name: str,
group_name: GroupName,
) -> torch.Tensor:
"""
Performs reduce-scatter with symmetric memory in a low-contention fashion.
Expand Down
11 changes: 7 additions & 4 deletions autoparallel/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import torch.distributed.distributed_c10d as c10d
from torch.distributed._tensor.experimental import local_map as _local_map

# Import GroupName for type checking
GroupName = c10d.GroupName

_local_map_device_mesh = None


Expand Down Expand Up @@ -51,7 +54,7 @@ def axis_index(axis_name):
def _all_gather_tensor(
x: torch.Tensor,
gather_dim: int,
group_name: str,
group_name: GroupName,
) -> torch.Tensor:
x = x.contiguous()
group_size = c10d._get_group_size_by_name(group_name)
Expand All @@ -67,7 +70,7 @@ def _all_gather_tensor(


def _reduce_scatter_tensor(
self: torch.Tensor, reduceOp: str, scatter_dim: int, group_name: str
self: torch.Tensor, reduceOp: str, scatter_dim: int, group_name: GroupName
):
group_size = c10d._get_group_size_by_name(group_name)

Expand All @@ -88,7 +91,7 @@ def _reduce_scatter_tensor(
return res


def _all_reduce(self: torch.Tensor, reduceOp: str, group_name: str):
def _all_reduce(self: torch.Tensor, reduceOp: str, group_name: GroupName):
tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
res = torch.ops._c10d_functional.wait_tensor(tensor)
return res
Expand All @@ -98,7 +101,7 @@ def _all_to_all(
self: torch.Tensor,
output_split_sizes: Optional[list[int]],
input_split_sizes: Optional[list[int]],
group_name: str,
group_name: GroupName,
):
group_size = c10d._get_group_size_by_name(group_name)
if output_split_sizes is None or input_split_sizes is None:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
OpSpec,
OpStrategy,
OutputSharding,
OutputSpecType,
RuntimeSchemaInfo,
TupleStrategy,
)
Expand Down Expand Up @@ -241,7 +240,7 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin
# for ops that return multiple tensors and the output_specs is not
# a tuple, we use a tuple of that single output spec as the new
# output_specs
output_specs: OutputSpecType = output_strategy.output_specs
output_specs = output_strategy.output_specs
if isinstance(output_specs, DTensorSpec):
output_specs = tuple(
[
Expand Down
Loading