@@ -55,7 +55,7 @@ class CuDNNDropoutOp final : public Operator<CUDAContext> {
55
55
cudnnTensorDescriptor_t data_desc_;
56
56
cudnnDropoutDescriptor_t dropout_desc_;
57
57
58
- vector< int64_t > cudnn_input_dims_;
58
+ at::IntList cudnn_input_dims_;
59
59
60
60
float ratio_;
61
61
bool is_test_;
@@ -113,7 +113,7 @@ class CuDNNDropoutGradientOp final : public Operator<CUDAContext> {
113
113
cudnnTensorDescriptor_t data_desc_;
114
114
cudnnDropoutDescriptor_t dropout_desc_;
115
115
116
- vector< int64_t > cudnn_input_dims_;
116
+ at::IntList cudnn_input_dims_;
117
117
118
118
Blob* scratch_blob_;
119
119
@@ -146,12 +146,11 @@ bool CuDNNDropoutOp::DoRunWithType() {
146
146
}
147
147
return true ;
148
148
} else {
149
- auto * mask = Output (1 );
150
149
// Reshape tensor descriptors if necessary
151
- if (X.sizes () != cudnn_input_dims_ && !is_test_ ) {
150
+ if (X.sizes () != cudnn_input_dims_) {
152
151
CAFFE_ENFORCE (scratch_blob_);
153
152
Tensor* states = BlobGetMutableTensor (scratch_blob_, CUDA);
154
- cudnn_input_dims_ = X.sizes (). vec () ;
153
+ cudnn_input_dims_ = X.sizes ();
155
154
CUDNN_ENFORCE (cudnnSetTensor4dDescriptor (
156
155
data_desc_,
157
156
GetCudnnTensorFormat (StorageOrder::NCHW),
@@ -165,7 +164,6 @@ bool CuDNNDropoutOp::DoRunWithType() {
165
164
CUDNN_ENFORCE (cudnnDropoutGetReserveSpaceSize (
166
165
data_desc_, &reserve_space_size_in_bytes_));
167
166
168
- mask->Resize (reserve_space_size_in_bytes_);
169
167
states->Resize (states_size_in_bytes_);
170
168
171
169
if (!states_initialized_) {
@@ -187,6 +185,10 @@ bool CuDNNDropoutOp::DoRunWithType() {
187
185
states_initialized_ = true ;
188
186
}
189
187
}
188
+ auto * mask = Output (
189
+ 1 ,
190
+ {static_cast <int64_t >(reserve_space_size_in_bytes_)},
191
+ at::dtype<uint8_t >());
190
192
CUDNN_ENFORCE (cudnnDropoutForward (
191
193
cudnn_wrapper_.inline_cudnn_handle (),
192
194
dropout_desc_,
@@ -244,7 +246,7 @@ bool CuDNNDropoutGradientOp::DoRunWithType() {
244
246
}
245
247
246
248
if (dY.sizes () != cudnn_input_dims_) {
247
- cudnn_input_dims_ = dY.sizes (). vec () ;
249
+ cudnn_input_dims_ = dY.sizes ();
248
250
CUDNN_ENFORCE (cudnnSetTensor4dDescriptor (
249
251
data_desc_,
250
252
GetCudnnTensorFormat (StorageOrder::NCHW),
0 commit comments