@@ -57,6 +57,10 @@ namespace at::native {
5757
5858#include < ATen/TensorUtils.h>
5959
60+ #include < c10/hip/HIPCachingAllocator.h>
61+
62+ #include < rocrand/rocrand_xorwow.h>
63+
6064#include < functional>
6165#include < iterator>
6266#include < sstream>
@@ -66,12 +70,35 @@ namespace at::native {
6670#include < stdint.h>
6771#include < unordered_map>
6872
69- namespace at { namespace native {
73+ namespace at ::native {
74+
75+ namespace {
76+
77+ struct DropoutState {
78+ DropoutState (size_t size) : size(size), data(NULL ) {
79+ data = c10::hip::HIPCachingAllocator::raw_alloc (size);
80+ }
81+ DropoutState (const DropoutState&) = delete ;
82+ DropoutState (DropoutState&&) = default ;
83+ DropoutState& operator =(DropoutState&&) = default ;
84+ ~DropoutState () {
85+ if (data) {
86+ c10::hip::HIPCachingAllocator::raw_delete (data);
87+ }
88+ }
89+
90+ size_t size;
91+ void * data;
92+ };
93+
94+ } // anonymous
7095
7196// RNNDescriptor.
7297struct RNNDescriptorParams {
7398 int64_t hidden_size;
7499 int64_t num_layers;
100+ double dropout_rate;
101+ uint64_t dropout_seed;
75102 miopenRNNDirectionMode_t direction;
76103 miopenRNNMode_t rnn_mode;
77104 miopenDataType_t datatype;
@@ -114,6 +141,12 @@ struct RNNDescriptorParams {
114141 }
115142 }
116143
144+ void set_dropout (double dropout_rate, uint64_t dropout_seed = 0 ) {
145+ this ->dropout_rate = dropout_rate;
146+ // TODO: Implement seed setting for RNN dropout
147+ this ->dropout_seed = dropout_seed;
148+ }
149+
117150 void set (int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) {
118151 this ->set_mode (mode);
119152 this ->hidden_size = hidden_size;
@@ -128,12 +161,18 @@ struct RNNDescriptorParams {
128161 rnn_desc.set (hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
129162 return rnn_desc;
130163 }
164+
165+ RNNDescriptor descriptorWithDropout (DropoutDescriptor& dropout_desc) const {
166+ RNNDescriptor rnn_desc;
167+ rnn_desc.setWithDropout (dropout_desc, hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
168+ return rnn_desc;
169+ }
131170};
132171
133172// TensorDescriptor list.
134173std::vector<TensorDescriptor> rnn_descriptor_sequence (const Tensor& tensor, IntArrayRef batch_sizes) {
135174 std::vector<TensorDescriptor> descriptors (batch_sizes.size ());
136- size_t i =0 ;
175+ size_t i = 0 ;
137176
138177 auto batch_tensor_size = tensor.sizes ().vec ();
139178 for (auto batch_size : batch_sizes) {
@@ -204,6 +243,8 @@ struct RNNParams {
204243
205244struct RNNDescriptors {
206245 RNNDescriptor rnn_desc;
246+ static thread_local DropoutDescriptor dropout_desc;
247+ static thread_local std::unique_ptr<DropoutState> dropout_states;
207248 std::vector<TensorDescriptor> x_descs;
208249 std::vector<TensorDescriptor> y_descs;
209250 TensorDescriptor hx_desc;
@@ -212,7 +253,39 @@ struct RNNDescriptors {
212253 TensorDescriptor cy_desc;
213254
214255 RNNDescriptors (const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) {
215- rnn_desc = fn.rnn .descriptor ();
256+ if (fn.rnn .dropout_rate == 0.0 ) {
257+ rnn_desc = fn.rnn .descriptor ();
258+ } else {
259+ if (!dropout_states) {
260+ size_t states_size_in_bytes = 0 ;
261+ MIOPEN_CHECK (miopenDropoutGetStatesSize (handle, &states_size_in_bytes));
262+ size_t states_size = states_size_in_bytes / sizeof (rocrand_state_xorwow);
263+
264+ dropout_states = std::make_unique<DropoutState>(states_size * sizeof (rocrand_state_xorwow));
265+
266+ dropout_desc.set (handle,
267+ fn.rnn .dropout_rate ,
268+ dropout_states->data ,
269+ dropout_states->size ,
270+ fn.rnn .dropout_seed ,
271+ false ,
272+ false ,
273+ miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW);
274+ } else {
275+ dropout_desc.restore (handle,
276+ fn.rnn .dropout_rate ,
277+ dropout_states->data ,
278+ dropout_states->size ,
279+ fn.rnn .dropout_seed ,
280+ // use_mask flag must be true in order to continue from a saved RNG state
281+ true ,
282+ false ,
283+ miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW);
284+ }
285+
286+ rnn_desc = fn.rnn .descriptorWithDropout (dropout_desc);
287+ }
288+
216289 x_descs = fn.tensors .descriptors (x);
217290 y_descs = fn.tensors .descriptors (y);
218291 hx_desc.set (hx, 5 );
@@ -239,6 +312,11 @@ struct RNNDescriptors {
239312 }
240313};
241314
315+ // We need to store both the dropout descriptor and state thread locally to avoid multithreading issues
316+ thread_local DropoutDescriptor RNNDescriptors::dropout_desc {};
317+ // Each state is 0.75 MB so there is no problem in caching all of them for each thread
318+ thread_local std::unique_ptr<DropoutState> RNNDescriptors::dropout_states { nullptr };
319+
242320Tensor permute_wei_for_miopen (Tensor wei, int64_t mode)
243321{
244322 if (mode < 2 )
@@ -492,7 +570,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
492570 auto handle = getMiopenHandle ();
493571 miopenRNNAlgo_t algo = miopenRNNdefault;
494572 fn.rnn .set_algo (algo);
495-
573+ fn. rnn . set_dropout (fn_dropout);
496574 RNNDescriptors descs (fn, handle, x, y, hx, cx);
497575
498576 FilterDescriptor w_desc;
@@ -551,7 +629,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
551629 }
552630
553631 return std::make_tuple (output, hy, cy, reserve, weight_buf);
554-
555632}
556633
557634std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input (
@@ -626,6 +703,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
626703
627704 miopenRNNAlgo_t algo = miopenRNNdefault;
628705 fn.rnn .set_algo (algo);
706+ fn.rnn .set_dropout (fn_dropout);
629707 RNNDescriptors descs (fn, handle, x, y, hx, cx);
630708
631709 FilterDescriptor w_desc;
@@ -720,6 +798,7 @@ std::vector<Tensor> miopen_rnn_backward_weight(
720798
721799 miopenRNNAlgo_t algo = miopenRNNdefault;
722800 fn.rnn .set_algo (algo);
801+ fn.rnn .set_dropout (fn_dropout);
723802 RNNDescriptors descs (fn, handle, x, y, hx, cx);
724803
725804 FilterDescriptor w_desc;
@@ -909,6 +988,6 @@ REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen)
909988REGISTER_CUDA_DISPATCH (lstm_packed_miopen_stub, &lstm_packed_miopen)
910989
911990} // anonymous namespace
912- }} // namespace native.
991+ } // namespace at:: native
913992
914993#endif
0 commit comments