Skip to content

Commit 450af34

Browse files
committed
Remove __nvvm_{thread,block,grid}_{idx,dim}_[xyz] intrinsics.
`core` has equivalents, might as well use them instead.
1 parent f42a13a commit 450af34

File tree

4 files changed

+15
-123
lines changed

4 files changed

+15
-123
lines changed

crates/cuda_std/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#![allow(internal_features)]
2525
#![cfg_attr(
2626
target_os = "cuda",
27-
feature(alloc_error_handler, asm_experimental_arch, link_llvm_intrinsics)
27+
feature(alloc_error_handler, asm_experimental_arch, link_llvm_intrinsics, stdarch_nvptx)
2828
)]
2929

3030
extern crate alloc;

crates/cuda_std/src/thread.rs

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,6 @@ use glam::{UVec2, UVec3};
6363
// different calling conventions dont exist in nvptx, so we just use C as a placeholder.
6464
extern "C" {
6565
// defined in libintrinsics.ll
66-
fn __nvvm_thread_idx_x() -> u32;
67-
fn __nvvm_thread_idx_y() -> u32;
68-
fn __nvvm_thread_idx_z() -> u32;
69-
70-
fn __nvvm_block_dim_x() -> u32;
71-
fn __nvvm_block_dim_y() -> u32;
72-
fn __nvvm_block_dim_z() -> u32;
73-
74-
fn __nvvm_block_idx_x() -> u32;
75-
fn __nvvm_block_idx_y() -> u32;
76-
fn __nvvm_block_idx_z() -> u32;
77-
78-
fn __nvvm_grid_dim_x() -> u32;
79-
fn __nvvm_grid_dim_y() -> u32;
80-
fn __nvvm_grid_dim_z() -> u32;
81-
8266
fn __nvvm_warp_size() -> u32;
8367

8468
fn __nvvm_block_barrier();
@@ -92,8 +76,8 @@ extern "C" {
9276
macro_rules! in_range {
9377
// The bounds were taken mostly from the cuda C++ programming guide. I also
9478
// double-checked with what cuda clang does by checking its emitted llvm ir's scalar metadata.
95-
($func_name:ident, $range:expr) => {{
96-
let val = unsafe { $func_name() };
79+
($func_name:path, $range:expr) => {{
80+
let val = unsafe { $func_name() as u32 };
9781
if !$range.contains(&val) {
9882
// SAFETY: this condition is declared unreachable by compute capability max bound.
9983
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
@@ -109,84 +93,84 @@ macro_rules! in_range {
10993
#[inline(always)]
11094
pub fn thread_idx_x() -> u32 {
11195
// The range is derived from the `block_idx_x` range.
112-
in_range!(__nvvm_thread_idx_x, 0..1024)
96+
in_range!(core::arch::nvptx::_thread_idx_x, 0..1024)
11397
}
11498

11599
#[gpu_only]
116100
#[inline(always)]
117101
pub fn thread_idx_y() -> u32 {
118102
// The range is derived from the `block_idx_y` range.
119-
in_range!(__nvvm_thread_idx_y, 0..1024)
103+
in_range!(core::arch::nvptx::_thread_idx_y, 0..1024)
120104
}
121105

122106
#[gpu_only]
123107
#[inline(always)]
124108
pub fn thread_idx_z() -> u32 {
125109
// The range is derived from the `block_idx_z` range.
126-
in_range!(__nvvm_thread_idx_z, 0..64)
110+
in_range!(core::arch::nvptx::_thread_idx_z, 0..64)
127111
}
128112

129113
#[gpu_only]
130114
#[inline(always)]
131115
pub fn block_idx_x() -> u32 {
132116
// The range is derived from the `grid_idx_x` range.
133-
in_range!(__nvvm_block_idx_x, 0..2147483647)
117+
in_range!(core::arch::nvptx::_block_idx_x, 0..2147483647)
134118
}
135119

136120
#[gpu_only]
137121
#[inline(always)]
138122
pub fn block_idx_y() -> u32 {
139123
// The range is derived from the `grid_idx_y` range.
140-
in_range!(__nvvm_block_idx_y, 0..65535)
124+
in_range!(core::arch::nvptx::_block_idx_y, 0..65535)
141125
}
142126

143127
#[gpu_only]
144128
#[inline(always)]
145129
pub fn block_idx_z() -> u32 {
146130
// The range is derived from the `grid_idx_z` range.
147-
in_range!(__nvvm_block_idx_z, 0..65535)
131+
in_range!(core::arch::nvptx::_block_idx_z, 0..65535)
148132
}
149133

150134
#[gpu_only]
151135
#[inline(always)]
152136
pub fn block_dim_x() -> u32 {
153137
// CUDA Compute Capabilities: "Maximum x- or y-dimensionality of a block" is 1024.
154-
in_range!(__nvvm_block_dim_x, 1..=1024)
138+
in_range!(core::arch::nvptx::_block_dim_x, 1..=1024)
155139
}
156140

157141
#[gpu_only]
158142
#[inline(always)]
159143
pub fn block_dim_y() -> u32 {
160144
// CUDA Compute Capabilities: "Maximum x- or y-dimensionality of a block" is 1024.
161-
in_range!(__nvvm_block_dim_y, 1..=1024)
145+
in_range!(core::arch::nvptx::_block_dim_y, 1..=1024)
162146
}
163147

164148
#[gpu_only]
165149
#[inline(always)]
166150
pub fn block_dim_z() -> u32 {
167151
// CUDA Compute Capabilities: "Maximum z-dimension of a block" is 64.
168-
in_range!(__nvvm_block_dim_z, 1..=64)
152+
in_range!(core::arch::nvptx::_block_dim_z, 1..=64)
169153
}
170154

171155
#[gpu_only]
172156
#[inline(always)]
173157
pub fn grid_dim_x() -> u32 {
174158
// CUDA Compute Capabilities: "Maximum x-dimension of a grid of thread blocks" is 2^32 - 1.
175-
in_range!(__nvvm_grid_dim_x, 1..=2147483647)
159+
in_range!(core::arch::nvptx::_grid_dim_x, 1..=2147483647)
176160
}
177161

178162
#[gpu_only]
179163
#[inline(always)]
180164
pub fn grid_dim_y() -> u32 {
181165
// CUDA Compute Capabilities: "Maximum y- or z-dimension of a grid of thread blocks" is 65535.
182-
in_range!(__nvvm_grid_dim_y, 1..=65535)
166+
in_range!(core::arch::nvptx::_grid_dim_y, 1..=65535)
183167
}
184168

185169
#[gpu_only]
186170
#[inline(always)]
187171
pub fn grid_dim_z() -> u32 {
188172
// CUDA Compute Capabilities: "Maximum y- or z-dimension of a grid of thread blocks" is 65535.
189-
in_range!(__nvvm_grid_dim_z, 1..=65535)
173+
in_range!(core::arch::nvptx::_grid_dim_z, 1..=65535)
190174
}
191175

192176
/// Gets the 3d index of the thread currently executing the kernel.
-620 Bytes
Binary file not shown.

crates/rustc_codegen_nvvm/libintrinsics.ll

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -8,86 +8,6 @@ source_filename = "libintrinsics"
88
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
99
target triple = "nvptx64-nvidia-cuda"
1010

11-
; thread ----
12-
13-
define i32 @__nvvm_thread_idx_x() #0 {
14-
start:
15-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
16-
ret i32 %0
17-
}
18-
19-
define i32 @__nvvm_thread_idx_y() #0 {
20-
start:
21-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
22-
ret i32 %0
23-
}
24-
25-
define i32 @__nvvm_thread_idx_z() #0 {
26-
start:
27-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
28-
ret i32 %0
29-
}
30-
31-
; block dimension ----
32-
33-
define i32 @__nvvm_block_dim_x() #0 {
34-
start:
35-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
36-
ret i32 %0
37-
}
38-
39-
define i32 @__nvvm_block_dim_y() #0 {
40-
start:
41-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
42-
ret i32 %0
43-
}
44-
45-
define i32 @__nvvm_block_dim_z() #0 {
46-
start:
47-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
48-
ret i32 %0
49-
}
50-
51-
; block idx ----
52-
53-
define i32 @__nvvm_block_idx_x() #0 {
54-
start:
55-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
56-
ret i32 %0
57-
}
58-
59-
define i32 @__nvvm_block_idx_y() #0 {
60-
start:
61-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
62-
ret i32 %0
63-
}
64-
65-
define i32 @__nvvm_block_idx_z() #0 {
66-
start:
67-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
68-
ret i32 %0
69-
}
70-
71-
; grid dimension ----
72-
73-
define i32 @__nvvm_grid_dim_x() #0 {
74-
start:
75-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
76-
ret i32 %0
77-
}
78-
79-
define i32 @__nvvm_grid_dim_y() #0 {
80-
start:
81-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
82-
ret i32 %0
83-
}
84-
85-
define i32 @__nvvm_grid_dim_z() #0 {
86-
start:
87-
%0 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
88-
ret i32 %0
89-
}
90-
9111
; warp ----
9212

9313
define i32 @__nvvm_warp_size() #0 {
@@ -96,18 +16,6 @@ start:
9616
ret i32 %0
9717
}
9818

99-
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
100-
declare i32 @llvm.nvvm.read.ptx.sreg.tid.y()
101-
declare i32 @llvm.nvvm.read.ptx.sreg.tid.z()
102-
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
103-
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
104-
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
105-
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
106-
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
107-
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
108-
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
109-
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
110-
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
11119
declare i32 @llvm.nvvm.read.ptx.sreg.warpsize()
11220

11321
; other ----

0 commit comments

Comments
 (0)