Skip to content

Commit 38bd551

Browse files
committed
feat(cuda): implement causal_mask_forward kernel and host wrapper
1 parent e88b5e5 commit 38bd551

1 file changed

Lines changed: 44 additions & 0 deletions

File tree

cuda/KERNAL/trimat_forward.cu

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "../includes/trimat.cuh"
2+
3+
#include "../includes/runtime.cuh"
4+
#include "../includes/utils.cuh"
5+
6+
namespace quadtrix {
7+
namespace cuda {
8+
namespace {
9+
10+
bool valid_square_f32_cuda(const TensorView& matrix) {
11+
return matrix.data != nullptr && matrix.device == DeviceKind::CUDA && matrix.dtype == DType::F32 &&
12+
matrix.shape.rank == 4 && matrix.shape.is_contiguous() && matrix.shape.dims[2] == matrix.shape.dims[3];
13+
}
14+
15+
__global__ void causal_mask_kernel(float* matrix, std::size_t n, int time, float masked_value) {
16+
const std::size_t idx = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
17+
if (idx >= n) {
18+
return;
19+
}
20+
const int col = static_cast<int>(idx % time);
21+
const int row = static_cast<int>((idx / time) % time);
22+
if (col > row) {
23+
matrix[idx] = masked_value;
24+
}
25+
}
26+
27+
} // namespace
28+
29+
Status causal_mask_forward(TensorView matrix, float masked_value, cudaStream_t stream) {
30+
if (!valid_square_f32_cuda(matrix)) {
31+
return Status::failure(cudaErrorInvalidValue, "invalid causal_mask_forward tensor");
32+
}
33+
34+
DeviceGuard guard(matrix.device_id);
35+
causal_mask_kernel<<<one_dim_grid(matrix.numel()), kDefaultBlockSize, 0, stream>>>(
36+
matrix.data_as<float>(),
37+
matrix.numel(),
38+
static_cast<int>(matrix.shape.dims[2]),
39+
masked_value);
40+
return QUADTRIX_CUDA_CHECK(cudaGetLastError());
41+
}
42+
43+
} // namespace cuda
44+
} // namespace quadtrix

0 commit comments

Comments
 (0)