Skip to content

Commit e1e4d51

Browse files
committed
feat(cuda): implement softmax and causal softmax forward kernels
Implements softmax_forward and causal_softmax_forward using CUDA kernels.
1 parent 27fddb7 commit e1e4d51

1 file changed

Lines changed: 187 additions & 0 deletions

File tree

cuda/KERNAL/softmax_forward.cu

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#include "../includes/softmax.cuh"
2+
3+
#include "../includes/runtime.cuh"
4+
#include "../includes/utils.cuh"
5+
6+
#include <cmath>
7+
#include <limits>
8+
9+
namespace quadtrix {
10+
namespace cuda {
11+
namespace {
12+
13+
constexpr int kSoftmaxBlockSize = 256;
14+
15+
bool fits_int(std::int64_t value) {
16+
return value > 0 && value <= std::numeric_limits<int>::max();
17+
}
18+
19+
bool valid_same_shape_f32(const TensorView& a, const TensorView& b) {
20+
if (a.data == nullptr || b.data == nullptr || a.device != DeviceKind::CUDA || b.device != DeviceKind::CUDA ||
21+
a.device_id != b.device_id || a.dtype != DType::F32 || b.dtype != DType::F32 ||
22+
a.shape.rank != b.shape.rank || !a.shape.is_contiguous() || !b.shape.is_contiguous() ||
23+
a.numel() != b.numel()) {
24+
return false;
25+
}
26+
for (int i = 0; i < a.shape.rank; ++i) {
27+
if (a.shape.dims[i] != b.shape.dims[i]) {
28+
return false;
29+
}
30+
}
31+
return true;
32+
}
33+
34+
__device__ float block_sum(float value, float* shared) {
35+
value = warp_sum(value);
36+
const int lane = threadIdx.x & (kWarpSize - 1);
37+
const int warp = threadIdx.x / kWarpSize;
38+
if (lane == 0) {
39+
shared[warp] = value;
40+
}
41+
__syncthreads();
42+
const int warp_count = (blockDim.x + kWarpSize - 1) / kWarpSize;
43+
value = threadIdx.x < warp_count ? shared[lane] : 0.0f;
44+
if (warp == 0) {
45+
value = warp_sum(value);
46+
}
47+
if (threadIdx.x == 0) {
48+
shared[0] = value;
49+
}
50+
__syncthreads();
51+
return shared[0];
52+
}
53+
54+
__device__ float block_max(float value, float* shared) {
55+
value = warp_max(value);
56+
const int lane = threadIdx.x & (kWarpSize - 1);
57+
const int warp = threadIdx.x / kWarpSize;
58+
if (lane == 0) {
59+
shared[warp] = value;
60+
}
61+
__syncthreads();
62+
const int warp_count = (blockDim.x + kWarpSize - 1) / kWarpSize;
63+
value = threadIdx.x < warp_count ? shared[lane] : -INFINITY;
64+
if (warp == 0) {
65+
value = warp_max(value);
66+
}
67+
if (threadIdx.x == 0) {
68+
shared[0] = value;
69+
}
70+
__syncthreads();
71+
return shared[0];
72+
}
73+
74+
__global__ void softmax_forward_kernel(
75+
const float* __restrict__ logits,
76+
float* __restrict__ probs,
77+
int rows,
78+
int cols,
79+
int valid_cols) {
80+
extern __shared__ float shared[];
81+
const int row = blockIdx.x;
82+
if (row >= rows) {
83+
return;
84+
}
85+
86+
const float* __restrict__ logits_row = logits + row * cols;
87+
float* __restrict__ probs_row = probs + row * cols;
88+
float local_max = -INFINITY;
89+
for (int col = threadIdx.x; col < valid_cols; col += blockDim.x) {
90+
local_max = fmaxf(local_max, logits_row[col]);
91+
}
92+
const float max_val = block_max(local_max, shared);
93+
94+
float local_sum = 0.0f;
95+
for (int col = threadIdx.x; col < valid_cols; col += blockDim.x) {
96+
const float value = expf(logits_row[col] - max_val);
97+
probs_row[col] = value;
98+
local_sum += value;
99+
}
100+
const float sum = block_sum(local_sum, shared);
101+
const float inv_sum = sum == 0.0f ? 0.0f : 1.0f / sum;
102+
103+
for (int col = threadIdx.x; col < cols; col += blockDim.x) {
104+
probs_row[col] = col < valid_cols ? probs_row[col] * inv_sum : 0.0f;
105+
}
106+
}
107+
108+
__global__ void causal_softmax_row_kernel(
109+
const float* __restrict__ preatt,
110+
float* __restrict__ att,
111+
int rows,
112+
int time) {
113+
extern __shared__ float shared[];
114+
const int row = blockIdx.x;
115+
if (row >= rows) {
116+
return;
117+
}
118+
const int t = row % time;
119+
const int valid_cols = t + 1;
120+
const float* __restrict__ preatt_row = preatt + row * time;
121+
float* __restrict__ att_row = att + row * time;
122+
123+
float local_max = -INFINITY;
124+
for (int col = threadIdx.x; col < valid_cols; col += blockDim.x) {
125+
local_max = fmaxf(local_max, preatt_row[col]);
126+
}
127+
const float max_val = block_max(local_max, shared);
128+
129+
float local_sum = 0.0f;
130+
for (int col = threadIdx.x; col < valid_cols; col += blockDim.x) {
131+
const float value = expf(preatt_row[col] - max_val);
132+
att_row[col] = value;
133+
local_sum += value;
134+
}
135+
const float sum = block_sum(local_sum, shared);
136+
const float inv_sum = sum == 0.0f ? 0.0f : 1.0f / sum;
137+
138+
for (int col = threadIdx.x; col < time; col += blockDim.x) {
139+
att_row[col] = col < valid_cols ? att_row[col] * inv_sum : 0.0f;
140+
}
141+
}
142+
143+
} // namespace
144+
145+
Status softmax_forward(const TensorView& logits, TensorView probs, int valid_cols, cudaStream_t stream) {
146+
if (!valid_same_shape_f32(logits, probs) || logits.shape.rank != 2 || !fits_int(logits.shape.dims[0]) ||
147+
!fits_int(logits.shape.dims[1])) {
148+
return Status::failure(cudaErrorInvalidValue, "invalid softmax_forward tensors");
149+
}
150+
const int rows = static_cast<int>(logits.shape.dims[0]);
151+
const int cols = static_cast<int>(logits.shape.dims[1]);
152+
if (valid_cols <= 0 || valid_cols > cols) {
153+
return Status::failure(cudaErrorInvalidValue, "invalid softmax_forward valid_cols");
154+
}
155+
156+
DeviceGuard guard(logits.device_id);
157+
const std::size_t shared_bytes = ((kSoftmaxBlockSize + kWarpSize - 1) / kWarpSize) * sizeof(float);
158+
softmax_forward_kernel<<<rows, kSoftmaxBlockSize, shared_bytes, stream>>>(
159+
logits.data_as<const float>(),
160+
probs.data_as<float>(),
161+
rows,
162+
cols,
163+
valid_cols);
164+
return QUADTRIX_CUDA_CHECK(cudaGetLastError());
165+
}
166+
167+
Status causal_softmax_forward(const TensorView& preatt, TensorView att, cudaStream_t stream) {
168+
if (!valid_same_shape_f32(preatt, att) || preatt.shape.rank != 4 || !fits_int(preatt.shape.dims[0]) ||
169+
!fits_int(preatt.shape.dims[1]) || !fits_int(preatt.shape.dims[2]) ||
170+
preatt.shape.dims[2] != preatt.shape.dims[3]) {
171+
return Status::failure(cudaErrorInvalidValue, "invalid causal_softmax_forward tensors");
172+
}
173+
const int rows = static_cast<int>(preatt.shape.dims[0] * preatt.shape.dims[1] * preatt.shape.dims[2]);
174+
const int time = static_cast<int>(preatt.shape.dims[2]);
175+
176+
DeviceGuard guard(preatt.device_id);
177+
const std::size_t shared_bytes = ((kSoftmaxBlockSize + kWarpSize - 1) / kWarpSize) * sizeof(float);
178+
causal_softmax_row_kernel<<<rows, kSoftmaxBlockSize, shared_bytes, stream>>>(
179+
preatt.data_as<const float>(),
180+
att.data_as<float>(),
181+
rows,
182+
time);
183+
return QUADTRIX_CUDA_CHECK(cudaGetLastError());
184+
}
185+
186+
} // namespace cuda
187+
} // namespace quadtrix

0 commit comments

Comments
 (0)