Skip to content

Commit 7cc6c72

Browse files
committed
feat(cuda): implement backward pass for GELU activation kernels
1 parent 2cc9a9e commit 7cc6c72

1 file changed

Lines changed: 87 additions & 0 deletions

File tree

cuda/KERNAL/gelu_backward.cu

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#include "../includes/gelu.cuh"
2+
3+
#include "../includes/runtime.cuh"
4+
#include "../includes/utils.cuh"
5+
6+
#include <cmath>
7+
8+
namespace quadtrix {
9+
namespace cuda {
10+
namespace {
11+
12+
constexpr float kSqrtHalf = 0.70710678118654752440f;
13+
constexpr float kSqrtTwoOverPi = 0.79788456080286535588f;
14+
constexpr float kGeluCoeff = 0.044715f;
15+
16+
bool valid_backward_tensors(const TensorView& grad_output, const TensorView& input, const TensorView& grad_input) {
17+
if (grad_output.data == nullptr || input.data == nullptr || grad_input.data == nullptr ||
18+
grad_output.device != DeviceKind::CUDA || input.device != DeviceKind::CUDA ||
19+
grad_input.device != DeviceKind::CUDA || grad_output.device_id != input.device_id ||
20+
grad_output.device_id != grad_input.device_id || grad_output.dtype != DType::F32 ||
21+
input.dtype != DType::F32 || grad_input.dtype != DType::F32 || grad_output.shape.rank != input.shape.rank ||
22+
grad_output.shape.rank != grad_input.shape.rank || !grad_output.shape.is_contiguous() ||
23+
!input.shape.is_contiguous() || !grad_input.shape.is_contiguous() || grad_output.numel() != input.numel() ||
24+
grad_output.numel() != grad_input.numel()) {
25+
return false;
26+
}
27+
for (int i = 0; i < input.shape.rank; ++i) {
28+
if (grad_output.shape.dims[i] != input.shape.dims[i] || grad_output.shape.dims[i] != grad_input.shape.dims[i]) {
29+
return false;
30+
}
31+
}
32+
return true;
33+
}
34+
35+
__device__ __forceinline__ float gelu_grad(float x, GeluMode mode) {
36+
if (mode == GeluMode::Exact) {
37+
const float cdf = 0.5f * (1.0f + erff(x * kSqrtHalf));
38+
const float pdf = 0.39894228040143267794f * expf(-0.5f * x * x);
39+
return cdf + x * pdf;
40+
}
41+
42+
const float x2 = x * x;
43+
const float x3 = x2 * x;
44+
const float inner = kSqrtTwoOverPi * (x + kGeluCoeff * x3);
45+
const float t = tanhf(inner);
46+
const float sech2 = 1.0f - t * t;
47+
const float inner_grad = kSqrtTwoOverPi * (1.0f + 3.0f * kGeluCoeff * x2);
48+
return 0.5f * (1.0f + t) + 0.5f * x * sech2 * inner_grad;
49+
}
50+
51+
__global__ void gelu_backward_kernel(
52+
const float* __restrict__ grad_output,
53+
const float* __restrict__ input,
54+
float* __restrict__ grad_input,
55+
std::size_t n,
56+
GeluMode mode) {
57+
const std::size_t idx = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
58+
if (idx < n) {
59+
grad_input[idx] += grad_output[idx] * gelu_grad(input[idx], mode);
60+
}
61+
}
62+
63+
} // namespace
64+
65+
Status gelu_backward(
66+
const TensorView& grad_output,
67+
const TensorView& input,
68+
TensorView grad_input,
69+
GeluMode mode,
70+
cudaStream_t stream) {
71+
if (!valid_backward_tensors(grad_output, input, grad_input)) {
72+
return Status::failure(cudaErrorInvalidValue, "invalid gelu_backward arguments");
73+
}
74+
75+
DeviceGuard guard(input.device_id);
76+
const std::size_t n = input.numel();
77+
gelu_backward_kernel<<<one_dim_grid(n), kDefaultBlockSize, 0, stream>>>(
78+
grad_output.data_as<const float>(),
79+
input.data_as<const float>(),
80+
grad_input.data_as<float>(),
81+
n,
82+
mode);
83+
return QUADTRIX_CUDA_CHECK(cudaGetLastError());
84+
}
85+
86+
} // namespace cuda
87+
} // namespace quadtrix

0 commit comments

Comments
 (0)