Commit e3acd4b
authored
Remove stack decomposition and add stack rule (#271)
Using the stack decomposition leads to fewer sharding options. As an example, two S(0) tensors can be successfully stacked together at dimension 0, which should lead to a S(1) output sharding. If we instead keep the decomposition from https://github.com/pytorch/pytorch/blob/ded9bcd61a059bf723e6e84689552962b480ea77/torch/_refs/__init__.py#L4116, which first concatenates at the stack dim and then applies a view, we can't obtain the same sharding option. This is because stack has a stricter set of requirements as cat, which the decomposition makes us miss it. Once I removed the decomposition, I faced an issue that the propagation rules from stack aren't correctly implemented, so I had to re-implement it. I'm following a much simpler pattern for the propagation rules, which is to enumerate all possible sharding options and expand the mesh afterwards, which makes the implementation much simpler. This I believe is in-line with what @wconstab is doing for his refactoring of the propagation rules1 parent c3fd25b commit e3acd4b
2 files changed
+47
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
| 66 | + | |
66 | 67 | | |
67 | 68 | | |
68 | 69 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
843 | 843 | | |
844 | 844 | | |
845 | 845 | | |
| 846 | + | |
| 847 | + | |
| 848 | + | |
| 849 | + | |
| 850 | + | |
| 851 | + | |
| 852 | + | |
| 853 | + | |
| 854 | + | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
| 858 | + | |
| 859 | + | |
| 860 | + | |
| 861 | + | |
| 862 | + | |
| 863 | + | |
| 864 | + | |
| 865 | + | |
| 866 | + | |
| 867 | + | |
| 868 | + | |
| 869 | + | |
| 870 | + | |
| 871 | + | |
| 872 | + | |
| 873 | + | |
| 874 | + | |
| 875 | + | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
| 883 | + | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
0 commit comments