Skip to content

Commit 31d5b44

Browse files
authored
Forbid divergent execution of work-group barriers (#564)
1 parent b435bb2 commit 31d5b44

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

src/KernelAbstractions.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ end
284284
After a `@synchronize` statement all read and writes to global and local memory
285285
from each thread in the workgroup are visible in from all other threads in the
286286
workgroup.
287+
288+
!!! note
289+
`@synchronize()` must be encountered by all workitems of a work-group executing the kernel or by none at all.
287290
"""
288291
macro synchronize()
289292
return quote
@@ -301,10 +304,15 @@ workgroup. `cond` is not allowed to have any visible sideffects.
301304
# Platform differences
302305
- `GPU`: This synchronization will only occur if the `cond` evaluates.
303306
- `CPU`: This synchronization will always occur.
307+
308+
!!! warn
309+
This variant of the `@synchronize` macro violates the requirement that `@synchronize` must be encountered
310+
by all workitems of a work-group executing the kernel or by none at all.
311+
Since v`0.9.34` this version of the macro is deprecated and lowers to `@synchronize()`
304312
"""
305313
macro synchronize(cond)
306314
return quote
307-
$(esc(cond)) && $__synchronize()
315+
$__synchronize()
308316
end
309317
end
310318

src/macros.jl

+39-15
Original file line numberDiff line numberDiff line change
@@ -86,22 +86,24 @@ function transform_gpu!(def, constargs, force_inbounds)
8686
end
8787
end
8888
pushfirst!(def[:args], :__ctx__)
89-
body = def[:body]
89+
new_stmts = Expr[]
90+
body = MacroTools.flatten(def[:body])
91+
stmts = body.args
92+
push!(new_stmts, Expr(:aliasscope))
93+
push!(new_stmts, :(__active_lane__ = $__validindex(__ctx__)))
9094
if force_inbounds
91-
body = quote
92-
@inbounds $(body)
93-
end
95+
push!(new_stmts, Expr(:inbounds, true))
9496
end
95-
body = quote
96-
if $__validindex(__ctx__)
97-
$(body)
98-
end
99-
return nothing
97+
append!(new_stmts, split(emit_gpu, body.args))
98+
if force_inbounds
99+
push!(new_stmts, Expr(:inbounds, :pop))
100100
end
101+
push!(new_stmts, Expr(:popaliasscope))
102+
push!(new_stmts, :(return nothing))
101103
def[:body] = Expr(
102104
:let,
103105
Expr(:block, let_constargs...),
104-
body,
106+
Expr(:block, new_stmts...),
105107
)
106108
return
107109
end
@@ -127,7 +129,7 @@ function transform_cpu!(def, constargs, force_inbounds)
127129
if force_inbounds
128130
push!(new_stmts, Expr(:inbounds, true))
129131
end
130-
append!(new_stmts, split(body.args))
132+
append!(new_stmts, split(emit_cpu, body.args))
131133
if force_inbounds
132134
push!(new_stmts, Expr(:inbounds, :pop))
133135
end
@@ -147,6 +149,7 @@ struct WorkgroupLoop
147149
allocations::Vector{Any}
148150
private_allocations::Vector{Any}
149151
private::Set{Symbol}
152+
terminated_in_sync::Bool
150153
end
151154

152155
is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_))
@@ -167,6 +170,7 @@ end
167170

168171
# TODO proper handling of LineInfo
169172
function split(
173+
emit,
170174
stmts,
171175
indicies = Any[], private = Set{Symbol}(),
172176
)
@@ -182,7 +186,7 @@ function split(
182186
for stmt in stmts
183187
has_sync = find_sync(stmt)
184188
if has_sync
185-
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private))
189+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private), is_sync(stmt))
186190
push!(new_stmts, emit(loop))
187191
allocations = Any[]
188192
private_allocations = Any[]
@@ -197,7 +201,7 @@ function split(
197201
function recurse(expr::Expr)
198202
expr = unblock(expr)
199203
if is_scope_construct(expr) && any(find_sync, expr.args)
200-
new_args = unblock(split(expr.args, deepcopy(indicies), deepcopy(private)))
204+
new_args = unblock(split(emit, expr.args, deepcopy(indicies), deepcopy(private)))
201205
return Expr(expr.head, new_args...)
202206
else
203207
return Expr(expr.head, map(recurse, expr.args)...)
@@ -240,13 +244,13 @@ function split(
240244

241245
# everything since the last `@synchronize`
242246
if !isempty(current)
243-
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private))
247+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private), false)
244248
push!(new_stmts, emit(loop))
245249
end
246250
return new_stmts
247251
end
248252

249-
function emit(loop)
253+
function emit_cpu(loop)
250254
idx = gensym(:I)
251255
for stmt in loop.indicies
252256
# splice index into the i = @index(Cartesian, $idx)
@@ -300,3 +304,23 @@ function emit(loop)
300304

301305
return unblock(Expr(:block, stmts...))
302306
end
307+
308+
function emit_gpu(loop)
309+
stmts = Any[]
310+
311+
body = Expr(:block, loop.stmts...)
312+
loopexpr = quote
313+
$(loop.allocations...)
314+
$(loop.private_allocations...)
315+
if __active_lane__
316+
$(loop.indicies...)
317+
$(unblock(body))
318+
end
319+
end
320+
push!(stmts, loopexpr)
321+
if loop.terminated_in_sync
322+
push!(stmts, :($__synchronize()))
323+
end
324+
325+
return unblock(Expr(:block, stmts...))
326+
end

0 commit comments

Comments
 (0)