diff --git a/aten/src/ATen/miopen/Descriptors.h b/aten/src/ATen/miopen/Descriptors.h index e32adf55c551..7729545f02b9 100644 --- a/aten/src/ATen/miopen/Descriptors.h +++ b/aten/src/ATen/miopen/Descriptors.h @@ -121,6 +121,18 @@ struct ConvolutionDescriptor } }; +<<<<<<< HEAD +======= +struct DropoutDescriptor + : public Descriptor +{ + void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } +>>>>>>> 699f46313da ([release/2.7] Fix test_rnn_check_device tests for P1 Jira SWDEV-542659 (#2440)) struct RNNDescriptor : public Descriptor { void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode, - miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) { + miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) { MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype)); } + + void setWithDropout(DropoutDescriptor& dropout_desc, int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, + miopenRNNMode_t rnn_mode, miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) { + MIOPEN_CHECK(miopenSetRNNDescriptor_V2(mut_desc(), hidden_size, num_layers, dropout_desc.mut_desc(), input_mode, direction, rnn_mode, bias_mode, algorithm, datatype)); + } }; union Constant diff --git a/aten/src/ATen/native/miopen/RNN_miopen.cpp b/aten/src/ATen/native/miopen/RNN_miopen.cpp index 86ef2fb707d5..52c472246aa0 100644 --- a/aten/src/ATen/native/miopen/RNN_miopen.cpp +++ b/aten/src/ATen/native/miopen/RNN_miopen.cpp @@ -57,6 +57,10 @@ namespace at { namespace native { #include +#include + +#include + #include #include #include @@ -66,12 +70,35 @@ namespace at { namespace native { #include #include -namespace at { namespace native { +namespace at::native { + +namespace { + + struct DropoutState { + DropoutState(size_t size) : size(size), data(NULL) { + data = c10::hip::HIPCachingAllocator::raw_alloc(size); + } + DropoutState(const DropoutState&) = delete; + DropoutState(DropoutState&&) = default; + DropoutState& operator=(DropoutState&&) = default; + ~DropoutState() { + if (data) { + c10::hip::HIPCachingAllocator::raw_delete(data); + } + } + + size_t size; + void* data; + }; + +} // anonymous //RNNDescriptor. struct RNNDescriptorParams { int64_t hidden_size; int64_t num_layers; + double dropout_rate; + uint64_t dropout_seed; miopenRNNDirectionMode_t direction; miopenRNNMode_t rnn_mode; miopenDataType_t datatype; @@ -114,6 +141,12 @@ struct RNNDescriptorParams { } } + void set_dropout(double dropout_rate, uint64_t dropout_seed = 0) { + this->dropout_rate = dropout_rate; + // TODO: Implement seed setting for RNN dropout + this->dropout_seed = dropout_seed; + } + void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) { this->set_mode(mode); this->hidden_size = hidden_size; @@ -128,12 +161,18 @@ struct RNNDescriptorParams { rnn_desc.set(hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype); return rnn_desc; } + + RNNDescriptor descriptorWithDropout(DropoutDescriptor& dropout_desc) const { + RNNDescriptor rnn_desc; + rnn_desc.setWithDropout(dropout_desc, hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype); + return rnn_desc; + } }; //TensorDescriptor list. std::vector rnn_descriptor_sequence(const Tensor& tensor, IntArrayRef batch_sizes) { std::vector descriptors(batch_sizes.size()); - size_t i =0; + size_t i = 0; auto batch_tensor_size = tensor.sizes().vec(); for (auto batch_size : batch_sizes) { @@ -204,6 +243,8 @@ struct RNNParams { struct RNNDescriptors { RNNDescriptor rnn_desc; + static thread_local DropoutDescriptor dropout_desc; + static thread_local std::unique_ptr dropout_states; std::vector x_descs; std::vector y_descs; TensorDescriptor hx_desc; @@ -212,7 +253,39 @@ struct RNNDescriptors { TensorDescriptor cy_desc; RNNDescriptors(const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) { - rnn_desc = fn.rnn.descriptor(); + if (fn.rnn.dropout_rate == 0.0) { + rnn_desc = fn.rnn.descriptor(); + } else { + if (!dropout_states) { + size_t states_size_in_bytes = 0; + MIOPEN_CHECK(miopenDropoutGetStatesSize(handle, &states_size_in_bytes)); + size_t states_size = states_size_in_bytes / sizeof(rocrand_state_xorwow); + + dropout_states = std::make_unique(states_size * sizeof(rocrand_state_xorwow)); + + dropout_desc.set(handle, + fn.rnn.dropout_rate, + dropout_states->data, + dropout_states->size, + fn.rnn.dropout_seed, + false, + false, + miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW); + } else { + dropout_desc.restore(handle, + fn.rnn.dropout_rate, + dropout_states->data, + dropout_states->size, + fn.rnn.dropout_seed, + // use_mask flag must be true in order to continue from a saved RNG state + true, + false, + miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW); + } + + rnn_desc = fn.rnn.descriptorWithDropout(dropout_desc); + } + x_descs = fn.tensors.descriptors(x); y_descs = fn.tensors.descriptors(y); hx_desc.set(hx, 5); @@ -239,6 +312,11 @@ struct RNNDescriptors { } }; +// We need to store both the dropout descriptor and state thread locally to avoid multithreading issues +thread_local DropoutDescriptor RNNDescriptors::dropout_desc {}; +// Each state is 0.75 MB so there is no problem in caching all of them for each thread +thread_local std::unique_ptr RNNDescriptors::dropout_states { nullptr }; + Tensor permute_wei_for_miopen(Tensor wei, int64_t mode) { if (mode < 2) @@ -492,7 +570,7 @@ std::tuple miopen_rnn( auto handle = getMiopenHandle(); miopenRNNAlgo_t algo = miopenRNNdefault; fn.rnn.set_algo(algo); - + fn.rnn.set_dropout(fn_dropout); RNNDescriptors descs(fn, handle, x, y, hx, cx); FilterDescriptor w_desc; @@ -551,7 +629,6 @@ std::tuple miopen_rnn( } return std::make_tuple(output, hy, cy, reserve, weight_buf); - } std::tuple miopen_rnn_backward_input( @@ -626,6 +703,7 @@ std::tuple miopen_rnn_backward_input( miopenRNNAlgo_t algo = miopenRNNdefault; fn.rnn.set_algo(algo); + fn.rnn.set_dropout(fn_dropout); RNNDescriptors descs(fn, handle, x, y, hx, cx); FilterDescriptor w_desc; @@ -720,6 +798,7 @@ std::vector miopen_rnn_backward_weight( miopenRNNAlgo_t algo = miopenRNNdefault; fn.rnn.set_algo(algo); + fn.rnn.set_dropout(fn_dropout); RNNDescriptors descs(fn, handle, x, y, hx, cx); FilterDescriptor w_desc; @@ -909,6 +988,6 @@ REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen); REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen); } // anonymous namespace -}} //namespace native. +} // namespace at::native #endif