@@ -1256,23 +1256,47 @@ function cfg_simplify!(ir::IRCode)
12561256 return finish (compact)
12571257end
12581258
1259- function is_allocation (stmt)
1259+ # function is_known_fcall(stmt::Expr, @nospecialize(func))
1260+ # isexpr(stmt, :foreigncall) || return false
1261+ # s = stmt.args[1]
1262+ # isa(s, QuoteNode) && (s = s.value)
1263+ # return s === func
1264+ # end
1265+
1266+ function is_known_fcall (stmt:: Expr , funcs:: Vector{Symbol} )
12601267 isexpr (stmt, :foreigncall ) || return false
12611268 s = stmt. args[1 ]
12621269 isa (s, QuoteNode) && (s = s. value)
1263- return s === :jl_alloc_array_1d
1270+ # return any(e -> s === e, funcs)
1271+ return true in map (e -> s === e, funcs)
1272+ end
1273+
1274+ function is_allocation (stmt:: Expr )
1275+ isexpr (stmt, :foreigncall ) || return false
1276+ s = stmt. args[1 ]
1277+ isa (s, QuoteNode) && (s = s. value)
1278+ return (s === :jl_alloc_array_1d
1279+ || s === :jl_alloc_array_2d
1280+ || s === :jl_alloc_array_3d
1281+ || s === :jl_new_array )
12641282end
12651283
12661284function memory_opt! (ir:: IRCode )
12671285 compact = IncrementalCompact (ir, false )
12681286 uses = IdDict {Int, Vector{Int}} ()
1269- relevant = IdSet {Int} ()
1270- revisit = Int[]
1271- function mark_val (val)
1287+ relevant = IdSet {Int} () # allocations
1288+ revisit = Int[] # potential targets for a mutating_arrayfreeze drop-in
1289+
1290+ function mark_escape (@nospecialize val)
12721291 isa (val, SSAValue) || return
1292+ # println(val.id, " escaped.")
12731293 val. id in relevant && pop! (relevant, val. id)
12741294 end
1295+
12751296 for ((_, idx), stmt) in compact
1297+
1298+ # println("idx: ", idx, " = ", stmt)
1299+
12761300 if isa (stmt, ReturnNode)
12771301 isdefined (stmt, :val ) || continue
12781302 val = stmt. val
@@ -1281,51 +1305,171 @@ function memory_opt!(ir::IRCode)
12811305 push! (uses[val. id], idx)
12821306 end
12831307 continue
1308+
1309+ # check for phinodes that are possibly allocations
1310+ elseif isa (stmt, PhiNode)
1311+
1312+ # this loop seems like a waste, but using map here didn't go well
1313+ defined = true
1314+ for i = 1 : length (stmt. values)
1315+ if ! isassigned (stmt. values, i)
1316+ defined = false
1317+ end
1318+ end
1319+
1320+ defined || continue
1321+
1322+ for val in stmt. values
1323+ if isa (val, SSAValue) && val. id in relevant
1324+ # println("Adding ", idx ," to relevant from PhiNode: " , stmt)
1325+ push! (relevant, idx)
1326+ end
1327+ end
12841328 end
1329+
12851330 (isexpr (stmt, :call ) || isexpr (stmt, :foreigncall )) || continue
1331+
12861332 if is_allocation (stmt)
12871333 push! (relevant, idx)
12881334 # TODO : Mark everything else here
12891335 continue
12901336 end
1291- # TODO : Replace this by interprocedural escape analysis
1292- if is_known_call (stmt, arrayset, compact)
1337+
1338+ if is_known_call (stmt, arrayset, compact) && length (stmt . args) >= 5
12931339 # The value being set escapes, everything else doesn't
1294- mark_val (stmt. args[4 ])
1340+ mark_escape (stmt. args[4 ])
1341+
1342+ arr = stmt. args[3 ]
1343+
1344+ if isa (arr, SSAValue) && arr. id in relevant
1345+ (haskey (uses, arr. id)) || (uses[arr. id] = Int[])
1346+ push! (uses[arr. id], idx)
1347+ end
1348+
1349+ elseif is_known_call (stmt, arrayref, compact) && length (stmt. args) == 4
12951350 arr = stmt. args[3 ]
1351+
1352+ if isa (arr, SSAValue) && arr. id in relevant
1353+ (haskey (uses, arr. id)) || (uses[arr. id] = Int[])
1354+ push! (uses[arr. id], idx)
1355+ end
1356+
1357+ elseif is_known_call (stmt, setindex!, compact) && length (stmt. args) == 4
1358+ # println("setindex!: ", stmt.args)
1359+ # handle similarly to arrayset
1360+ # escape the value being set
1361+ val = stmt. args[3 ]
1362+ mark_escape (val)
1363+ # track usage of arr for dominance analysis
1364+ arr = stmt. args[2 ]
1365+ if isa (arr, SSAValue) && arr. id in relevant
1366+ (haskey (uses, arr. id)) || (uses[arr. id] = Int[])
1367+ push! (uses[arr. id], idx)
1368+ end
1369+
1370+ # these foreigncalls have similar structure and don't escape our array, so handle them all at once
1371+ elseif is_known_fcall (stmt, [:jl_array_ptr , :jl_array_copy ]) && length (stmt. args) == 6
1372+ # println("is_known_fcall: ", stmt)
1373+ arr = stmt. args[6 ]
1374+
1375+ # just record usage info
12961376 if isa (arr, SSAValue) && arr. id in relevant
12971377 (haskey (uses, arr. id)) || (uses[arr. id] = Int[])
12981378 push! (uses[arr. id], idx)
12991379 end
1380+
1381+ # elseif is_known_fcall(stmt, :jl_array_ptr) && length(stmt.args) == 6
1382+ # arr = stmt.args[6]
1383+
1384+ # if isa(arr, SSAValue) && arr.id in relevant
1385+ # (haskey(uses, arr.id)) || (uses[arr.id] = Int[])
1386+ # push!(uses[arr.id], idx)
1387+ # end
1388+
1389+ # elseif is_known_fcall(stmt, :jl_array_copy) && length(stmt.args) == 6
1390+ # arr = stmt.args[6]
1391+
1392+ # if isa(arr, SSAValue) && arr.id in relevant
1393+ # (haskey(uses, arr.id)) || (uses[arr.id] = Int[])
1394+ # push!(uses[arr.id], idx)
1395+ # end
1396+
1397+ elseif is_known_call (stmt, arraysize, compact) && isa (stmt. args[2 ], SSAValue) # && isa(stmt.args[3], Number)
1398+ arr = stmt. args[2 ]
1399+ # dim = stmt.args[3]
1400+ # typ = types(compact)[arr]
1401+
1402+ # if isa(typ, Core.Const)
1403+ # typ = typ.val
1404+ # end
1405+
1406+ # NEW: since exceptions no longer escape arrays, we can just assume no escape
1407+
1408+ if arr. id in relevant
1409+ (haskey (uses, arr. id)) || (uses[arr. id] = Int[])
1410+ push! (uses[arr. id], idx)
1411+ end
1412+
1413+ # # make sure this call isn't going to throw
1414+ # if isa(typ, Type) && typ <: AbstractArray && dim >= 1
1415+ # # don't escape the array, but mark usage for dom analysis
1416+ # if arr.id in relevant
1417+ # (haskey(uses, arr.id)) || (uses[arr.id] = Int[])
1418+ # push!(uses[arr.id], idx)
1419+ # end
1420+
1421+ # else # if this call throws or we can't tell, assume all uses escape
1422+ # for ur in userefs(stmt)
1423+ # mark_escape(ur[])
1424+ # end
1425+ # end
1426+
13001427 elseif is_known_call (stmt, Core. arrayfreeze, compact) && isa (stmt. args[2 ], SSAValue)
1428+ # mark these for potential replacement with mutating_arrayfreeze
13011429 push! (revisit, idx)
1430+
13021431 else
1303- # For now we assume everything escapes
1304- # TODO : We could handle PhiNodes specially and improve this
1432+ # Assume everything else escapes
13051433 for ur in userefs (stmt)
1306- mark_val (ur[])
1434+ mark_escape (ur[])
13071435 end
13081436 end
13091437 end
1438+
13101439 ir = finish (compact)
13111440 isempty (revisit) && return ir
1441+
13121442 domtree = construct_domtree (ir. cfg. blocks)
1443+
13131444 for idx in revisit
13141445 # Make sure that the value we reference didn't escape
1315- id = ir. stmts[idx][:inst ]. args[2 ]. id
1446+ stmt = ir. stmts[idx][:inst ]:: Expr
1447+ id = (stmt. args[2 ]:: SSAValue ). id
1448+ # print("Relevant: ")
1449+ # for id in relevant
1450+ # print(id, " ")
1451+ # end
1452+ # println(" ")
1453+ # print("Revisit: ")
1454+ # Main.Base.show(revisit)
1455+ # println()
13161456 (id in relevant) || continue
13171457
1458+ # println("Revisiting ", stmt)
1459+
13181460 # We're ok to steal the memory if we don't dominate any uses
13191461 ok = true
1320- for use in uses[id]
1321- if ssadominates (ir, domtree, idx, use)
1322- ok = false
1323- break
1462+ if haskey (uses, id)
1463+ for use in uses[id]
1464+ if ssadominates (ir, domtree, idx, use)
1465+ ok = false
1466+ break
1467+ end
13241468 end
13251469 end
13261470 ok || continue
1327-
1328- ir . stmts[idx][ :inst ] . args[1 ] = Core. mutating_arrayfreeze
1471+ # println("Optimization of ", stmt)
1472+ stmt . args[1 ] = Core. mutating_arrayfreeze
13291473 end
13301474 return ir
13311475end
0 commit comments