Skip to content

Commit 17eb649

Browse files
ezyangpytorchmergebot
authored andcommitted
Implement guard collectives (optimized version) (pytorch#156562)
This is a remix of pytorch#155558 Instead of mediating guard collective via a config option, in this one it's done via a `set_stance` like API. The motivation is that checking for the config value on entry on torch.compile is apparently quite expensive, according to functorch_maml_omniglot. So this makes it a bit cheaper. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#156562 Approved by: https://github.com/Microve
1 parent 7377291 commit 17eb649

File tree

10 files changed

+170
-4
lines changed

10 files changed

+170
-4
lines changed

docs/source/torch.compiler_api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
2121
list_backends
2222
disable
2323
set_stance
24+
set_enable_guard_collectives
2425
cudagraph_mark_step_begin
2526
is_compiling
2627
is_dynamo_compiling

test/distributed/test_dynamo_distributed.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch._dynamo.testing import collect_results
2626
from torch._dynamo.utils import same
2727
from torch._higher_order_ops.wrap import tag_activation_checkpoint
28+
from torch.compiler import set_enable_guard_collectives
2829
from torch.distributed._functional_collectives import _maybe_wrap_tensor
2930
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
3031
from torch.distributed.fsdp.wrap import (
@@ -61,6 +62,15 @@ def init_weights(m):
6162
m.bias.data.fill_(0.01)
6263

6364

65+
@contextmanager
66+
def enable_guard_collectives():
67+
old = set_enable_guard_collectives(True)
68+
try:
69+
yield
70+
finally:
71+
set_enable_guard_collectives(old)
72+
73+
6474
class ToyModel(nn.Module):
6575
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
6676
super().__init__()
@@ -1141,6 +1151,31 @@ def f(x):
11411151
for r in res[1:]:
11421152
self.assertEqual(res[0], r)
11431153

1154+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
1155+
@enable_guard_collectives()
1156+
def test_guard_collective(self):
1157+
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
1158+
torch._dynamo.utils.clear_compilation_metrics()
1159+
1160+
@torch.compile()
1161+
def f(x):
1162+
return x.sum()
1163+
1164+
x = torch.randn(10, device=self.rank)
1165+
f(x)
1166+
1167+
if self.rank == 0:
1168+
x = torch.randn(10, device=self.rank)
1169+
else:
1170+
x = torch.randn(12, device=self.rank) # recompile on one rank
1171+
f(x)
1172+
1173+
metrics = torch._dynamo.utils.get_compilation_metrics()
1174+
res = [None] * self.world_size
1175+
torch.distributed.all_gather_object(res, len(metrics))
1176+
for r in res[1:]:
1177+
self.assertEqual(res[0], r)
1178+
11441179
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
11451180
def test_get_pg_attr(self):
11461181
with _dynamo_dist_per_rank_init(self.rank, self.world_size):

torch/_C/_dynamo/eval_frame.pyi

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import enum
22
import types
3-
from typing import overload
3+
from typing import Optional, overload
44

5-
from torch._dynamo.types import DynamoCallback, DynamoGuardHook, GuardFn
5+
from torch._dynamo.types import (
6+
DynamoCallback,
7+
DynamoGuardCompleteHook,
8+
DynamoGuardHook,
9+
GuardFn,
10+
)
611

712
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
813
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
@@ -13,6 +18,9 @@ def set_code_exec_strategy(
1318
code: types.CodeType, strategy: _FrameExecStrategy
1419
) -> None: ...
1520
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
21+
def set_guard_complete_hook(
22+
hook: Optional[DynamoGuardCompleteHook],
23+
) -> Optional[DynamoGuardCompleteHook]: ...
1624
def raise_sigtrap() -> None: ...
1725

1826
class _CacheEntry:

torch/_dynamo/decorators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from torch._C._dynamo.eval_frame import ( # noqa: F401
4545
reset_code,
4646
set_eval_frame,
47+
set_guard_complete_hook,
4748
set_guard_error_hook,
4849
unsupported,
4950
)

torch/_dynamo/distributed.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323

2424
_COMPILE_PG: Optional[dist.ProcessGroup] = None
25+
_GUARD_PG: Optional[dist.ProcessGroup] = None
2526

2627

2728
def get_compile_pg() -> Optional[dist.ProcessGroup]:
@@ -39,3 +40,15 @@ def get_compile_pg() -> Optional[dist.ProcessGroup]:
3940
return _COMPILE_PG
4041

4142
return None
43+
44+
45+
# NB: Unlike get_compile_pg, this is only called when guard collectives were
46+
# explicitly requested
47+
def get_guard_pg() -> Optional[dist.ProcessGroup]:
48+
if dist.is_available() and dist.is_initialized():
49+
global _GUARD_PG
50+
if _GUARD_PG is None:
51+
_GUARD_PG = dist.distributed_c10d._new_group_with_tag(pg_tag="pt2_guard_pg")
52+
return _GUARD_PG
53+
54+
return None

torch/_dynamo/eval_frame.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
reset_code,
5959
set_code_exec_strategy,
6060
set_eval_frame,
61+
set_guard_complete_hook,
6162
set_guard_error_hook,
6263
set_skip_guard_eval_unsafe,
6364
unsupported,
@@ -90,7 +91,7 @@
9091
)
9192
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
9293

93-
from . import config, convert_frame, external_utils, trace_rules, utils
94+
from . import config, convert_frame, distributed, external_utils, trace_rules, utils
9495
from .backends.registry import CompilerFn, lookup_backend
9596
from .code_context import code_context
9697
from .exc import (
@@ -519,6 +520,38 @@ def _log_traced_frames():
519520
log.info(msg)
520521

521522

523+
def guard_collectives_hook(guard_eval_result):
524+
import torch.distributed as dist
525+
from torch._dynamo.utils import dynamo_timed
526+
527+
# guard_eval_result == True ==> cache hit
528+
if pg := distributed.get_guard_pg():
529+
with dynamo_timed(
530+
"guard_collective", log_pt2_compile_event=True, log_waitcounter=True
531+
):
532+
log.info("guard_collective %s", guard_eval_result)
533+
torch._logging.trace_structured(
534+
"artifact",
535+
metadata_fn=lambda: {
536+
"name": "guard_collective",
537+
"encoding": "string",
538+
},
539+
payload_fn=lambda: str(guard_eval_result),
540+
)
541+
# TODO: a bit awkward to time, this isn't inside of the dynamo compile region
542+
all_results = [None] * pg.size()
543+
dist.all_gather_object(all_results, guard_eval_result, group=pg)
544+
# True = everyone hit, OK to run
545+
# False = someone missed, force recompile everywhere
546+
res = all(all_results)
547+
log.info("guard_collective %s -> %s", guard_eval_result, res)
548+
return res
549+
return guard_eval_result
550+
551+
552+
_not_set = object()
553+
554+
522555
class _TorchDynamoContext:
523556
def __init__(
524557
self,

torch/_dynamo/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ def __call__(
114114
) -> None: ...
115115

116116

117+
class DynamoGuardCompleteHook(Protocol):
118+
def __call__(
119+
self,
120+
cache_hit: bool,
121+
) -> bool: ...
122+
123+
117124
class ProfilerStartHook(Protocol):
118125
def __call__(
119126
self,

torch/compiler/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"list_backends",
2222
"disable",
2323
"set_stance",
24+
"set_enable_guard_collectives",
2425
"cudagraph_mark_step_begin",
2526
"wrap_numpy",
2627
"is_compiling",
@@ -330,6 +331,35 @@ def bar():
330331
set_stance._dynamo_forbidden = True # type: ignore[attr-defined]
331332

332333

334+
def set_enable_guard_collectives(enabled: bool):
335+
"""
336+
Enables use of collectives *during* guard evaluation to synchronize behavior
337+
across ranks. This is expensive: we have to issue a collective every time
338+
we enter a compiled code region, even if no rank actually would need to
339+
compile. This can help prevent NCCL hangs by ensuring that we never have a
340+
situation where one rank starts recompiling while other ranks don't compile;
341+
it is especially useful in conjunction with enable_compiler_collectives
342+
where such a situation would immediately cause a hang (as it is necessary
343+
for all ranks to compile at the same time to run compiler collectives). Like
344+
compiler collectives, you can only run this on SPMD programs; you will hang
345+
otherwise. Note that a guard collective is only issued if there is any
346+
compiled code to guard on; if this the first time we encounter a frame or
347+
the frame is skipped, we don't issue collectives.
348+
349+
Returns the previous setting of enabled.
350+
"""
351+
from torch._C._dynamo.eval_frame import set_guard_complete_hook # noqa: F401
352+
from torch._dynamo.eval_frame import guard_collectives_hook
353+
354+
if enabled:
355+
return set_guard_complete_hook(guard_collectives_hook) is not None
356+
else:
357+
return set_guard_complete_hook(None) is not None
358+
359+
360+
set_enable_guard_collectives._dynamo_forbidden = True # type: ignore[attr-defined]
361+
362+
333363
def cudagraph_mark_step_begin():
334364
"""
335365
Indicates that a new iteration of inference or training is about to begin.

torch/csrc/dynamo/eval_frame.c

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <torch/csrc/utils/python_compat.h>
1212

1313
PyObject* guard_error_hook = NULL;
14+
PyObject* guard_complete_hook = NULL;
1415

1516
typedef struct {
1617
int active_dynamo_threads;
@@ -626,6 +627,22 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) {
626627
Py_RETURN_NONE;
627628
}
628629

630+
static PyObject* set_guard_complete_hook(PyObject* dummy, PyObject* obj) {
631+
PyObject* old_hook = guard_complete_hook;
632+
633+
if (obj == Py_None) {
634+
obj = NULL;
635+
}
636+
637+
guard_complete_hook = Py_XNewRef(obj);
638+
639+
if (old_hook == NULL) {
640+
Py_RETURN_NONE;
641+
} else {
642+
return old_hook;
643+
}
644+
}
645+
629646
// Debugging function for GNU C only.
630647
// Used to set gdb breakpoints in hot CPython sites from Python.
631648
// Code example:
@@ -666,6 +683,7 @@ static PyMethodDef _methods[] = {
666683
{"unsupported", unsupported, METH_VARARGS, NULL},
667684
{"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL},
668685
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
686+
{"set_guard_complete_hook", set_guard_complete_hook, METH_O, NULL},
669687
{"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL},
670688
{NULL, NULL, 0, NULL}};
671689

torch/csrc/dynamo/eval_frame_cpp.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#include <torch/csrc/dynamo/framelocals_mapping.h>
88
#include <torch/csrc/utils/python_compat.h>
99

10+
extern "C" {
11+
extern PyObject* guard_complete_hook;
12+
}
13+
1014
static constexpr const char* cache_lookup_profiler_str =
1115
"TorchDynamo Cache Lookup";
1216

@@ -197,7 +201,23 @@ PyObject* dynamo__custom_eval_frame(
197201
// guard eval failed, keep propagating
198202
fail();
199203
return eval_result;
200-
} else if (maybe_cached_code != Py_None) {
204+
}
205+
206+
// NB: We only do guard collectives when there are any compiled code entries
207+
// at all; these reduces overtriggering and we don't need to do guard
208+
// collectives the very first time we've seen a frame
209+
// TODO: We could also check if we had just created extra for the first
210+
// time? Not too sure the best condition for extra->cache_entry_list
211+
if (guard_complete_hook != nullptr && !extra->cache_entry_list.empty()) {
212+
py::handle guard_complete_hook_handle(guard_complete_hook);
213+
// False means force compilation (someone cache missed)
214+
py::object res = guard_complete_hook_handle(maybe_cached_code != Py_None);
215+
if (!py::cast<bool>(res)) {
216+
maybe_cached_code = Py_None; // NB: non-owning
217+
}
218+
}
219+
220+
if (maybe_cached_code != Py_None) {
201221
cached_code = (PyCodeObject*)maybe_cached_code;
202222
// used cached version
203223
DEBUG_TRACE("cache hit %s", get_frame_name(frame));

0 commit comments

Comments
 (0)