Skip to content

Commit 65cbb82

Browse files
smessmerfacebook-github-bot
authored andcommitted
IValue can store Blob (pytorch#11414)
Summary: Pull Request resolved: pytorch#11414 caffe2::Blob can be stored in an IValue. This is a precondition for caffe2 to switch from Blob to IValue. Reviewed By: ezyang Differential Revision: D9731326 fbshipit-source-id: 462a39d2d9ab6f85b99b1670848c6976a3de417c
1 parent b7ebc00 commit 65cbb82

File tree

4 files changed

+36
-7
lines changed

4 files changed

+36
-7
lines changed

aten/src/ATen/core/blob.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <typeinfo>
77
#include <vector>
88

9+
#include <ATen/core/intrusive_ptr.h>
910
#include <ATen/core/typeid.h>
1011
#include <c10/macros/Macros.h>
1112

@@ -20,7 +21,7 @@ class Tensor;
2021
* properly when the blob is deallocated or re-allocated with a new type. A blob
2122
* could contain anything, although the most common case is to contain a Tensor.
2223
*/
23-
class CAFFE2_API Blob final {
24+
class CAFFE2_API Blob final : public c10::intrusive_ptr_target {
2425
public:
2526
using DestroyCall = void(void*);
2627

@@ -232,4 +233,8 @@ inline void swap(Blob& lhs, Blob& rhs) {
232233
lhs.swap(rhs);
233234
}
234235

236+
inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
237+
return out << "Blob[" << v.TypeName() << "]";
238+
}
239+
235240
} // namespace caffe2

aten/src/ATen/core/ivalue.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#include <ATen/core/ivalue.h>
22
#include <ATen/core/Formatting.h>
33

4-
#define TORCH_FORALL_TAGS(_) \
5-
_(None) _(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) _(TensorList)
4+
#define TORCH_FORALL_TAGS(_) \
5+
_(None) \
6+
_(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) \
7+
_(TensorList) _(Blob)
68

79
namespace torch { namespace jit {
810

aten/src/ATen/core/ivalue.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/core/Tensor.h>
55
#include <ATen/core/TensorImpl.h>
66
#include <ATen/core/UndefinedTensorImpl.h>
7+
#include <ATen/core/blob.h>
78
#include <ATen/core/intrusive_ptr.h>
89

910
#include <type_traits>
@@ -64,8 +65,10 @@ using DoubleList = ConstantList<double>;
6465
// to mark whether that type is a subtype of c10::intrusive_ptr_target and needs
6566
// retain/release calls.
6667

67-
#define TORCH_FORALL_TAGS(_) \
68-
_(None) _(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) _(TensorList)
68+
#define TORCH_FORALL_TAGS(_) \
69+
_(None) \
70+
_(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) \
71+
_(TensorList) _(Blob)
6972

7073
struct CAFFE2_API IValue final {
7174
IValue()
@@ -125,6 +128,25 @@ struct CAFFE2_API IValue final {
125128
return at::Tensor(toIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
126129
}
127130

131+
IValue(caffe2::Blob blob) : tag(Tag::Blob), is_intrusive_ptr(true) {
132+
// TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract
133+
// and
134+
// store it as a Tensor instead.
135+
payload.as_intrusive_ptr =
136+
c10::make_intrusive<caffe2::Blob>(std::move(blob)).release();
137+
}
138+
bool isBlob() const {
139+
return Tag::Blob == tag;
140+
}
141+
caffe2::Blob& toBlob() & {
142+
AT_ASSERT(isBlob());
143+
return *static_cast<caffe2::Blob*>(payload.as_intrusive_ptr);
144+
}
145+
const caffe2::Blob& toBlob() const& {
146+
AT_ASSERT(isBlob());
147+
return *static_cast<caffe2::Blob*>(payload.as_intrusive_ptr);
148+
}
149+
128150
// Tuple
129151
IValue(c10::intrusive_ptr<Tuple> v);
130152
bool isTuple() const { return Tag::Tuple == tag; }

binaries/tutorial_blob.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ int main(int argc, char** argv) {
4747
LOG(INFO)
4848
<< "Is the blob type float? "
4949
<< myblob.IsType<float>();
50-
50+
5151
const int& myint_const = myblob.Get<int>();
5252
LOG(INFO)
5353
<< "The value of the int number stored in the blob is: "
@@ -80,7 +80,7 @@ int main(int argc, char** argv) {
8080

8181
std::string* pvec = new std::string();
8282
myblob.Reset(pvec); // no need to release pvec, myblob takes ownership.
83-
83+
8484
LOG(INFO) << "Is the blob now of type string? "
8585
<< myblob.IsType<std::string>();
8686

0 commit comments

Comments
 (0)