@@ -57,6 +57,10 @@ namespace at::native {
57
57
58
58
#include < ATen/TensorUtils.h>
59
59
60
+ #include < c10/hip/HIPCachingAllocator.h>
61
+
62
+ #include < rocrand/rocrand_xorwow.h>
63
+
60
64
#include < functional>
61
65
#include < iterator>
62
66
#include < sstream>
@@ -66,12 +70,35 @@ namespace at::native {
66
70
#include < stdint.h>
67
71
#include < unordered_map>
68
72
69
- namespace at { namespace native {
73
+ namespace at ::native {
74
+
75
+ namespace {
76
+
77
+ struct DropoutState {
78
+ DropoutState (size_t size) : size(size), data(NULL ) {
79
+ data = c10::hip::HIPCachingAllocator::raw_alloc (size);
80
+ }
81
+ DropoutState (const DropoutState&) = delete ;
82
+ DropoutState (DropoutState&&) = default ;
83
+ DropoutState& operator =(DropoutState&&) = default ;
84
+ ~DropoutState () {
85
+ if (data) {
86
+ c10::hip::HIPCachingAllocator::raw_delete (data);
87
+ }
88
+ }
89
+
90
+ size_t size;
91
+ void * data;
92
+ };
93
+
94
+ } // anonymous
70
95
71
96
// RNNDescriptor.
72
97
struct RNNDescriptorParams {
73
98
int64_t hidden_size;
74
99
int64_t num_layers;
100
+ double dropout_rate;
101
+ uint64_t dropout_seed;
75
102
miopenRNNDirectionMode_t direction;
76
103
miopenRNNMode_t rnn_mode;
77
104
miopenDataType_t datatype;
@@ -114,6 +141,12 @@ struct RNNDescriptorParams {
114
141
}
115
142
}
116
143
144
+ void set_dropout (double dropout_rate, uint64_t dropout_seed = 0 ) {
145
+ this ->dropout_rate = dropout_rate;
146
+ // TODO: Implement seed setting for RNN dropout
147
+ this ->dropout_seed = dropout_seed;
148
+ }
149
+
117
150
void set (int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) {
118
151
this ->set_mode (mode);
119
152
this ->hidden_size = hidden_size;
@@ -128,12 +161,18 @@ struct RNNDescriptorParams {
128
161
rnn_desc.set (hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
129
162
return rnn_desc;
130
163
}
164
+
165
+ RNNDescriptor descriptorWithDropout (DropoutDescriptor& dropout_desc) const {
166
+ RNNDescriptor rnn_desc;
167
+ rnn_desc.setWithDropout (dropout_desc, hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
168
+ return rnn_desc;
169
+ }
131
170
};
132
171
133
172
// TensorDescriptor list.
134
173
std::vector<TensorDescriptor> rnn_descriptor_sequence (const Tensor& tensor, IntArrayRef batch_sizes) {
135
174
std::vector<TensorDescriptor> descriptors (batch_sizes.size ());
136
- size_t i =0 ;
175
+ size_t i = 0 ;
137
176
138
177
auto batch_tensor_size = tensor.sizes ().vec ();
139
178
for (auto batch_size : batch_sizes) {
@@ -204,6 +243,8 @@ struct RNNParams {
204
243
205
244
struct RNNDescriptors {
206
245
RNNDescriptor rnn_desc;
246
+ static thread_local DropoutDescriptor dropout_desc;
247
+ static thread_local std::unique_ptr<DropoutState> dropout_states;
207
248
std::vector<TensorDescriptor> x_descs;
208
249
std::vector<TensorDescriptor> y_descs;
209
250
TensorDescriptor hx_desc;
@@ -212,7 +253,39 @@ struct RNNDescriptors {
212
253
TensorDescriptor cy_desc;
213
254
214
255
RNNDescriptors (const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) {
215
- rnn_desc = fn.rnn .descriptor ();
256
+ if (fn.rnn .dropout_rate == 0.0 ) {
257
+ rnn_desc = fn.rnn .descriptor ();
258
+ } else {
259
+ if (!dropout_states) {
260
+ size_t states_size_in_bytes = 0 ;
261
+ MIOPEN_CHECK (miopenDropoutGetStatesSize (handle, &states_size_in_bytes));
262
+ size_t states_size = states_size_in_bytes / sizeof (rocrand_state_xorwow);
263
+
264
+ dropout_states = std::make_unique<DropoutState>(states_size * sizeof (rocrand_state_xorwow));
265
+
266
+ dropout_desc.set (handle,
267
+ fn.rnn .dropout_rate ,
268
+ dropout_states->data ,
269
+ dropout_states->size ,
270
+ fn.rnn .dropout_seed ,
271
+ false ,
272
+ false ,
273
+ miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW);
274
+ } else {
275
+ dropout_desc.restore (handle,
276
+ fn.rnn .dropout_rate ,
277
+ dropout_states->data ,
278
+ dropout_states->size ,
279
+ fn.rnn .dropout_seed ,
280
+ // use_mask flag must be true in order to continue from a saved RNG state
281
+ true ,
282
+ false ,
283
+ miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW);
284
+ }
285
+
286
+ rnn_desc = fn.rnn .descriptorWithDropout (dropout_desc);
287
+ }
288
+
216
289
x_descs = fn.tensors .descriptors (x);
217
290
y_descs = fn.tensors .descriptors (y);
218
291
hx_desc.set (hx, 5 );
@@ -239,6 +312,11 @@ struct RNNDescriptors {
239
312
}
240
313
};
241
314
315
+ // We need to store both the dropout descriptor and state thread locally to avoid multithreading issues
316
+ thread_local DropoutDescriptor RNNDescriptors::dropout_desc {};
317
+ // Each state is 0.75 MB so there is no problem in caching all of them for each thread
318
+ thread_local std::unique_ptr<DropoutState> RNNDescriptors::dropout_states { nullptr };
319
+
242
320
Tensor permute_wei_for_miopen (Tensor wei, int64_t mode)
243
321
{
244
322
if (mode < 2 )
@@ -492,7 +570,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
492
570
auto handle = getMiopenHandle ();
493
571
miopenRNNAlgo_t algo = miopenRNNdefault;
494
572
fn.rnn .set_algo (algo);
495
-
573
+ fn. rnn . set_dropout (fn_dropout);
496
574
RNNDescriptors descs (fn, handle, x, y, hx, cx);
497
575
498
576
FilterDescriptor w_desc;
@@ -551,7 +629,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
551
629
}
552
630
553
631
return std::make_tuple (output, hy, cy, reserve, weight_buf);
554
-
555
632
}
556
633
557
634
std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input (
@@ -626,6 +703,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
626
703
627
704
miopenRNNAlgo_t algo = miopenRNNdefault;
628
705
fn.rnn .set_algo (algo);
706
+ fn.rnn .set_dropout (fn_dropout);
629
707
RNNDescriptors descs (fn, handle, x, y, hx, cx);
630
708
631
709
FilterDescriptor w_desc;
@@ -720,6 +798,7 @@ std::vector<Tensor> miopen_rnn_backward_weight(
720
798
721
799
miopenRNNAlgo_t algo = miopenRNNdefault;
722
800
fn.rnn .set_algo (algo);
801
+ fn.rnn .set_dropout (fn_dropout);
723
802
RNNDescriptors descs (fn, handle, x, y, hx, cx);
724
803
725
804
FilterDescriptor w_desc;
@@ -909,6 +988,6 @@ REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen)
909
988
REGISTER_CUDA_DISPATCH (lstm_packed_miopen_stub, &lstm_packed_miopen)
910
989
911
990
} // anonymous namespace
912
- }} // namespace native.
991
+ } // namespace at:: native
913
992
914
993
#endif
0 commit comments