|
25 | 25 | from torch._dynamo.testing import collect_results |
26 | 26 | from torch._dynamo.utils import same |
27 | 27 | from torch._higher_order_ops.wrap import tag_activation_checkpoint |
| 28 | +from torch.compiler import set_enable_guard_collectives |
28 | 29 | from torch.distributed._functional_collectives import _maybe_wrap_tensor |
29 | 30 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
30 | 31 | from torch.distributed.fsdp.wrap import ( |
@@ -61,6 +62,15 @@ def init_weights(m): |
61 | 62 | m.bias.data.fill_(0.01) |
62 | 63 |
|
63 | 64 |
|
| 65 | +@contextmanager |
| 66 | +def enable_guard_collectives(): |
| 67 | + old = set_enable_guard_collectives(True) |
| 68 | + try: |
| 69 | + yield |
| 70 | + finally: |
| 71 | + set_enable_guard_collectives(old) |
| 72 | + |
| 73 | + |
64 | 74 | class ToyModel(nn.Module): |
65 | 75 | def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None): |
66 | 76 | super().__init__() |
@@ -1141,6 +1151,31 @@ def f(x): |
1141 | 1151 | for r in res[1:]: |
1142 | 1152 | self.assertEqual(res[0], r) |
1143 | 1153 |
|
| 1154 | + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") |
| 1155 | + @enable_guard_collectives() |
| 1156 | + def test_guard_collective(self): |
| 1157 | + with _dynamo_dist_per_rank_init(self.rank, self.world_size): |
| 1158 | + torch._dynamo.utils.clear_compilation_metrics() |
| 1159 | + |
| 1160 | + @torch.compile() |
| 1161 | + def f(x): |
| 1162 | + return x.sum() |
| 1163 | + |
| 1164 | + x = torch.randn(10, device=self.rank) |
| 1165 | + f(x) |
| 1166 | + |
| 1167 | + if self.rank == 0: |
| 1168 | + x = torch.randn(10, device=self.rank) |
| 1169 | + else: |
| 1170 | + x = torch.randn(12, device=self.rank) # recompile on one rank |
| 1171 | + f(x) |
| 1172 | + |
| 1173 | + metrics = torch._dynamo.utils.get_compilation_metrics() |
| 1174 | + res = [None] * self.world_size |
| 1175 | + torch.distributed.all_gather_object(res, len(metrics)) |
| 1176 | + for r in res[1:]: |
| 1177 | + self.assertEqual(res[0], r) |
| 1178 | + |
1144 | 1179 | @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") |
1145 | 1180 | def test_get_pg_attr(self): |
1146 | 1181 | with _dynamo_dist_per_rank_init(self.rank, self.world_size): |
|
0 commit comments