forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpass_manager.h
139 lines (120 loc) · 4.59 KB
/
pass_manager.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#pragma once
#include <torch/csrc/jit/ir/ir.h>
/* `getCustomPrePasses()` returns a vector of passes that will be executed
* after differentiation but before any fusion. This is the de-facto location
* for compiler backends to insert passes.
*
* `getCustomPostPasses()` returns a vector of passes that will be
* executed after differentiation and after fusion (if any). This is the
* location for fusion cleanup passes if they are needed.
*
* Static registration of a pass can be done by creating a global
* `Register{Pre,Post}Pass r(Pass)` variable in a compilation unit.
*
* pass_manager.h uses a Meyer's singleton to store a vector of `Pass`es, which
* modify the IR graph in place.
*/
namespace torch {
namespace jit {
// A pass modifies a Graph in place.
using GraphPass = std::function<void(std::shared_ptr<Graph>&)>;
// Since Passes are std::functions, we associate a UUID to each pass, this way
// if we want to deregister a pass, we have something to reference it by.
using GraphPassNameType = unsigned int;
// Start UUID at 1
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static GraphPassNameType graphPassID = 1;
// Graph pass entries have a name associated with them
using GraphPassEntry = std::pair<GraphPass, GraphPassNameType>;
// Return currently registered passes. Passes are stored in a static vector
TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>&
getCustomPostPasses();
TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>&
getCustomPrePasses();
TORCH_API GraphPassNameType registerPostPass(GraphPass p);
TORCH_API GraphPassNameType registerPrePass(GraphPass p);
// Look up pass by name passed in, remove it from registered passes
TORCH_API void clearPostPass(GraphPassNameType p);
TORCH_API void clearPrePass(GraphPassNameType p);
// Remove all passes
TORCH_API void clearAllPostPasses();
TORCH_API void clearAllPrePasses();
// LEGACY CALL
struct TORCH_API RegisterPostPass {
RegisterPostPass(GraphPass p);
};
using RegisterPass = RegisterPostPass;
/*
* PassManager is a wrapper on the register/clear PostPass functions above. It
* will register the pass provided in "registerPass" and will hold on to its
* associated name that way clearPass can be later called and will delete the
* pass used to register when called.
*
* PassManager is templated because we want static variables based on a
* particular GraphPass. When deriving from PassManager, you should send as the
* template parameter your derived class as you would for the curiously
* recurring template pattern. This template parameter isn't actually used and
* is simply done to prevent static members from being shared across derived
* types.
*/
template <typename DerivedType>
struct C10_EXPORT PassManager {
private:
// We want this class to be abstract because it's
virtual void abstract() = 0;
protected:
/*
* isRegistered() will return if a pass has been registered
* isRegistered(true) will change the value of the internal static bool
*
* There's an internal static bool to this function to keep track of the
* state, this is so when functions are derived from this class, they don't
* have to worry about initializing the static members.
*/
static bool isRegistered(bool flip_bit = false) {
static bool val = false;
if (flip_bit)
val = !val;
return val;
}
/*
* name() will return the name of the registered pass
* name(pass_name, true) will set the name of the pass
* Similarly to isRegistered we use an internal static variable to hold the
* name.
*/
static GraphPassNameType passID(
GraphPassNameType PassID = 0,
bool set = false) {
static GraphPassNameType pass_id = 0;
if (set)
pass_id = PassID;
return pass_id;
}
public:
// registerPass(pass) will register the pass provided and set the
// name/isRegistered functions appropriately, it returns a bool value
// indicating whether the given pass is already registered previously.
static bool registerPass(GraphPass p) {
if (!isRegistered()) {
// If we don't already have a registered pass, register pass
// hold on to its name, change isRegistered to true
passID(registerPostPass(std::move(p)), true);
isRegistered(true);
return false;
}
return true;
}
// Calls ClearPostPass(passID())
static void clearPass() {
// If the pass is registered, clear it and change isRegistered to false.
if (isRegistered()) {
clearPostPass(passID());
isRegistered(true);
}
}
// clang-tidy requires virtual destructor;
virtual ~PassManager() = default;
};
} // namespace jit
} // namespace torch