forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhooks_for_testing.cpp
34 lines (27 loc) · 911 Bytes
/
hooks_for_testing.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <torch/csrc/jit/testing/hooks_for_testing.h>
#include <torch/csrc/jit/api/module.h>
namespace torch {
namespace jit {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static ModuleHook emit_module_callback;
void didFinishEmitModule(Module module) {
if (emit_module_callback) {
emit_module_callback(module);
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static FunctionHook emit_function_callback;
void didFinishEmitFunction(StrongFunctionPtr fn) {
if (emit_function_callback) {
emit_function_callback(fn);
}
}
void setEmitHooks(ModuleHook for_mod, FunctionHook for_fn) {
emit_module_callback = std::move(for_mod);
emit_function_callback = std::move(for_fn);
}
std::pair<ModuleHook, FunctionHook> getEmitHooks() {
return std::make_pair(emit_module_callback, emit_function_callback);
}
} // namespace jit
} // namespace torch