Skip to content

Commit c29c8fd

Browse files
authored
Improve PP->SS cost (#214)
Before we were only considering the PP->S(0)S(0) case, but the same logic applies to all shardings. We should generalize this to take the smallest cost into account, irrespective of mesh dims, but we should do this is a way that is efficient
1 parent cb3059e commit c29c8fd

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

autoparallel/collective_runtime_estimation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,8 @@ def redistribute_cost(
138138

139139
def estimate_strategy_comms_cost(src_spec, tgt_spec):
140140
order = list(range(src_spec.mesh.ndim))
141-
if src_spec.placements == (Partial(), Partial()) and tgt_spec.placements == (
142-
Shard(0),
143-
Shard(0),
141+
if src_spec.placements == (Partial(), Partial()) and all(
142+
p.is_shard() for p in tgt_spec.placements
144143
):
145144
order = [1, 0]
146145
comms_cost = redistribute_cost(src_spec, tgt_spec, order)

0 commit comments

Comments
 (0)