forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_compatibility.cpp
178 lines (156 loc) · 6.66 KB
/
model_compatibility.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#include <ATen/core/ivalue.h>
#include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/model_compatibility.h>
#include <torch/csrc/jit/serialization/import_read.h>
#include <string>
#include <vector>
namespace torch {
namespace jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
c10::IValue readArchive(
const std::string& archive_name,
PyTorchStreamReader& stream_reader) {
c10::optional<at::Device> device;
std::shared_ptr<CompilationUnit> compilation_unit =
std::make_shared<CompilationUnit>();
// TODO (T90180710): Simplify type_resolver and obj_loader when getting
// bytecode version from model
auto type_resolver = [&](const c10::QualifiedName& qn) {
return typeResolverMobile(qn, compilation_unit);
};
std::shared_ptr<mobile::CompilationUnit> mobile_compilation_unit =
std::make_shared<mobile::CompilationUnit>();
auto obj_loader = [&](at::StrongTypePtr type, IValue input) {
return objLoaderMobile(type, input, mobile_compilation_unit);
};
bool bytecode_tensor_in_constants_archive =
(archive_name == "bytecode" && !isTensorInBytecodeArchive(stream_reader));
auto ivalues = torch::jit::readArchiveAndTensors(
archive_name,
/*pickle_prefix=*/"",
/*tensor_prefix=*/
bytecode_tensor_in_constants_archive ? "constants/" : "",
type_resolver,
obj_loader,
device,
stream_reader);
return ivalues;
}
std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
std::vector<IValue> bytecode_values;
bytecode_values = readArchive("bytecode", reader).toTuple()->elements();
return bytecode_values;
}
/********************** Bytecode **********************/
// Forward declare
int64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues);
int64_t _get_model_bytecode_version(std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return _get_model_bytecode_version(std::move(rai));
}
int64_t _get_model_bytecode_version(const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return _get_model_bytecode_version(std::move(rai));
}
int64_t _get_model_bytecode_version(std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_WARN(
"The input model might not be generated from _save_for_mobile()");
return -1;
}
PyTorchStreamReader reader(std::move(rai));
auto bytecode_values = get_bytecode_ivalues(reader);
return _get_model_bytecode_version(bytecode_values);
}
int64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues) {
if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) {
int64_t model_version = bytecode_ivalues[0].toInt();
return model_version;
}
TORCH_WARN("Fail to get bytecode version.");
return -1;
}
/********************** Operators and Info **********************/
// Forward declare
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::vector<IValue> bytecode_ivalues);
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::istream& in) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return _get_model_ops_and_info(std::move(rai));
}
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
const std::string& filename) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return _get_model_ops_and_info(std::move(rai));
}
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_WARN("Failed to open zip file for model ops.");
return std::unordered_map<std::string, OperatorInfo>{};
}
PyTorchStreamReader reader(std::move(rai));
auto bytecode_values = get_bytecode_ivalues(reader);
return _get_model_ops_and_info(bytecode_values);
}
/* A function to retrieve the root (top level) operators of a model and their
* corresponding compatibility info. These root operators can call other
* operators within them (traced ops), and a root op can call many different
* traced ops depending on internal code paths in the root op. These traced ops
* are not returned by this function. Those operators are abstracted into the
* runtime as an implementation detail (and the traced ops themselves can also
* call other operators) making retrieving them difficult and their value from
* this api negligible since they will differ between which runtime version the
* model is run on. Because of this, there is a false positive this api can't
* prevent in a compatibility usecase. All the root ops of a model are present
* in a target runtime, but not all the traced ops are which prevents a model
* from being able to run.
**/
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
std::vector<IValue> bytecode_ivalues) {
constexpr uint64_t min_version_with_schema = 6;
if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) {
TORCH_WARN(
"Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it");
}
std::unordered_map<std::string, OperatorInfo> result;
if (bytecode_ivalues.empty()) {
TORCH_WARN("Failed to get model ops and info.");
return result;
}
// loop over all the functions in the bytecode
for (int i = 1; i < bytecode_ivalues.size(); i++) {
// descend to the operators list
auto method_tuple = bytecode_ivalues.at(i).toTuple()->elements();
auto operators_tuple = method_tuple.at(1).toTuple()->elements()[1];
auto operators = operators_tuple.toTuple()->elements()[1];
for (auto& op_tuple : operators.toTuple()->elements()) {
auto op = op_tuple.toTuple()->elements();
// grab name
std::string op_name = op.at(0).toStringRef();
std::string op_overload_name = op.at(1).toStringRef();
if (op_overload_name != "") {
op_name.append(".");
op_name.append(op_overload_name);
}
// grab schema size
if (op.size() > 2) {
result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()});
} else { // no schema information use default
result.emplace(op_name, OperatorInfo{});
}
}
}
return result;
}
} // namespace jit
} // namespace torch