Skip to content

Commit 9f4bcdf

Browse files
jerryzh168facebook-github-bot
authored andcommitted
caffe2::DeviceType -> at::DeviceType (pytorch#11254)
Summary: Pull Request resolved: pytorch#11254 Previously we use DeviceType in caffe2.proto directly, but it's an `enum` and have implicit conversion to int, which does not have type safety, e.g. we have to explicitly check for a device type is valid in event.h: ``` template <int d> struct EventCreateFunctionRegisterer { explicit EventCreateFunctionRegisterer(EventCreateFunction f) { static_assert(d < MaxDeviceTypes, ""); Event::event_creator_[d] = f; } }; ``` at::DeviceType is an `enum class`, and it does not have implicit conversion to int, and provides better type safety guarantees. In this diff we have done the following refactor(taking CPU as an example): 1. caffe2::DeviceType → caffe2::DeviceTypeProto 2. caffe2::CPU → caffe2::PROTO_CPU 3. caffe2::DeviceType = at::DeviceType 4. caffe2::CPU = at::DeviceType::CPU codemod -d caffe2/caffe2 --extensions h,cc,cpp 'device_type\(\), ' 'device_type(), PROTO_' + some manual changes In short, after this diff, in c++, caffe2::CPU refers to the at::DeviceType::CPU and the old proto caffe2::CPU will be caffe2::PROTO_CPU. In python side, we have a temporary workaround that alias `caffe2_pb2.CPU = caffe2_pb2.PROOT_CPU` to make the change easier to review and this will be removed later. Reviewed By: ezyang Differential Revision: D9545704 fbshipit-source-id: 461a28a4ca74e616d3ee183a607078a717fd38a7
1 parent ac9f0a6 commit 9f4bcdf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+382
-272
lines changed

binaries/benchmark_helper.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ bool backendCudaSet(const string& backend) {
6868
void setDeviceType(caffe2::NetDef* net_def, caffe2::DeviceType& run_dev) {
6969
for (int j = 0; j < net_def->op_size(); j++) {
7070
caffe2::OperatorDef* op = net_def->mutable_op(j);
71-
op->mutable_device_option()->set_device_type(run_dev);
71+
op->mutable_device_option()->set_device_type(caffe2::TypeToProto(run_dev));
7272
}
7373
}
7474

binaries/core_overhead_benchmark_gpu.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ static void BM_OperatorCreationCPU(benchmark::State& state) {
167167
OperatorDef def;
168168
Workspace ws;
169169
def.set_type("DummyEmpty");
170-
def.mutable_device_option()->set_device_type(CPU);
170+
def.mutable_device_option()->set_device_type(PROTO_CPU);
171171
while (state.KeepRunning()) {
172172
op = CreateOperator(def, &ws);
173173
}
@@ -180,7 +180,7 @@ static void BM_OperatorCreationCUDA(benchmark::State& state) {
180180
OperatorDef def;
181181
Workspace ws;
182182
def.set_type("DummyEmpty");
183-
def.mutable_device_option()->set_device_type(CUDA);
183+
def.mutable_device_option()->set_device_type(PROTO_CUDA);
184184
while (state.KeepRunning()) {
185185
op = CreateOperator(def, &ws);
186186
}

binaries/print_registered_core_operators.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ int main(int argc, char** argv) {
5252
for (const auto& pair : *caffe2::gDeviceTypeRegistry()) {
5353
std::cout << "Device type " << pair.first
5454
#ifndef CAFFE2_USE_LITE_PROTO
55-
<< " (" << caffe2::DeviceType_Name(
56-
static_cast<caffe2::DeviceType>(pair.first))
55+
<< " ("
56+
<< at::DeviceTypeName(static_cast<caffe2::DeviceType>(pair.first))
5757
<< ")"
5858
#endif
5959
<< std::endl;

caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace {
1919
static void AddConstInput(const std::vector<int>& shape, const float value,
2020
const string& name, Workspace* ws) {
2121
DeviceOption option;
22-
option.set_device_type(CUDA);
22+
option.set_device_type(PROTO_CUDA);
2323
CUDAContext context(option);
2424
Blob* blob = ws->CreateBlob(name);
2525
auto* tensor = blob->GetMutableTensor(CUDA);
@@ -43,7 +43,7 @@ TEST(NervanaFullyConnectedTest, Test) {
4343
def.add_input("W");
4444
def.add_input("B");
4545
def.add_output("Y");
46-
def.mutable_device_option()->set_device_type(CUDA);
46+
def.mutable_device_option()->set_device_type(PROTO_CUDA);
4747
def.set_engine("NERVANA");
4848
AddConstInput(std::vector<int>{5, 10}, 1., "X", &ws);
4949
AddConstInput(std::vector<int>{6, 10}, 1., "W", &ws);

caffe2/contrib/opencl/context.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class OpenCLContext final {
3636
public:
3737
explicit OpenCLContext();
3838
explicit OpenCLContext(const DeviceOption& option) {
39-
DCHECK_EQ(option.device_type(), OPENCL);
39+
DCHECK_EQ(option.device_type(), PROTO_OPENCL);
4040
OpenCLContext();
4141
}
4242
~OpenCLContext() {}

caffe2/core/blob_gpu_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ TEST(TensorTest, TensorSerializationMultiDevices) {
193193
EXPECT_EQ(tensor_proto.float_data(i), i);
194194
}
195195
EXPECT_TRUE(tensor_proto.has_device_detail());
196-
EXPECT_EQ(tensor_proto.device_detail().device_type(), CUDA);
196+
EXPECT_EQ(tensor_proto.device_detail().device_type(), PROTO_CUDA);
197197
EXPECT_EQ(tensor_proto.device_detail().cuda_gpu_id(), gpu_id);
198198
// Test if the restored blob is still of the same device.
199199
blob.Reset();

caffe2/core/blob_test.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ TYPED_TEST(TypedTensorTest, BigTensorSerialization) {
883883

884884
{
885885
DeviceOption option;
886-
option.set_device_type(CPU);
886+
option.set_device_type(PROTO_CPU);
887887
Argument db_type_arg = MakeArgument<string>("db_type", "vector_db");
888888
Argument absolute_path_arg = MakeArgument<bool>("absolute_path", true);
889889
Argument db_source_arg = MakeArgument<string>("db", db_source);
@@ -996,7 +996,7 @@ TEST(ContentChunks, Serialization) {
996996

997997
{
998998
DeviceOption option;
999-
option.set_device_type(CPU);
999+
option.set_device_type(PROTO_CPU);
10001000
Argument db_type_arg = MakeArgument<string>("db_type", "vector_db");
10011001
Argument absolute_path_arg = MakeArgument<bool>("absolute_path", true);
10021002
Argument db_source_arg = MakeArgument<string>("db", db_source);

caffe2/core/context.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class CAFFE2_API CPUContext final : public BaseContext {
4848
: random_seed_(
4949
option.has_random_seed() ? option.random_seed()
5050
: RandomNumberSeed()) {
51-
CAFFE_ENFORCE_EQ(option.device_type(), CPU);
51+
CAFFE_ENFORCE_EQ(option.device_type(), PROTO_CPU);
5252
}
5353

5454
~CPUContext() noexcept override {}

caffe2/core/context_base.cc

+8-8
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33
namespace caffe2 {
44

55
// TODO: rename context.h -> context_cpu.h & context_base.h -> context.h
6-
std::array<BaseStaticContext*, COMPILE_TIME_MAX_DEVICE_TYPES>&
7-
GetStaticContexts() {
8-
static std::array<BaseStaticContext*, COMPILE_TIME_MAX_DEVICE_TYPES>
9-
static_contexts;
6+
StaticContextMap& GetStaticContexts() {
7+
static StaticContextMap static_contexts;
108
return static_contexts;
119
}
1210

13-
void set_static_context(int d, BaseStaticContext* ptr) {
11+
void set_static_context(DeviceType t, BaseStaticContext* ptr) {
1412
auto& static_contexts = GetStaticContexts();
15-
static_contexts[d] = ptr;
13+
static_contexts[t] = ptr;
1614
}
1715

18-
BaseStaticContext* get_static_context(int d) {
19-
return GetStaticContexts()[d];
16+
BaseStaticContext* get_static_context(DeviceType t) {
17+
auto* ptr = GetStaticContexts()[t];
18+
CAFFE_ENFORCE(ptr, "StaticContext is not registered yet.");
19+
return ptr;
2020
}
2121

2222
} // namespace caffe2

caffe2/core/context_base.h

+9-10
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class CAFFE2_API BaseStaticContext {
3737
* current context and the a data pointer
3838
*/
3939
virtual void ExtractDeviceOption(DeviceOption* device, const void* /*data*/) {
40-
device->set_device_type(GetDeviceType());
40+
device->set_device_type(TypeToProto(GetDeviceType()));
4141
}
4242
};
4343

@@ -169,22 +169,21 @@ class CAFFE2_API BaseContext {
169169
}
170170
};
171171

172-
CAFFE2_API std::array<BaseStaticContext*, COMPILE_TIME_MAX_DEVICE_TYPES>&
173-
GetStaticContexts();
174-
CAFFE2_API void set_static_context(int d, BaseStaticContext* ptr);
175-
CAFFE2_API BaseStaticContext* get_static_context(int d);
172+
using StaticContextMap = CaffeMap<DeviceType, BaseStaticContext*>;
173+
CAFFE2_API StaticContextMap& GetStaticContexts();
174+
CAFFE2_API void set_static_context(DeviceType t, BaseStaticContext* ptr);
175+
CAFFE2_API BaseStaticContext* get_static_context(DeviceType t);
176176

177-
template <int d>
177+
template <DeviceType t>
178178
struct StaticContextFunctionRegisterer {
179179
explicit StaticContextFunctionRegisterer(BaseStaticContext* ptr) {
180-
static_assert(d < COMPILE_TIME_MAX_DEVICE_TYPES, "");
181-
set_static_context(d, ptr);
180+
set_static_context(t, ptr);
182181
}
183182
};
184183

185-
#define REGISTER_STATIC_CONTEXT(d, f) \
184+
#define REGISTER_STATIC_CONTEXT(t, f) \
186185
namespace { \
187-
static StaticContextFunctionRegisterer<d> g_static_context_##d(f); \
186+
static StaticContextFunctionRegisterer<t> g_static_context_##d(f); \
188187
}
189188

190189
} // namespace caffe2

caffe2/core/context_gpu.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ CUDAContext::CUDAContext(const DeviceOption& option)
257257
option.has_random_seed() ? option.random_seed()
258258
: RandomNumberSeed()) {
259259
static Caffe2CudaInitializerHelper g_cuda_initializer_;
260-
DCHECK_EQ(option.device_type(), CUDA);
260+
DCHECK_EQ(option.device_type(), PROTO_CUDA);
261261
}
262262

263263
// shared mutex to lock out alloc / free during NCCL launches

caffe2/core/context_gpu.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ class CAFFE2_API CUDAStaticContext final : public BaseStaticContext {
403403
}
404404

405405
void ExtractDeviceOption(DeviceOption* device, const void* data) override {
406-
device->set_device_type(GetDeviceType());
406+
device->set_device_type(TypeToProto(GetDeviceType()));
407407
device->set_cuda_gpu_id(GetGPUIDForPointer(data));
408408
}
409409

0 commit comments

Comments
 (0)