|
| 1 | +#include "../includes/gelu.cuh" |
| 2 | + |
| 3 | +#include "../includes/runtime.cuh" |
| 4 | +#include "../includes/utils.cuh" |
| 5 | + |
| 6 | +#include <cmath> |
| 7 | + |
| 8 | +namespace quadtrix { |
| 9 | +namespace cuda { |
| 10 | +namespace { |
| 11 | + |
| 12 | +constexpr float kSqrtHalf = 0.70710678118654752440f; |
| 13 | +constexpr float kSqrtTwoOverPi = 0.79788456080286535588f; |
| 14 | +constexpr float kGeluCoeff = 0.044715f; |
| 15 | + |
| 16 | +bool valid_backward_tensors(const TensorView& grad_output, const TensorView& input, const TensorView& grad_input) { |
| 17 | + if (grad_output.data == nullptr || input.data == nullptr || grad_input.data == nullptr || |
| 18 | + grad_output.device != DeviceKind::CUDA || input.device != DeviceKind::CUDA || |
| 19 | + grad_input.device != DeviceKind::CUDA || grad_output.device_id != input.device_id || |
| 20 | + grad_output.device_id != grad_input.device_id || grad_output.dtype != DType::F32 || |
| 21 | + input.dtype != DType::F32 || grad_input.dtype != DType::F32 || grad_output.shape.rank != input.shape.rank || |
| 22 | + grad_output.shape.rank != grad_input.shape.rank || !grad_output.shape.is_contiguous() || |
| 23 | + !input.shape.is_contiguous() || !grad_input.shape.is_contiguous() || grad_output.numel() != input.numel() || |
| 24 | + grad_output.numel() != grad_input.numel()) { |
| 25 | + return false; |
| 26 | + } |
| 27 | + for (int i = 0; i < input.shape.rank; ++i) { |
| 28 | + if (grad_output.shape.dims[i] != input.shape.dims[i] || grad_output.shape.dims[i] != grad_input.shape.dims[i]) { |
| 29 | + return false; |
| 30 | + } |
| 31 | + } |
| 32 | + return true; |
| 33 | +} |
| 34 | + |
| 35 | +__device__ __forceinline__ float gelu_grad(float x, GeluMode mode) { |
| 36 | + if (mode == GeluMode::Exact) { |
| 37 | + const float cdf = 0.5f * (1.0f + erff(x * kSqrtHalf)); |
| 38 | + const float pdf = 0.39894228040143267794f * expf(-0.5f * x * x); |
| 39 | + return cdf + x * pdf; |
| 40 | + } |
| 41 | + |
| 42 | + const float x2 = x * x; |
| 43 | + const float x3 = x2 * x; |
| 44 | + const float inner = kSqrtTwoOverPi * (x + kGeluCoeff * x3); |
| 45 | + const float t = tanhf(inner); |
| 46 | + const float sech2 = 1.0f - t * t; |
| 47 | + const float inner_grad = kSqrtTwoOverPi * (1.0f + 3.0f * kGeluCoeff * x2); |
| 48 | + return 0.5f * (1.0f + t) + 0.5f * x * sech2 * inner_grad; |
| 49 | +} |
| 50 | + |
| 51 | +__global__ void gelu_backward_kernel( |
| 52 | + const float* __restrict__ grad_output, |
| 53 | + const float* __restrict__ input, |
| 54 | + float* __restrict__ grad_input, |
| 55 | + std::size_t n, |
| 56 | + GeluMode mode) { |
| 57 | + const std::size_t idx = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x; |
| 58 | + if (idx < n) { |
| 59 | + grad_input[idx] += grad_output[idx] * gelu_grad(input[idx], mode); |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +} // namespace |
| 64 | + |
| 65 | +Status gelu_backward( |
| 66 | + const TensorView& grad_output, |
| 67 | + const TensorView& input, |
| 68 | + TensorView grad_input, |
| 69 | + GeluMode mode, |
| 70 | + cudaStream_t stream) { |
| 71 | + if (!valid_backward_tensors(grad_output, input, grad_input)) { |
| 72 | + return Status::failure(cudaErrorInvalidValue, "invalid gelu_backward arguments"); |
| 73 | + } |
| 74 | + |
| 75 | + DeviceGuard guard(input.device_id); |
| 76 | + const std::size_t n = input.numel(); |
| 77 | + gelu_backward_kernel<<<one_dim_grid(n), kDefaultBlockSize, 0, stream>>>( |
| 78 | + grad_output.data_as<const float>(), |
| 79 | + input.data_as<const float>(), |
| 80 | + grad_input.data_as<float>(), |
| 81 | + n, |
| 82 | + mode); |
| 83 | + return QUADTRIX_CUDA_CHECK(cudaGetLastError()); |
| 84 | +} |
| 85 | + |
| 86 | +} // namespace cuda |
| 87 | +} // namespace quadtrix |
0 commit comments