@@ -86,22 +86,24 @@ function transform_gpu!(def, constargs, force_inbounds)
86
86
end
87
87
end
88
88
pushfirst! (def[:args ], :__ctx__ )
89
- body = def[:body ]
89
+ new_stmts = Expr[]
90
+ body = MacroTools. flatten (def[:body ])
91
+ stmts = body. args
92
+ push! (new_stmts, Expr (:aliasscope ))
93
+ push! (new_stmts, :(__active_lane__ = $ __validindex (__ctx__)))
90
94
if force_inbounds
91
- body = quote
92
- @inbounds $ (body)
93
- end
95
+ push! (new_stmts, Expr (:inbounds , true ))
94
96
end
95
- body = quote
96
- if $ __validindex (__ctx__)
97
- $ (body)
98
- end
99
- return nothing
97
+ append! (new_stmts, split (emit_gpu, body. args))
98
+ if force_inbounds
99
+ push! (new_stmts, Expr (:inbounds , :pop ))
100
100
end
101
+ push! (new_stmts, Expr (:popaliasscope ))
102
+ push! (new_stmts, :(return nothing ))
101
103
def[:body ] = Expr (
102
104
:let ,
103
105
Expr (:block , let_constargs... ),
104
- body ,
106
+ Expr ( :block , new_stmts ... ) ,
105
107
)
106
108
return
107
109
end
@@ -127,7 +129,7 @@ function transform_cpu!(def, constargs, force_inbounds)
127
129
if force_inbounds
128
130
push! (new_stmts, Expr (:inbounds , true ))
129
131
end
130
- append! (new_stmts, split (body. args))
132
+ append! (new_stmts, split (emit_cpu, body. args))
131
133
if force_inbounds
132
134
push! (new_stmts, Expr (:inbounds , :pop ))
133
135
end
@@ -147,6 +149,7 @@ struct WorkgroupLoop
147
149
allocations:: Vector{Any}
148
150
private_allocations:: Vector{Any}
149
151
private:: Set{Symbol}
152
+ terminated_in_sync:: Bool
150
153
end
151
154
152
155
is_sync (expr) = @capture (expr, @synchronize () | @synchronize (a_))
167
170
168
171
# TODO proper handling of LineInfo
169
172
function split (
173
+ emit,
170
174
stmts,
171
175
indicies = Any[], private = Set {Symbol} (),
172
176
)
@@ -182,7 +186,7 @@ function split(
182
186
for stmt in stmts
183
187
has_sync = find_sync (stmt)
184
188
if has_sync
185
- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private))
189
+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private), is_sync (stmt) )
186
190
push! (new_stmts, emit (loop))
187
191
allocations = Any[]
188
192
private_allocations = Any[]
@@ -197,7 +201,7 @@ function split(
197
201
function recurse (expr:: Expr )
198
202
expr = unblock (expr)
199
203
if is_scope_construct (expr) && any (find_sync, expr. args)
200
- new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private)))
204
+ new_args = unblock (split (emit, expr. args, deepcopy (indicies), deepcopy (private)))
201
205
return Expr (expr. head, new_args... )
202
206
else
203
207
return Expr (expr. head, map (recurse, expr. args)... )
@@ -240,13 +244,13 @@ function split(
240
244
241
245
# everything since the last `@synchronize`
242
246
if ! isempty (current)
243
- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private))
247
+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private), false )
244
248
push! (new_stmts, emit (loop))
245
249
end
246
250
return new_stmts
247
251
end
248
252
249
- function emit (loop)
253
+ function emit_cpu (loop)
250
254
idx = gensym (:I )
251
255
for stmt in loop. indicies
252
256
# splice index into the i = @index(Cartesian, $idx)
@@ -300,3 +304,23 @@ function emit(loop)
300
304
301
305
return unblock (Expr (:block , stmts... ))
302
306
end
307
+
308
+ function emit_gpu (loop)
309
+ stmts = Any[]
310
+
311
+ body = Expr (:block , loop. stmts... )
312
+ loopexpr = quote
313
+ $ (loop. allocations... )
314
+ $ (loop. private_allocations... )
315
+ if __active_lane__
316
+ $ (loop. indicies... )
317
+ $ (unblock (body))
318
+ end
319
+ end
320
+ push! (stmts, loopexpr)
321
+ if loop. terminated_in_sync
322
+ push! (stmts, :($ __synchronize ()))
323
+ end
324
+
325
+ return unblock (Expr (:block , stmts... ))
326
+ end
0 commit comments