Skip to content

Commit f42a13a

Browse files
committed
Don't call intrinsics in 3d dim/idx functions.
Instead call the Rust functions that have the range constraints. That way the 3d version get the same range constraints as the 1d versions. It also avoids the need for some `unsafe` blocks.
1 parent 162e738 commit f42a13a

File tree

1 file changed

+20
-28
lines changed

1 file changed

+20
-28
lines changed

crates/cuda_std/src/thread.rs

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -193,54 +193,46 @@ pub fn grid_dim_z() -> u32 {
193193
#[gpu_only]
194194
#[inline(always)]
195195
pub fn thread_idx() -> UVec3 {
196-
unsafe {
197-
UVec3::new(
198-
__nvvm_thread_idx_x(),
199-
__nvvm_thread_idx_y(),
200-
__nvvm_thread_idx_z(),
201-
)
202-
}
196+
UVec3::new(
197+
thread_idx_x(),
198+
thread_idx_y(),
199+
thread_idx_z(),
200+
)
203201
}
204202

205203
/// Gets the 3d index of the block that the thread currently executing the kernel is located in.
206204
#[gpu_only]
207205
#[inline(always)]
208206
pub fn block_idx() -> UVec3 {
209-
unsafe {
210-
UVec3::new(
211-
__nvvm_block_idx_x(),
212-
__nvvm_block_idx_y(),
213-
__nvvm_block_idx_z(),
214-
)
215-
}
207+
UVec3::new(
208+
block_idx_x(),
209+
block_idx_y(),
210+
block_idx_z(),
211+
)
216212
}
217213

218214
/// Gets the 3d layout of the thread blocks executing this kernel. In other words,
219215
/// how many threads exist in each thread block in every direction.
220216
#[gpu_only]
221217
#[inline(always)]
222218
pub fn block_dim() -> UVec3 {
223-
unsafe {
224-
UVec3::new(
225-
__nvvm_block_dim_x(),
226-
__nvvm_block_dim_y(),
227-
__nvvm_block_dim_z(),
228-
)
229-
}
219+
UVec3::new(
220+
block_dim_x(),
221+
block_dim_y(),
222+
block_dim_z(),
223+
)
230224
}
231225

232226
/// Gets the 3d layout of the block grids executing this kernel. In other words,
233227
/// how many thread blocks exist in each grid in every direction.
234228
#[gpu_only]
235229
#[inline(always)]
236230
pub fn grid_dim() -> UVec3 {
237-
unsafe {
238-
UVec3::new(
239-
__nvvm_grid_dim_x(),
240-
__nvvm_grid_dim_y(),
241-
__nvvm_grid_dim_z(),
242-
)
243-
}
231+
UVec3::new(
232+
grid_dim_x(),
233+
grid_dim_y(),
234+
grid_dim_z(),
235+
)
244236
}
245237

246238
/// Gets the overall thread index, accounting for 1d/2d/3d block/grid dimensions. This

0 commit comments

Comments
 (0)