Skip to content

Commit e3acd4b

Browse files
authored
Remove stack decomposition and add stack rule (#271)
Using the stack decomposition leads to fewer sharding options. As an example, two S(0) tensors can be successfully stacked together at dimension 0, which should lead to a S(1) output sharding. If we instead keep the decomposition from https://github.com/pytorch/pytorch/blob/ded9bcd61a059bf723e6e84689552962b480ea77/torch/_refs/__init__.py#L4116, which first concatenates at the stack dim and then applies a view, we can't obtain the same sharding option. This is because stack has a stricter set of requirements as cat, which the decomposition makes us miss it. Once I removed the decomposition, I faced an issue that the propagation rules from stack aren't correctly implemented, so I had to re-implement it. I'm following a much simpler pattern for the propagation rules, which is to enumerate all possible sharding options and expand the mesh afterwards, which makes the implementation much simpler. This I believe is in-line with what @wconstab is doing for his refactoring of the propagation rules
1 parent c3fd25b commit e3acd4b

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

autoparallel/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def _get_decomp_table():
6363
decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default)
6464
decomp_table.pop(torch.ops.aten._softmax_backward_data.default)
6565
decomp_table.pop(torch.ops.aten._softmax.default)
66+
decomp_table.pop(torch.ops.aten.stack.default)
6667

6768
# decompose addmm to allow for TP on mm
6869
decomp_table.pop(torch.ops.aten.addmm.default)

autoparallel/propagation_rules.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,3 +843,49 @@ def scatter_strategy(mesh, op_schema: OpSchema):
843843
return expand_to_full_mesh_op_strategy(
844844
mesh, op_schema, single_mesh_dim_strategies, input_index=1
845845
)
846+
847+
848+
@register_opschema_rule(torch.ops.aten.stack.default)
849+
def stack_strategy(mesh, op_schema: OpSchema):
850+
from torch.distributed.tensor._ops._tensor_ops import (
851+
PlacementList,
852+
TupleStrategy,
853+
cast,
854+
expand_to_full_mesh_op_strategy,
855+
normalize_dim,
856+
)
857+
858+
input_tuple_strategy = op_schema.args_schema[0]
859+
assert isinstance(input_tuple_strategy, TupleStrategy)
860+
861+
num_input_tensor = len(input_tuple_strategy.children)
862+
first_input_strategy = input_tuple_strategy.children[0]
863+
assert isinstance(first_input_strategy, OpStrategy)
864+
common_input_ndim = first_input_strategy.ndim
865+
866+
dim = cast(int, op_schema.args_schema[1]) if len(op_schema.args_schema) > 1 else 0
867+
# normalize the dim to be within the common input ndim
868+
dim = normalize_dim(dim, common_input_ndim)
869+
870+
possible_input_strategies: PlacementList = [Replicate()] + [ # type: ignore[assignment]
871+
Shard(i) for i in range(common_input_ndim)
872+
]
873+
possible_output_strategies: PlacementList = (
874+
[Replicate()] # type: ignore[assignment]
875+
+ [Shard(i) for i in range(dim)]
876+
+ [Shard(i + 1) for i in range(dim, common_input_ndim)]
877+
)
878+
879+
single_mesh_dim_strategies = []
880+
for input_strategy, output_strategy in zip(
881+
possible_input_strategies, possible_output_strategies
882+
):
883+
strategy: PlacementList = [output_strategy] + [
884+
input_strategy
885+
] * num_input_tensor
886+
single_mesh_dim_strategies.append(strategy)
887+
888+
s = expand_to_full_mesh_op_strategy(
889+
mesh, op_schema, single_mesh_dim_strategies, input_index=1
890+
)
891+
return s

0 commit comments

Comments
 (0)