From 25635ea42e8d012d95ebdd14c3b9f9b4c69d9c81 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 14 Jan 2022 05:21:10 +0900 Subject: [PATCH 1/5] optimizer: Julia-level escape analysis This commit ports [EscapeAnalysis.jl](https://github.com/aviatesk/EscapeAnalysis.jl) into Julia base. You can find the documentation of this escape analysis at [this GitHub page](https://aviatesk.github.io/EscapeAnalysis.jl/dev/)[^1]. [^1]: The same documentation will be included into Julia's developer documentation by this commit. This escape analysis will hopefully be an enabling technology for various memory-related optimizations at Julia's high level compilation pipeline. Possible target optimization includes alias aware SROA (#43888), array SROA (#43909), `mutating_arrayfreeze` optimization (#42465), stack allocation of mutables, finalizer elision and so on[^2]. [^2]: It would be also interesting if LLVM-level optimizations can consume IPO information derived by this escape analysis to broaden optimization possibilities. The primary motivation for porting EA in this PR is to check its impact on latency as well as to get feedbacks from a broader range of developers. The plan is that we first introduce EA in this commit, and then merge the depending PRs built on top of this commit like #43888, #43909 and #42465 This commit simply defines and runs EA inside Julia base compiler and enables the existing test suite with it. In this commit, we just run EA before inlining to generate IPO cache. The depending PRs, EA will be invoked again after inlining to be used for various local optimizations. --- base/boot.jl | 60 +- base/compiler/bootstrap.jl | 10 +- base/compiler/compiler.jl | 2 + base/compiler/optimize.jl | 103 +- .../ssair/EscapeAnalysis/EscapeAnalysis.jl | 1913 ++++++++++++++ .../ssair/EscapeAnalysis/disjoint_set.jl | 143 ++ .../ssair/EscapeAnalysis/interprocedural.jl | 151 ++ base/compiler/ssair/driver.jl | 6 +- base/compiler/tfuncs.jl | 2 +- base/compiler/typeinfer.jl | 2 +- base/compiler/types.jl | 12 +- base/compiler/utilities.jl | 4 + doc/make.jl | 1 + doc/src/devdocs/EscapeAnalysis.md | 363 +++ doc/src/devdocs/llvm.md | 2 +- src/dump.c | 4 + src/gf.c | 19 +- src/jltypes.c | 12 +- src/julia.h | 1 + test/choosetests.jl | 5 +- test/compiler/EscapeAnalysis/EAUtils.jl | 385 +++ .../EscapeAnalysis/interprocedural.jl | 264 ++ test/compiler/EscapeAnalysis/local.jl | 2206 +++++++++++++++++ test/compiler/EscapeAnalysis/setup.jl | 71 + 24 files changed, 5654 insertions(+), 87 deletions(-) create mode 100644 base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl create mode 100644 base/compiler/ssair/EscapeAnalysis/disjoint_set.jl create mode 100644 base/compiler/ssair/EscapeAnalysis/interprocedural.jl create mode 100644 doc/src/devdocs/EscapeAnalysis.md create mode 100644 test/compiler/EscapeAnalysis/EAUtils.jl create mode 100644 test/compiler/EscapeAnalysis/interprocedural.jl create mode 100644 test/compiler/EscapeAnalysis/local.jl create mode 100644 test/compiler/EscapeAnalysis/setup.jl diff --git a/base/boot.jl b/base/boot.jl index ecc037407685e..290a98cbf2bbd 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -401,33 +401,39 @@ _new(:QuoteNode, :Any) _new(:SSAValue, :Int) _new(:Argument, :Int) _new(:ReturnNode, :Any) -eval(Core, :(ReturnNode() = $(Expr(:new, :ReturnNode)))) # unassigned val indicates unreachable -eval(Core, :(GotoIfNot(@nospecialize(cond), dest::Int) = $(Expr(:new, :GotoIfNot, :cond, :dest)))) -eval(Core, :(LineNumberNode(l::Int) = $(Expr(:new, :LineNumberNode, :l, nothing)))) -eval(Core, :(LineNumberNode(l::Int, @nospecialize(f)) = $(Expr(:new, :LineNumberNode, :l, :f)))) -LineNumberNode(l::Int, f::String) = LineNumberNode(l, Symbol(f)) -eval(Core, :(GlobalRef(m::Module, s::Symbol) = $(Expr(:new, :GlobalRef, :m, :s)))) -eval(Core, :(SlotNumber(n::Int) = $(Expr(:new, :SlotNumber, :n)))) -eval(Core, :(TypedSlot(n::Int, @nospecialize(t)) = $(Expr(:new, :TypedSlot, :n, :t)))) -eval(Core, :(PhiNode(edges::Array{Int32, 1}, values::Array{Any, 1}) = $(Expr(:new, :PhiNode, :edges, :values)))) -eval(Core, :(PiNode(val, typ) = $(Expr(:new, :PiNode, :val, :typ)))) -eval(Core, :(PhiCNode(values::Array{Any, 1}) = $(Expr(:new, :PhiCNode, :values)))) -eval(Core, :(UpsilonNode(val) = $(Expr(:new, :UpsilonNode, :val)))) -eval(Core, :(UpsilonNode() = $(Expr(:new, :UpsilonNode)))) -eval(Core, :(LineInfoNode(mod::Module, @nospecialize(method), file::Symbol, line::Int, inlined_at::Int) = - $(Expr(:new, :LineInfoNode, :mod, :method, :file, :line, :inlined_at)))) -eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const), - @nospecialize(inferred), const_flags::Int32, - min_world::UInt, max_world::UInt, ipo_effects::UInt8, effects::UInt8, - relocatability::UInt8) = - ccall(:jl_new_codeinst, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt, UInt8, UInt8, UInt8), - mi, rettype, inferred_const, inferred, const_flags, min_world, max_world, ipo_effects, effects, relocatability))) -eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v)))) -eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields)))) -eval(Core, :(PartialOpaque(@nospecialize(typ), @nospecialize(env), isva::Bool, parent::MethodInstance, source::Method) = $(Expr(:new, :PartialOpaque, :typ, :env, :isva, :parent, :source)))) -eval(Core, :(InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype)))) -eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) = - $(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers)))) +eval(Core, quote + ReturnNode() = $(Expr(:new, :ReturnNode)) # unassigned val indicates unreachable + GotoIfNot(@nospecialize(cond), dest::Int) = $(Expr(:new, :GotoIfNot, :cond, :dest)) + LineNumberNode(l::Int) = $(Expr(:new, :LineNumberNode, :l, nothing)) + function LineNumberNode(l::Int, @nospecialize(f)) + isa(f, String) && (f = Symbol(f)) + return $(Expr(:new, :LineNumberNode, :l, :f)) + end + LineInfoNode(mod::Module, @nospecialize(method), file::Symbol, line::Int, inlined_at::Int) = + $(Expr(:new, :LineInfoNode, :mod, :method, :file, :line, :inlined_at)) + GlobalRef(m::Module, s::Symbol) = $(Expr(:new, :GlobalRef, :m, :s)) + SlotNumber(n::Int) = $(Expr(:new, :SlotNumber, :n)) + TypedSlot(n::Int, @nospecialize(t)) = $(Expr(:new, :TypedSlot, :n, :t)) + PhiNode(edges::Array{Int32, 1}, values::Array{Any, 1}) = $(Expr(:new, :PhiNode, :edges, :values)) + PiNode(@nospecialize(val), @nospecialize(typ)) = $(Expr(:new, :PiNode, :val, :typ)) + PhiCNode(values::Array{Any, 1}) = $(Expr(:new, :PhiCNode, :values)) + UpsilonNode(@nospecialize(val)) = $(Expr(:new, :UpsilonNode, :val)) + UpsilonNode() = $(Expr(:new, :UpsilonNode)) + function CodeInstance( + mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const), + @nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt, + ipo_effects::UInt8, effects::UInt8, @nospecialize(argescapes#=::Union{Nothing,Vector{ArgEscapeInfo}}=#), + relocatability::UInt8) + return ccall(:jl_new_codeinst, Ref{CodeInstance}, + (Any, Any, Any, Any, Int32, UInt, UInt, UInt8, UInt8, Any, UInt8), + mi, rettype, inferred_const, inferred, const_flags, min_world, max_world, ipo_effects, effects, argescapes, relocatability) + end + Const(@nospecialize(v)) = $(Expr(:new, :Const, :v)) + PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields)) + PartialOpaque(@nospecialize(typ), @nospecialize(env), isva::Bool, parent::MethodInstance, source::Method) = $(Expr(:new, :PartialOpaque, :typ, :env, :isva, :parent, :source)) + InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype)) + MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) = $(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers)) +end) Module(name::Symbol=:anonymous, std_imports::Bool=true, default_names::Bool=true) = ccall(:jl_f_new_module, Ref{Module}, (Any, Bool, Bool), name, std_imports, default_names) diff --git a/base/compiler/bootstrap.jl b/base/compiler/bootstrap.jl index 2517b181d2804..1989d8aa57393 100644 --- a/base/compiler/bootstrap.jl +++ b/base/compiler/bootstrap.jl @@ -11,10 +11,11 @@ let world = get_world_counter() interp = NativeInterpreter(world) + analyze_escapes_tt = Tuple{typeof(analyze_escapes), IRCode, Int, Bool, typeof(get_escape_cache(code_cache(interp)))} fs = Any[ # we first create caches for the optimizer, because they contain many loop constructions # and they're better to not run in interpreter even during bootstrapping - run_passes, + analyze_escapes_tt, run_passes, # then we create caches for inference entries typeinf_ext, typeinf, typeinf_edge, ] @@ -32,7 +33,12 @@ let end starttime = time() for f in fs - for m in _methods_by_ftype(Tuple{typeof(f), Vararg{Any}}, 10, typemax(UInt)) + if isa(f, DataType) && f.name === typename(Tuple) + tt = f + else + tt = Tuple{typeof(f), Vararg{Any}} + end + for m in _methods_by_ftype(tt, 10, typemax(UInt)) # remove any TypeVars from the intersection typ = Any[m.spec_types.parameters...] for i = 1:length(typ) diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 41e045773fb06..d13fb9e21b483 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -97,6 +97,8 @@ ntuple(f, n) = (Any[f(i) for i = 1:n]...,) # core docsystem include("docs/core.jl") +import Core.Compiler.CoreDocs +Core.atdoc!(CoreDocs.docm) # sorting function sort end diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 58f20b5ef2a0c..635e53a9e1f1d 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -1,5 +1,35 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license +############# +# constants # +############# + +# The slot has uses that are not statically dominated by any assignment +# This is implied by `SLOT_USEDUNDEF`. +# If this is not set, all the uses are (statically) dominated by the defs. +# In particular, if a slot has `AssignedOnce && !StaticUndef`, it is an SSA. +const SLOT_STATICUNDEF = 1 # slot might be used before it is defined (structurally) +const SLOT_ASSIGNEDONCE = 16 # slot is assigned to only once +const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError +# const SLOT_CALLED = 64 + +# NOTE make sure to sync the flag definitions below with julia.h and `jl_code_info_set_ir` in method.c + +const IR_FLAG_NULL = 0x00 +# This statement is marked as @inbounds by user. +# Ff replaced by inlining, any contained boundschecks may be removed. +const IR_FLAG_INBOUNDS = 0x01 << 0 +# This statement is marked as @inline by user +const IR_FLAG_INLINE = 0x01 << 1 +# This statement is marked as @noinline by user +const IR_FLAG_NOINLINE = 0x01 << 2 +const IR_FLAG_THROW_BLOCK = 0x01 << 3 +# This statement may be removed if its result is unused. In particular it must +# thus be both pure and effect free. +const IR_FLAG_EFFECT_FREE = 0x01 << 4 + +const TOP_TUPLE = GlobalRef(Core, :tuple) + ##################### # OptimizationState # ##################### @@ -21,10 +51,10 @@ function push!(et::EdgeTracker, ci::CodeInstance) push!(et, ci.def) end -struct InliningState{S <: Union{EdgeTracker, Nothing}, T, I<:AbstractInterpreter} +struct InliningState{S <: Union{EdgeTracker, Nothing}, MICache, I<:AbstractInterpreter} params::OptimizationParams et::S - mi_cache::T + mi_cache::MICache # TODO move this to `OptimizationState` (as used by EscapeAnalysis as well) interp::I end @@ -52,7 +82,34 @@ function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_f return nothing end +function argextype end # imported by EscapeAnalysis +function stmt_effect_free end # imported by EscapeAnalysis +function alloc_array_ndims end # imported by EscapeAnalysis include("compiler/ssair/driver.jl") +using .EscapeAnalysis +import .EscapeAnalysis: EscapeState, ArgEscapeCache, is_ipo_profitable + +""" + cache_escapes!(caller::InferenceResult, estate::EscapeState) + +Transforms escape information of call arguments of `caller`, +and then caches it into a global cache for later interprocedural propagation. +""" +cache_escapes!(caller::InferenceResult, estate::EscapeState) = + caller.argescapes = ArgEscapeCache(estate) + +function get_escape_cache(mi_cache::MICache) where MICache + return function (linfo::Union{InferenceResult,MethodInstance}) + if isa(linfo, InferenceResult) + argescapes = linfo.argescapes + else + codeinst = get(mi_cache, linfo, nothing) + isa(codeinst, CodeInstance) || return nothing + argescapes = codeinst.argescapes + end + return argescapes !== nothing ? argescapes::ArgEscapeCache : nothing + end +end mutable struct OptimizationState linfo::MethodInstance @@ -121,36 +178,6 @@ function ir_to_codeinf!(opt::OptimizationState) return src end -############# -# constants # -############# - -# The slot has uses that are not statically dominated by any assignment -# This is implied by `SLOT_USEDUNDEF`. -# If this is not set, all the uses are (statically) dominated by the defs. -# In particular, if a slot has `AssignedOnce && !StaticUndef`, it is an SSA. -const SLOT_STATICUNDEF = 1 # slot might be used before it is defined (structurally) -const SLOT_ASSIGNEDONCE = 16 # slot is assigned to only once -const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError -# const SLOT_CALLED = 64 - -# NOTE make sure to sync the flag definitions below with julia.h and `jl_code_info_set_ir` in method.c - -const IR_FLAG_NULL = 0x00 -# This statement is marked as @inbounds by user. -# Ff replaced by inlining, any contained boundschecks may be removed. -const IR_FLAG_INBOUNDS = 0x01 << 0 -# This statement is marked as @inline by user -const IR_FLAG_INLINE = 0x01 << 1 -# This statement is marked as @noinline by user -const IR_FLAG_NOINLINE = 0x01 << 2 -const IR_FLAG_THROW_BLOCK = 0x01 << 3 -# This statement may be removed if its result is unused. In particular it must -# thus be both pure and effect free. -const IR_FLAG_EFFECT_FREE = 0x01 << 4 - -const TOP_TUPLE = GlobalRef(Core, :tuple) - ######### # logic # ######### @@ -503,15 +530,23 @@ end # run the optimization work function optimize(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, caller::InferenceResult) - @timeit "optimizer" ir = run_passes(opt.src, opt) + @timeit "optimizer" ir = run_passes(opt.src, opt, caller) return finish(interp, opt, params, ir, caller) end -function run_passes(ci::CodeInfo, sv::OptimizationState) +function run_passes(ci::CodeInfo, sv::OptimizationState, caller::InferenceResult) @timeit "convert" ir = convert_to_ircode(ci, sv) @timeit "slot2reg" ir = slot2reg(ir, ci, sv) # TODO: Domsorting can produce an updated domtree - no need to recompute here @timeit "compact 1" ir = compact!(ir) + nargs = let def = sv.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end + get_escape_cache = (@__MODULE__).get_escape_cache(sv.inlining.mi_cache) + if is_ipo_profitable(ir, nargs) + @timeit "IPO EA" begin + state = analyze_escapes(ir, nargs, false, get_escape_cache) + cache_escapes!(caller, state) + end + end @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) # @timeit "verify 2" verify_ir(ir) @timeit "compact 2" ir = compact!(ir) diff --git a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl new file mode 100644 index 0000000000000..0cb34e76c36bb --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl @@ -0,0 +1,1913 @@ +baremodule EscapeAnalysis + +export + analyze_escapes, + getaliases, + isaliased, + has_no_escape, + has_arg_escape, + has_return_escape, + has_thrown_escape, + has_all_escape + +const _TOP_MOD = ccall(:jl_base_relative_to, Any, (Any,), EscapeAnalysis)::Module + +# imports +import ._TOP_MOD: ==, getindex, setindex! +# usings +import Core: + MethodInstance, Const, Argument, SSAValue, PiNode, PhiNode, UpsilonNode, PhiCNode, + ReturnNode, GotoNode, GotoIfNot, SimpleVector, MethodMatch, CodeInstance, + sizeof, ifelse, arrayset, arrayref, arraysize +import ._TOP_MOD: # Base definitions + @__MODULE__, @eval, @assert, @specialize, @nospecialize, @inbounds, @inline, @noinline, + @label, @goto, !, !==, !=, ≠, +, -, *, ≤, <, ≥, >, &, |, <<, error, missing, copy, + Vector, BitSet, IdDict, IdSet, UnitRange, Csize_t, Callable, ∪, ⊆, ∩, :, ∈, ∉, =>, + in, length, get, first, last, haskey, keys, get!, isempty, isassigned, + pop!, push!, pushfirst!, empty!, delete!, max, min, enumerate, unwrap_unionall, + ismutabletype +import Core.Compiler: # Core.Compiler specific definitions + Bottom, InferenceResult, IRCode, IR_FLAG_EFFECT_FREE, + isbitstype, isexpr, is_meta_expr_head, println, widenconst, argextype, singleton_type, + fieldcount_noerror, try_compute_field, try_compute_fieldidx, hasintersect, ⊑, + intrinsic_nothrow, array_builtin_common_typecheck, arrayset_typecheck, + setfield!_nothrow, alloc_array_ndims, stmt_effect_free, check_effect_free! + +include(x) = _TOP_MOD.include(@__MODULE__, x) +if _TOP_MOD === Core.Compiler + include("compiler/ssair/EscapeAnalysis/disjoint_set.jl") +else + include("disjoint_set.jl") +end + +const AInfo = IdSet{Any} +const LivenessSet = BitSet + +""" + x::EscapeInfo + +A lattice for escape information, which holds the following properties: +- `x.Analyzed::Bool`: not formally part of the lattice, only indicates `x` has not been analyzed or not +- `x.ReturnEscape::Bool`: indicates `x` can escape to the caller via return +- `x.ThrownEscape::BitSet`: records SSA statement numbers where `x` can be thrown as exception: + * `isempty(x.ThrownEscape)`: `x` will never be thrown in this call frame (the bottom) + * `pc ∈ x.ThrownEscape`: `x` may be thrown at the SSA statement at `pc` + * `-1 ∈ x.ThrownEscape`: `x` may be thrown at arbitrary points of this call frame (the top) + This information will be used by `escape_exception!` to propagate potential escapes via exception. +- `x.AliasInfo::Union{Bool,IndexableFields,IndexableElements,Unindexable}`: maintains all possible values + that can be aliased to fields or array elements of `x`: + * `x.AliasInfo === false` indicates the fields/elements of `x` aren't analyzed yet + * `x.AliasInfo === true` indicates the fields/elements of `x` can't be analyzed, + e.g. the type of `x` is not known or is not concrete and thus its fields/elements + can't be known precisely + * `x.AliasInfo::IndexableFields` records all the possible values that can be aliased to fields of object `x` with precise index information + * `x.AliasInfo::IndexableElements` records all the possible values that can be aliased to elements of array `x` with precise index information + * `x.AliasInfo::Unindexable` records all the possible values that can be aliased to fields/elements of `x` without precise index information +- `x.Liveness::BitSet`: records SSA statement numbers where `x` should be live, e.g. + to be used as a call argument, to be returned to a caller, or preserved for `:foreigncall`: + * `isempty(x.Liveness)`: `x` is never be used in this call frame (the bottom) + * `0 ∈ x.Liveness` also has the special meaning that it's a call argument of the currently + analyzed call frame (and thus it's visible from the caller immediately). + * `pc ∈ x.Liveness`: `x` may be used at the SSA statement at `pc` + * `-1 ∈ x.Liveness`: `x` may be used at arbitrary points of this call frame (the top) + +There are utility constructors to create common `EscapeInfo`s, e.g., +- `NoEscape()`: the bottom(-like) element of this lattice, meaning it won't escape to anywhere +- `AllEscape()`: the topmost element of this lattice, meaning it will escape to everywhere + +`analyze_escapes` will transition these elements from the bottom to the top, +in the same direction as Julia's native type inference routine. +An abstract state will be initialized with the bottom(-like) elements: +- the call arguments are initialized as `ArgEscape()`, whose `Liveness` property includes `0` + to indicate that it is passed as a call argument and visible from a caller immediately +- the other states are initialized as `NotAnalyzed()`, which is a special lattice element that + is slightly lower than `NoEscape`, but at the same time doesn't represent any meaning + other than it's not analyzed yet (thus it's not formally part of the lattice) +""" +struct EscapeInfo + Analyzed::Bool + ReturnEscape::Bool + ThrownEscape::LivenessSet + AliasInfo #::Union{IndexableFields,IndexableElements,Unindexable,Bool} + Liveness::LivenessSet + + function EscapeInfo( + Analyzed::Bool, + ReturnEscape::Bool, + ThrownEscape::LivenessSet, + AliasInfo#=::Union{IndexableFields,IndexableElements,Unindexable,Bool}=#, + Liveness::LivenessSet, + ) + @nospecialize AliasInfo + return new( + Analyzed, + ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) + end + function EscapeInfo( + x::EscapeInfo, + # non-concrete fields should be passed as default arguments + # in order to avoid allocating non-concrete `NamedTuple`s + AliasInfo#=::Union{IndexableFields,IndexableElements,Unindexable,Bool}=# = x.AliasInfo; + Analyzed::Bool = x.Analyzed, + ReturnEscape::Bool = x.ReturnEscape, + ThrownEscape::LivenessSet = x.ThrownEscape, + Liveness::LivenessSet = x.Liveness, + ) + @nospecialize AliasInfo + return new( + Analyzed, + ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) + end +end + +# precomputed default values in order to eliminate computations at each callsite + +const BOT_THROWN_ESCAPE = LivenessSet() +# NOTE the lattice operations should try to avoid actual set computations on this top value, +# and e.g. LivenessSet(0:1000000) should also work without incurring excessive computations +const TOP_THROWN_ESCAPE = LivenessSet(-1) + +const BOT_LIVENESS = LivenessSet() +# NOTE the lattice operations should try to avoid actual set computations on this top value, +# and e.g. LivenessSet(0:1000000) should also work without incurring excessive computations +const TOP_LIVENESS = LivenessSet(-1:0) +const ARG_LIVENESS = LivenessSet(0) + +# the constructors +NotAnalyzed() = EscapeInfo(false, false, BOT_THROWN_ESCAPE, false, BOT_LIVENESS) # not formally part of the lattice +NoEscape() = EscapeInfo(true, false, BOT_THROWN_ESCAPE, false, BOT_LIVENESS) +ArgEscape() = EscapeInfo(true, false, BOT_THROWN_ESCAPE, true, ARG_LIVENESS) +ReturnEscape(pc::Int) = EscapeInfo(true, true, BOT_THROWN_ESCAPE, false, LivenessSet(pc)) +AllReturnEscape() = EscapeInfo(true, true, BOT_THROWN_ESCAPE, false, TOP_LIVENESS) +ThrownEscape(pc::Int) = EscapeInfo(true, false, LivenessSet(pc), false, BOT_LIVENESS) +AllEscape() = EscapeInfo(true, true, TOP_THROWN_ESCAPE, true, TOP_LIVENESS) + +const ⊥, ⊤ = NotAnalyzed(), AllEscape() + +# Convenience names for some ⊑ₑ queries +has_no_escape(x::EscapeInfo) = !x.ReturnEscape && isempty(x.ThrownEscape) && 0 ∉ x.Liveness +has_arg_escape(x::EscapeInfo) = 0 in x.Liveness +has_return_escape(x::EscapeInfo) = x.ReturnEscape +has_return_escape(x::EscapeInfo, pc::Int) = x.ReturnEscape && (-1 ∈ x.Liveness || pc in x.Liveness) +has_thrown_escape(x::EscapeInfo) = !isempty(x.ThrownEscape) +has_thrown_escape(x::EscapeInfo, pc::Int) = -1 ∈ x.ThrownEscape || pc in x.ThrownEscape +has_all_escape(x::EscapeInfo) = ⊤ ⊑ₑ x + +# utility lattice constructors +ignore_argescape(x::EscapeInfo) = EscapeInfo(x; Liveness=delete!(copy(x.Liveness), 0)) +ignore_thrownescapes(x::EscapeInfo) = EscapeInfo(x; ThrownEscape=BOT_THROWN_ESCAPE) +ignore_aliasinfo(x::EscapeInfo) = EscapeInfo(x, false) +ignore_liveness(x::EscapeInfo) = EscapeInfo(x; Liveness=BOT_LIVENESS) + +# AliasInfo +struct IndexableFields + infos::Vector{AInfo} +end +struct IndexableElements + infos::IdDict{Int,AInfo} +end +struct Unindexable + info::AInfo +end +IndexableFields(nflds::Int) = IndexableFields(AInfo[AInfo() for _ in 1:nflds]) +Unindexable() = Unindexable(AInfo()) + +merge_to_unindexable(AliasInfo::IndexableFields) = Unindexable(merge_to_unindexable(AliasInfo.infos)) +merge_to_unindexable(AliasInfo::Unindexable, AliasInfos::IndexableFields) = Unindexable(merge_to_unindexable(AliasInfo.info, AliasInfos.infos)) +merge_to_unindexable(infos::Vector{AInfo}) = merge_to_unindexable(AInfo(), infos) +function merge_to_unindexable(info::AInfo, infos::Vector{AInfo}) + for i = 1:length(infos) + info = info ∪ infos[i] + end + return info +end +merge_to_unindexable(AliasInfo::IndexableElements) = Unindexable(merge_to_unindexable(AliasInfo.infos)) +merge_to_unindexable(AliasInfo::Unindexable, AliasInfos::IndexableElements) = Unindexable(merge_to_unindexable(AliasInfo.info, AliasInfos.infos)) +merge_to_unindexable(infos::IdDict{Int,AInfo}) = merge_to_unindexable(AInfo(), infos) +function merge_to_unindexable(info::AInfo, infos::IdDict{Int,AInfo}) + for idx in keys(infos) + info = info ∪ infos[idx] + end + return info +end + +# we need to make sure this `==` operator corresponds to lattice equality rather than object equality, +# otherwise `propagate_changes` can't detect the convergence +x::EscapeInfo == y::EscapeInfo = begin + # fast pass: better to avoid top comparison + x === y && return true + x.Analyzed === y.Analyzed || return false + x.ReturnEscape === y.ReturnEscape || return false + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE + yt === TOP_THROWN_ESCAPE || return false + elseif yt === TOP_THROWN_ESCAPE + return false # x.ThrownEscape === TOP_THROWN_ESCAPE + else + xt == yt || return false + end + xa, ya = x.AliasInfo, y.AliasInfo + if isa(xa, Bool) + xa === ya || return false + elseif isa(xa, IndexableFields) + isa(ya, IndexableFields) || return false + xa.infos == ya.infos || return false + elseif isa(xa, IndexableElements) + isa(ya, IndexableElements) || return false + xa.infos == ya.infos || return false + else + xa = xa::Unindexable + isa(ya, Unindexable) || return false + xa.info == ya.info || return false + end + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS + yl === TOP_LIVENESS || return false + elseif yl === TOP_LIVENESS + return false # x.Liveness === TOP_LIVENESS + else + xl == yl || return false + end + return true +end + +""" + x::EscapeInfo ⊑ₑ y::EscapeInfo -> Bool + +The non-strict partial order over `EscapeInfo`. +""" +x::EscapeInfo ⊑ₑ y::EscapeInfo = begin + # fast pass: better to avoid top comparison + if y === ⊤ + return true + elseif x === ⊤ + return false # return y === ⊤ + elseif x === ⊥ + return true + elseif y === ⊥ + return false # return x === ⊥ + end + x.Analyzed ≤ y.Analyzed || return false + x.ReturnEscape ≤ y.ReturnEscape || return false + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE + yt !== TOP_THROWN_ESCAPE && return false + elseif yt !== TOP_THROWN_ESCAPE + xt ⊆ yt || return false + end + xa, ya = x.AliasInfo, y.AliasInfo + if isa(xa, Bool) + xa && ya !== true && return false + elseif isa(xa, IndexableFields) + if isa(ya, IndexableFields) + xinfos, yinfos = xa.infos, ya.infos + xn, yn = length(xinfos), length(yinfos) + xn > yn && return false + for i in 1:xn + xinfos[i] ⊆ yinfos[i] || return false + end + elseif isa(ya, IndexableElements) + return false + elseif isa(ya, Unindexable) + xinfos, yinfo = xa.infos, ya.info + for i = length(xinfos) + xinfos[i] ⊆ yinfo || return false + end + else + ya === true || return false + end + elseif isa(xa, IndexableElements) + if isa(ya, IndexableElements) + xinfos, yinfos = xa.infos, ya.infos + keys(xinfos) ⊆ keys(yinfos) || return false + for idx in keys(xinfos) + xinfos[idx] ⊆ yinfos[idx] || return false + end + elseif isa(ya, IndexableFields) + return false + elseif isa(ya, Unindexable) + xinfos, yinfo = xa.infos, ya.info + for idx in keys(xinfos) + xinfos[idx] ⊆ yinfo || return false + end + else + ya === true || return false + end + else + xa = xa::Unindexable + if isa(ya, Unindexable) + xinfo, yinfo = xa.info, ya.info + xinfo ⊆ yinfo || return false + else + ya === true || return false + end + end + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS + yl !== TOP_LIVENESS && return false + elseif yl !== TOP_LIVENESS + xl ⊆ yl || return false + end + return true +end + +""" + x::EscapeInfo ⊏ₑ y::EscapeInfo -> Bool + +The strict partial order over `EscapeInfo`. +This is defined as the irreflexive kernel of `⊏ₑ`. +""" +x::EscapeInfo ⊏ₑ y::EscapeInfo = x ⊑ₑ y && !(y ⊑ₑ x) + +""" + x::EscapeInfo ⋤ₑ y::EscapeInfo -> Bool + +This order could be used as a slightly more efficient version of the strict order `⊏ₑ`, +where we can safely assume `x ⊑ₑ y` holds. +""" +x::EscapeInfo ⋤ₑ y::EscapeInfo = !(y ⊑ₑ x) + +""" + x::EscapeInfo ⊔ₑ y::EscapeInfo -> EscapeInfo + +Computes the join of `x` and `y` in the partial order defined by `EscapeInfo`. +""" +x::EscapeInfo ⊔ₑ y::EscapeInfo = begin + # fast pass: better to avoid top join + if x === ⊤ || y === ⊤ + return ⊤ + elseif x === ⊥ + return y + elseif y === ⊥ + return x + end + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE || yt === TOP_THROWN_ESCAPE + ThrownEscape = TOP_THROWN_ESCAPE + elseif xt === BOT_THROWN_ESCAPE + ThrownEscape = yt + elseif yt === BOT_THROWN_ESCAPE + ThrownEscape = xt + else + ThrownEscape = xt ∪ yt + end + AliasInfo = merge_alias_info(x.AliasInfo, y.AliasInfo) + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS || yl === TOP_LIVENESS + Liveness = TOP_LIVENESS + elseif xl === BOT_LIVENESS + Liveness = yl + elseif yl === BOT_LIVENESS + Liveness = xl + else + Liveness = xl ∪ yl + end + return EscapeInfo( + x.Analyzed | y.Analyzed, + x.ReturnEscape | y.ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) +end + +function merge_alias_info(@nospecialize(xa), @nospecialize(ya)) + if xa === true || ya === true + return true + elseif xa === false + return ya + elseif ya === false + return xa + elseif isa(xa, IndexableFields) + if isa(ya, IndexableFields) + xinfos, yinfos = xa.infos, ya.infos + xn, yn = length(xinfos), length(yinfos) + nmax, nmin = max(xn, yn), min(xn, yn) + infos = Vector{AInfo}(undef, nmax) + for i in 1:nmax + if i > nmin + infos[i] = (xn > yn ? xinfos : yinfos)[i] + else + infos[i] = xinfos[i] ∪ yinfos[i] + end + end + return IndexableFields(infos) + elseif isa(ya, Unindexable) + xinfos, yinfo = xa.infos, ya.info + return merge_to_unindexable(ya, xa) + else + return true # handle conflicting case conservatively + end + elseif isa(xa, IndexableElements) + if isa(ya, IndexableElements) + xinfos, yinfos = xa.infos, ya.infos + infos = IdDict{Int,AInfo}() + for idx in keys(xinfos) + if !haskey(yinfos, idx) + infos[idx] = xinfos[idx] + else + infos[idx] = xinfos[idx] ∪ yinfos[idx] + end + end + for idx in keys(yinfos) + haskey(xinfos, idx) && continue # unioned already + infos[idx] = yinfos[idx] + end + return IndexableElements(infos) + elseif isa(ya, Unindexable) + return merge_to_unindexable(ya, xa) + else + return true # handle conflicting case conservatively + end + else + xa = xa::Unindexable + if isa(ya, IndexableFields) + return merge_to_unindexable(xa, ya) + elseif isa(ya, IndexableElements) + return merge_to_unindexable(xa, ya) + else + ya = ya::Unindexable + xinfo, yinfo = xa.info, ya.info + info = xinfo ∪ yinfo + return Unindexable(info) + end + end +end + +const AliasSet = IntDisjointSet{Int} + +const ArrayInfo = IdDict{Int,Vector{Int}} + +""" + estate::EscapeState + +Extended lattice that maps arguments and SSA values to escape information represented as [`EscapeInfo`](@ref). +Escape information imposed on SSA IR element `x` can be retrieved by `estate[x]`. +""" +struct EscapeState + escapes::Vector{EscapeInfo} + aliasset::AliasSet + nargs::Int + arrayinfo::Union{Nothing,ArrayInfo} +end +function EscapeState(nargs::Int, nstmts::Int, arrayinfo::Union{Nothing,ArrayInfo}) + escapes = EscapeInfo[ + 1 ≤ i ≤ nargs ? ArgEscape() : ⊥ for i in 1:(nargs+nstmts)] + aliasset = AliasSet(nargs+nstmts) + return EscapeState(escapes, aliasset, nargs, arrayinfo) +end +function getindex(estate::EscapeState, @nospecialize(x)) + xidx = iridx(x, estate) + return xidx === nothing ? nothing : estate.escapes[xidx] +end +function setindex!(estate::EscapeState, v::EscapeInfo, @nospecialize(x)) + xidx = iridx(x, estate) + if xidx !== nothing + estate.escapes[xidx] = v + end + return estate +end + +""" + iridx(x, estate::EscapeState) -> xidx::Union{Int,Nothing} + +Tries to convert analyzable IR element `x::Union{Argument,SSAValue}` to +its unique identifier number `xidx` that is valid in the analysis context of `estate`. +Returns `nothing` if `x` isn't maintained by `estate` and thus unanalyzable (e.g. `x::GlobalRef`). + +`irval` is the inverse function of `iridx` (not formally), i.e. +`irval(iridx(x::Union{Argument,SSAValue}, state), state) === x`. +""" +function iridx(@nospecialize(x), estate::EscapeState) + if isa(x, Argument) + xidx = x.n + @assert 1 ≤ xidx ≤ estate.nargs "invalid Argument" + elseif isa(x, SSAValue) + xidx = x.id + estate.nargs + else + return nothing + end + return xidx +end + +""" + irval(xidx::Int, estate::EscapeState) -> x::Union{Argument,SSAValue} + +Converts its unique identifier number `xidx` to the original IR element `x::Union{Argument,SSAValue}` +that is analyzable in the context of `estate`. + +`iridx` is the inverse function of `irval` (not formally), i.e. +`iridx(irval(xidx, state), state) === xidx`. +""" +function irval(xidx::Int, estate::EscapeState) + x = xidx > estate.nargs ? SSAValue(xidx-estate.nargs) : Argument(xidx) + return x +end + +function getaliases(x::Union{Argument,SSAValue}, estate::EscapeState) + xidx = iridx(x, estate) + aliases = getaliases(xidx, estate) + aliases === nothing && return nothing + return Union{Argument,SSAValue}[irval(aidx, estate) for aidx in aliases] +end +function getaliases(xidx::Int, estate::EscapeState) + aliasset = estate.aliasset + root = find_root!(aliasset, xidx) + if xidx ≠ root || aliasset.ranks[xidx] > 0 + # the size of this alias set containing `key` is larger than 1, + # collect the entire alias set + aliases = Int[] + for aidx in 1:length(aliasset.parents) + if aliasset.parents[aidx] == root + push!(aliases, aidx) + end + end + return aliases + else + return nothing + end +end + +isaliased(x::Union{Argument,SSAValue}, y::Union{Argument,SSAValue}, estate::EscapeState) = + isaliased(iridx(x, estate), iridx(y, estate), estate) +isaliased(xidx::Int, yidx::Int, estate::EscapeState) = + in_same_set(estate.aliasset, xidx, yidx) + +struct ArgEscapeInfo + EscapeBits::UInt8 +end +function ArgEscapeInfo(x::EscapeInfo) + x === ⊤ && return ArgEscapeInfo(ARG_ALL_ESCAPE) + EscapeBits = 0x00 + has_return_escape(x) && (EscapeBits |= ARG_RETURN_ESCAPE) + has_thrown_escape(x) && (EscapeBits |= ARG_THROWN_ESCAPE) + return ArgEscapeInfo(EscapeBits) +end + +const ARG_ALL_ESCAPE = 0x01 << 0 +const ARG_RETURN_ESCAPE = 0x01 << 1 +const ARG_THROWN_ESCAPE = 0x01 << 2 + +has_no_escape(x::ArgEscapeInfo) = !has_all_escape(x) && !has_return_escape(x) && !has_thrown_escape(x) +has_all_escape(x::ArgEscapeInfo) = x.EscapeBits & ARG_ALL_ESCAPE ≠ 0 +has_return_escape(x::ArgEscapeInfo) = x.EscapeBits & ARG_RETURN_ESCAPE ≠ 0 +has_thrown_escape(x::ArgEscapeInfo) = x.EscapeBits & ARG_THROWN_ESCAPE ≠ 0 + +struct ArgAliasing + aidx::Int + bidx::Int +end + +struct ArgEscapeCache + argescapes::Vector{ArgEscapeInfo} + argaliases::Vector{ArgAliasing} +end + +function ArgEscapeCache(estate::EscapeState) + nargs = estate.nargs + argescapes = Vector{ArgEscapeInfo}(undef, nargs) + argaliases = ArgAliasing[] + for i = 1:nargs + info = estate.escapes[i] + @assert info.AliasInfo === true + argescapes[i] = ArgEscapeInfo(info) + for j = (i+1):nargs + if isaliased(i, j, estate) + push!(argaliases, ArgAliasing(i, j)) + end + end + end + return ArgEscapeCache(argescapes, argaliases) +end + +""" + is_ipo_profitable(ir::IRCode, nargs::Int) -> Bool + +Heuristically checks if there is any profitability to run the escape analysis on `ir` +and generate IPO escape information cache. Specifically, this function examines +if any call argument is "interesting" in terms of their escapability. +""" +function is_ipo_profitable(ir::IRCode, nargs::Int) + for i = 1:nargs + t = unwrap_unionall(widenconst(ir.argtypes[i])) + t <: IO && return false # bail out IO-related functions + is_ipo_profitable_type(t) && return true + end + return false +end +function is_ipo_profitable_type(@nospecialize t) + if isa(t, Union) + return is_ipo_profitable_type(t.a) && is_ipo_profitable_type(t.b) + end + (t === String || t === Symbol || t === Module || t === SimpleVector) && return false + return ismutabletype(t) +end + +abstract type Change end +struct EscapeChange <: Change + xidx::Int + xinfo::EscapeInfo +end +struct AliasChange <: Change + xidx::Int + yidx::Int +end +struct ArgAliasChange <: Change + xidx::Int + yidx::Int +end +struct LivenessChange <: Change + xidx::Int + livepc::Int +end +const Changes = Vector{Change} + +struct AnalysisState{T<:Callable} + ir::IRCode + estate::EscapeState + changes::Changes + get_escape_cache::T +end + +function getinst(ir::IRCode, idx::Int) + nstmts = length(ir.stmts) + if idx ≤ nstmts + return ir.stmts[idx] + else + return ir.new_nodes.stmts[idx - nstmts] + end +end + +""" + analyze_escapes(ir::IRCode, nargs::Int, call_resolved::Bool, get_escape_cache::Callable) + -> estate::EscapeState + +Analyzes escape information in `ir`: +- `nargs`: the number of actual arguments of the analyzed call +- `call_resolved`: if interprocedural calls are already resolved by `ssa_inlining_pass!` +- `get_escape_cache(::Union{InferenceResult,MethodInstance}) -> Union{Nothing,ArgEscapeCache}`: + retrieves cached argument escape information +""" +function analyze_escapes(ir::IRCode, nargs::Int, call_resolved::Bool, get_escape_cache::T) where T<:Callable + stmts = ir.stmts + nstmts = length(stmts) + length(ir.new_nodes.stmts) + + tryregions, arrayinfo, callinfo = compute_frameinfo(ir, call_resolved) + estate = EscapeState(nargs, nstmts, arrayinfo) + changes = Changes() # keeps changes that happen at current statement + astate = AnalysisState(ir, estate, changes, get_escape_cache) + + local debug_itr_counter = 0 + while true + local anyupdate = false + + for pc in nstmts:-1:1 + stmt = getinst(ir, pc)[:inst] + + # collect escape information + if isa(stmt, Expr) + head = stmt.head + if head === :call + if callinfo !== nothing + escape_call!(astate, pc, stmt.args, callinfo) + else + escape_call!(astate, pc, stmt.args) + end + elseif head === :invoke + escape_invoke!(astate, pc, stmt.args) + elseif head === :new || head === :splatnew + escape_new!(astate, pc, stmt.args) + elseif head === :(=) + lhs, rhs = stmt.args + if isa(lhs, GlobalRef) # global store + add_escape_change!(astate, rhs, ⊤) + else + unexpected_assignment!(ir, pc) + end + elseif head === :foreigncall + escape_foreigncall!(astate, pc, stmt.args) + elseif head === :throw_undef_if_not # XXX when is this expression inserted ? + add_escape_change!(astate, stmt.args[1], ThrownEscape(pc)) + elseif is_meta_expr_head(head) + # meta expressions doesn't account for any usages + continue + elseif head === :enter || head === :leave || head === :the_exception || head === :pop_exception + # ignore these expressions since escapes via exceptions are handled by `escape_exception!` + # `escape_exception!` conservatively propagates `AllEscape` anyway, + # and so escape information imposed on `:the_exception` isn't computed + continue + elseif head === :static_parameter || # this exists statically, not interested in its escape + head === :copyast || # XXX can this account for some escapes? + head === :undefcheck || # XXX can this account for some escapes? + head === :isdefined || # just returns `Bool`, nothing accounts for any escapes + head === :gc_preserve_begin || # `GC.@preserve` expressions themselves won't be used anywhere + head === :gc_preserve_end # `GC.@preserve` expressions themselves won't be used anywhere + continue + else + add_conservative_changes!(astate, pc, stmt.args) + end + elseif isa(stmt, ReturnNode) + if isdefined(stmt, :val) + add_escape_change!(astate, stmt.val, ReturnEscape(pc)) + end + elseif isa(stmt, PhiNode) + escape_edges!(astate, pc, stmt.values) + elseif isa(stmt, PiNode) + escape_val_ifdefined!(astate, pc, stmt) + elseif isa(stmt, PhiCNode) + escape_edges!(astate, pc, stmt.values) + elseif isa(stmt, UpsilonNode) + escape_val_ifdefined!(astate, pc, stmt) + elseif isa(stmt, GlobalRef) # global load + add_escape_change!(astate, SSAValue(pc), ⊤) + elseif isa(stmt, SSAValue) + escape_val!(astate, pc, stmt) + elseif isa(stmt, Argument) + escape_val!(astate, pc, stmt) + else # otherwise `stmt` can be GotoNode, GotoIfNot, and inlined values etc. + continue + end + + isempty(changes) && continue + + anyupdate |= propagate_changes!(estate, changes) + + empty!(changes) + end + + tryregions !== nothing && escape_exception!(astate, tryregions) + + debug_itr_counter += 1 + + anyupdate || break + end + + # if debug_itr_counter > 2 + # println("[EA] excessive iteration count found ", debug_itr_counter, " (", singleton_type(ir.argtypes[1]), ")") + # end + + return estate +end + +""" + compute_frameinfo(ir::IRCode, call_resolved::Bool) -> (tryregions, arrayinfo, callinfo) + +A preparatory linear scan before the escape analysis on `ir` to find: +- `tryregions::Union{Nothing,Vector{UnitRange{Int}}}`: regions in which potential `throw`s can be caught (used by `escape_exception!`) +- `arrayinfo::Union{Nothing,IdDict{Int,Vector{Int}}}`: array allocations whose dimensions are known precisely (with some very simple local analysis) +- `callinfo::`: when `!call_resolved`, `compute_frameinfo` additionally returns `callinfo::Vector{Union{MethodInstance,InferenceResult}}`, + which contains information about statically resolved callsites. + The inliner will use essentially equivalent interprocedural information to inline callees as well as resolve static callsites, + this additional information won't be required when analyzing post-inlining IR. + +!!! note + This array dimension analysis to compute `arrayinfo` is very local and doesn't account + for flow-sensitivity nor complex aliasing. + Ideally this dimension analysis should be done as a part of type inference that + propagates array dimenstions in a flow sensitive way. +""" +function compute_frameinfo(ir::IRCode, call_resolved::Bool) + nstmts, nnewnodes = length(ir.stmts), length(ir.new_nodes.stmts) + tryregions, arrayinfo = nothing, nothing + if !call_resolved + callinfo = Vector{Any}(undef, nstmts+nnewnodes) + else + callinfo = nothing + end + for idx in 1:nstmts+nnewnodes + inst = getinst(ir, idx) + stmt = inst[:inst] + if !call_resolved + # TODO don't call `check_effect_free!` in the inlinear + check_effect_free!(ir, idx, stmt, inst[:type]) + end + if callinfo !== nothing && isexpr(stmt, :call) + callinfo[idx] = resolve_call(ir, stmt, inst[:info]) + elseif isexpr(stmt, :enter) + @assert idx ≤ nstmts "try/catch inside new_nodes unsupported" + tryregions === nothing && (tryregions = UnitRange{Int}[]) + leave_block = stmt.args[1]::Int + leave_pc = first(ir.cfg.blocks[leave_block].stmts) + push!(tryregions, idx:leave_pc) + elseif isexpr(stmt, :foreigncall) + args = stmt.args + name = args[1] + nn = normalize(name) + isa(nn, Symbol) || @goto next_stmt + ndims = alloc_array_ndims(nn) + ndims === nothing && @goto next_stmt + if ndims ≠ 0 + length(args) ≥ ndims+6 || @goto next_stmt + dims = Int[] + for i in 1:ndims + dim = argextype(args[i+6], ir) + isa(dim, Const) || @goto next_stmt + dim = dim.val + isa(dim, Int) || @goto next_stmt + push!(dims, dim) + end + else + length(args) ≥ 7 || @goto next_stmt + dims = argextype(args[7], ir) + if isa(dims, Const) + dims = dims.val + isa(dims, Tuple{Vararg{Int}}) || @goto next_stmt + dims = collect(Int, dims) + else + dims === Tuple{} || @goto next_stmt + dims = Int[] + end + end + if arrayinfo === nothing + arrayinfo = ArrayInfo() + end + arrayinfo[idx] = dims + elseif arrayinfo !== nothing + # TODO this super limited alias analysis is able to handle only very simple cases + # this should be replaced with a proper forward dimension analysis + if isa(stmt, PhiNode) + values = stmt.values + local dims = nothing + for i = 1:length(values) + if isassigned(values, i) + val = values[i] + if isa(val, SSAValue) && haskey(arrayinfo, val.id) + if dims === nothing + dims = arrayinfo[val.id] + continue + elseif dims == arrayinfo[val.id] + continue + end + end + end + @goto next_stmt + end + if dims !== nothing + arrayinfo[idx] = dims + end + elseif isa(stmt, PiNode) + if isdefined(stmt, :val) + val = stmt.val + if isa(val, SSAValue) && haskey(arrayinfo, val.id) + arrayinfo[idx] = arrayinfo[val.id] + end + end + end + end + @label next_stmt + end + return tryregions, arrayinfo, callinfo +end + +# define resolve_call +if _TOP_MOD === Core.Compiler + include("compiler/ssair/EscapeAnalysis/interprocedural.jl") +else + include("interprocedural.jl") +end + +# propagate changes, and check convergence +function propagate_changes!(estate::EscapeState, changes::Changes) + local anychanged = false + for change in changes + if isa(change, EscapeChange) + anychanged |= propagate_escape_change!(estate, change) + elseif isa(change, LivenessChange) + anychanged |= propagate_liveness_change!(estate, change) + else + change = change::AliasChange + anychanged |= propagate_alias_change!(estate, change) + end + end + return anychanged +end + +@inline propagate_escape_change!(estate::EscapeState, change::EscapeChange) = + propagate_escape_change!(⊔ₑ, estate, change) + +# allows this to work as lattice join as well as lattice meet +@inline function propagate_escape_change!(@specialize(op), + estate::EscapeState, change::EscapeChange) + (; xidx, xinfo) = change + anychanged = _propagate_escape_change!(op, estate, xidx, xinfo) + # COMBAK is there a more efficient method of escape information equalization on aliasset? + aliases = getaliases(xidx, estate) + if aliases !== nothing + for aidx in aliases + anychanged |= _propagate_escape_change!(op, estate, aidx, xinfo) + end + end + return anychanged +end + +@inline function _propagate_escape_change!(@specialize(op), + estate::EscapeState, xidx::Int, info::EscapeInfo) + old = estate.escapes[xidx] + new = op(old, info) + if old ≠ new + estate.escapes[xidx] = new + return true + end + return false +end + +# propagate Liveness changes separately in order to avoid constructing too many LivenessSet +@inline function propagate_liveness_change!(estate::EscapeState, change::LivenessChange) + (; xidx, livepc) = change + info = estate.escapes[xidx] + Liveness = info.Liveness + Liveness === TOP_LIVENESS && return false + livepc in Liveness && return false + if Liveness === BOT_LIVENESS || Liveness === ARG_LIVENESS + # if this Liveness is a constant, we shouldn't modify it and propagate this change as a new EscapeInfo + Liveness = copy(Liveness) + push!(Liveness, livepc) + estate.escapes[xidx] = EscapeInfo(info; Liveness) + return true + else + # directly modify Liveness property in order to avoid excessive copies + push!(Liveness, livepc) + return true + end +end + +@inline function propagate_alias_change!(estate::EscapeState, change::AliasChange) + anychange = false + (; xidx, yidx) = change + aliasset = estate.aliasset + xroot = find_root!(aliasset, xidx) + yroot = find_root!(aliasset, yidx) + if xroot ≠ yroot + union!(aliasset, xroot, yroot) + return true + end + return false +end + +function add_escape_change!(astate::AnalysisState, @nospecialize(x), xinfo::EscapeInfo, + force::Bool = false) + xinfo === ⊥ && return nothing # performance optimization + xidx = iridx(x, astate.estate) + if xidx !== nothing + if force || !isbitstype(widenconst(argextype(x, astate.ir))) + push!(astate.changes, EscapeChange(xidx, xinfo)) + end + end + return nothing +end + +function add_liveness_change!(astate::AnalysisState, @nospecialize(x), livepc::Int) + xidx = iridx(x, astate.estate) + if xidx !== nothing + if !isbitstype(widenconst(argextype(x, astate.ir))) + push!(astate.changes, LivenessChange(xidx, livepc)) + end + end + return nothing +end + +function add_alias_change!(astate::AnalysisState, @nospecialize(x), @nospecialize(y)) + if isa(x, GlobalRef) + return add_escape_change!(astate, y, ⊤) + elseif isa(y, GlobalRef) + return add_escape_change!(astate, x, ⊤) + end + estate = astate.estate + xidx = iridx(x, estate) + yidx = iridx(y, estate) + if xidx !== nothing && yidx !== nothing + if !isaliased(xidx, yidx, astate.estate) + pushfirst!(astate.changes, AliasChange(xidx, yidx)) + end + # add new escape change here so that it's shared among the expanded `aliasset` in `propagate_escape_change!` + xinfo = estate.escapes[xidx] + yinfo = estate.escapes[yidx] + add_escape_change!(astate, x, xinfo ⊔ₑ yinfo, #=force=#true) + end + return nothing +end + +struct LocalDef + idx::Int +end +struct LocalUse + idx::Int +end + +function add_alias_escapes!(astate::AnalysisState, @nospecialize(v), ainfo::AInfo) + estate = astate.estate + for x in ainfo + isa(x, LocalUse) || continue # ignore def + x = SSAValue(x.idx) # obviously this won't be true once we implement interprocedural AliasInfo + add_alias_change!(astate, v, x) + end +end + +function add_thrown_escapes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + info = ThrownEscape(pc) + for i in first_idx:last_idx + add_escape_change!(astate, args[i], info) + end +end + +function add_liveness_changes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + for i in first_idx:last_idx + arg = args[i] + add_liveness_change!(astate, arg, pc) + end +end + +function add_fallback_changes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + info = ThrownEscape(pc) + for i in first_idx:last_idx + arg = args[i] + add_escape_change!(astate, arg, info) + add_liveness_change!(astate, arg, pc) + end +end + +function add_conservative_changes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + for i in first_idx:last_idx + add_escape_change!(astate, args[i], ⊤) + end + add_escape_change!(astate, SSAValue(pc), ⊤) # it may return GlobalRef etc. + return nothing +end + +function escape_edges!(astate::AnalysisState, pc::Int, edges::Vector{Any}) + ret = SSAValue(pc) + for i in 1:length(edges) + if isassigned(edges, i) + v = edges[i] + add_alias_change!(astate, ret, v) + end + end +end + +function escape_val_ifdefined!(astate::AnalysisState, pc::Int, x) + if isdefined(x, :val) + escape_val!(astate, pc, x.val) + end +end + +function escape_val!(astate::AnalysisState, pc::Int, @nospecialize(val)) + ret = SSAValue(pc) + add_alias_change!(astate, ret, val) +end + +function escape_unanalyzable_obj!(astate::AnalysisState, @nospecialize(obj), objinfo::EscapeInfo) + objinfo = EscapeInfo(objinfo, true) + add_escape_change!(astate, obj, objinfo) + return objinfo +end + +@noinline function unexpected_assignment!(ir::IRCode, pc::Int) + @eval Main (ir = $ir; pc = $pc) + error("unexpected assignment found: inspect `Main.pc` and `Main.pc`") +end + +is_effect_free(ir::IRCode, pc::Int) = getinst(ir, pc)[:flag] & IR_FLAG_EFFECT_FREE ≠ 0 + +# NOTE if we don't maintain the alias set that is separated from the lattice state, we can do +# something like below: it essentially incorporates forward escape propagation in our default +# backward propagation, and leads to inefficient convergence that requires more iterations +# # lhs = rhs: propagate escape information of `rhs` to `lhs` +# function escape_alias!(astate::AnalysisState, @nospecialize(lhs), @nospecialize(rhs)) +# if isa(rhs, SSAValue) || isa(rhs, Argument) +# vinfo = astate.estate[rhs] +# else +# return +# end +# add_escape_change!(astate, lhs, vinfo) +# end + +""" + escape_exception!(astate::AnalysisState, tryregions::Vector{UnitRange{Int}}) + +Propagates escapes via exceptions that can happen in `tryregions`. + +Naively it seems enough to propagate escape information imposed on `:the_exception` object, +but actually there are several other ways to access to the exception object such as +`Base.current_exceptions` and manual catch of `rethrow`n object. +For example, escape analysis needs to account for potential escape of the allocated object +via `rethrow_escape!()` call in the example below: +```julia +const Gx = Ref{Any}() +@noinline function rethrow_escape!() + try + rethrow() + catch err + Gx[] = err + end +end +unsafeget(x) = isassigned(x) ? x[] : throw(x) + +code_escapes() do + r = Ref{String}() + try + t = unsafeget(r) + catch err + t = typeof(err) # `err` (which `r` may alias to) doesn't escape here + rethrow_escape!() # `r` can escape here + end + return t +end +``` + +As indicated by the above example, it requires a global analysis in addition to a base escape +analysis to reason about all possible escapes via existing exception interfaces correctly. +For now we conservatively always propagate `AllEscape` to all potentially thrown objects, +since such an additional analysis might not be worthwhile to do given that exception handlings +and error paths usually don't need to be very performance sensitive, and optimizations of +error paths might be very ineffective anyway since they are sometimes "unoptimized" +intentionally for latency reasons. +""" +function escape_exception!(astate::AnalysisState, tryregions::Vector{UnitRange{Int}}) + estate = astate.estate + # NOTE if `:the_exception` is the only way to access the exception, we can do: + # exc = SSAValue(pc) + # excinfo = estate[exc] + excinfo = ⊤ + escapes = estate.escapes + for i in 1:length(escapes) + x = escapes[i] + xt = x.ThrownEscape + xt === TOP_THROWN_ESCAPE && @goto propagate_exception_escape # fast pass + for pc in xt + for region in tryregions + pc in region && @goto propagate_exception_escape # early break because of AllEscape + end + end + continue + @label propagate_exception_escape + xval = irval(i, estate) + add_escape_change!(astate, xval, excinfo) + end +end + +# escape statically-resolved call, i.e. `Expr(:invoke, ::MethodInstance, ...)` +escape_invoke!(astate::AnalysisState, pc::Int, args::Vector{Any}) = + escape_invoke!(astate, pc, args, first(args)::MethodInstance, 2) + +function escape_invoke!(astate::AnalysisState, pc::Int, args::Vector{Any}, + linfo::Linfo, first_idx::Int, last_idx::Int = length(args)) + if isa(linfo, InferenceResult) + cache = astate.get_escape_cache(linfo) + linfo = linfo.linfo + else + cache = astate.get_escape_cache(linfo) + end + if cache === nothing + return add_conservative_changes!(astate, pc, args, 2) + else + cache = cache::ArgEscapeCache + end + ret = SSAValue(pc) + retinfo = astate.estate[ret] # escape information imposed on the call statement + method = linfo.def::Method + nargs = Int(method.nargs) + for (i, argidx) in enumerate(first_idx:last_idx) + arg = args[argidx] + if i > nargs + # handle isva signature + # COMBAK will this be invalid once we take alias information into account? + i = nargs + end + arginfo = cache.argescapes[i] + info = from_interprocedural(arginfo, pc) + if has_return_escape(arginfo) + # if this argument can be "returned", in addition to propagating + # the escape information imposed on this call argument within the callee, + # we should also account for possible aliasing of this argument and the returned value + add_escape_change!(astate, arg, info) + add_alias_change!(astate, ret, arg) + else + # if this is simply passed as the call argument, we can just propagate + # the escape information imposed on this call argument within the callee + add_escape_change!(astate, arg, info) + end + end + for (; aidx, bidx) in cache.argaliases + add_alias_change!(astate, args[aidx-(first_idx-1)], args[bidx-(first_idx-1)]) + end + # we should disable the alias analysis on this newly introduced object + add_escape_change!(astate, ret, EscapeInfo(retinfo, true)) +end + +""" + from_interprocedural(arginfo::ArgEscapeInfo, pc::Int) -> x::EscapeInfo + +Reinterprets the escape information imposed on the call argument which is cached as `arginfo` +in the context of the caller frame, where `pc` is the SSA statement number of the return value. +""" +function from_interprocedural(arginfo::ArgEscapeInfo, pc::Int) + has_all_escape(arginfo) && return ⊤ + + ThrownEscape = has_thrown_escape(arginfo) ? LivenessSet(pc) : BOT_THROWN_ESCAPE + + return EscapeInfo( + #=Analyzed=#true, #=ReturnEscape=#false, ThrownEscape, + # FIXME implement interprocedural memory effect-analysis + # currently, this essentially disables the entire field analysis + # it might be okay from the SROA point of view, since we can't remove the allocation + # as far as it's passed to a callee anyway, but still we may want some field analysis + # for e.g. stack allocation or some other IPO optimizations + #=AliasInfo=#true, #=Liveness=#LivenessSet(pc)) +end + +# escape every argument `(args[6:length(args[3])])` and the name `args[1]` +# TODO: we can apply a similar strategy like builtin calls to specialize some foreigncalls +function escape_foreigncall!(astate::AnalysisState, pc::Int, args::Vector{Any}) + nargs = length(args) + if nargs < 6 + # invalid foreigncall, just escape everything + add_conservative_changes!(astate, pc, args) + return + end + argtypes = args[3]::SimpleVector + nargs = length(argtypes) + name = args[1] + nn = normalize(name) + if isa(nn, Symbol) + boundserror_ninds = array_resize_info(nn) + if boundserror_ninds !== nothing + boundserror, ninds = boundserror_ninds + escape_array_resize!(boundserror, ninds, astate, pc, args) + return + end + if is_array_copy(nn) + escape_array_copy!(astate, pc, args) + return + elseif is_array_isassigned(nn) + escape_array_isassigned!(astate, pc, args) + return + end + # if nn === :jl_gc_add_finalizer_th + # # TODO add `FinalizerEscape` ? + # end + end + # NOTE array allocations might have been proven as nothrow (https://github.com/JuliaLang/julia/pull/43565) + nothrow = is_effect_free(astate.ir, pc) + name_info = nothrow ? ⊥ : ThrownEscape(pc) + add_escape_change!(astate, name, name_info) + add_liveness_change!(astate, name, pc) + for i = 1:nargs + # we should escape this argument if it is directly called, + # otherwise just impose ThrownEscape if not nothrow + if argtypes[i] === Any + arg_info = ⊤ + else + arg_info = nothrow ? ⊥ : ThrownEscape(pc) + end + add_escape_change!(astate, args[5+i], arg_info) + add_liveness_change!(astate, args[5+i], pc) + end + for i = (5+nargs):length(args) + arg = args[i] + add_escape_change!(astate, arg, ⊥) + add_liveness_change!(astate, arg, pc) + end +end + +normalize(@nospecialize x) = isa(x, QuoteNode) ? x.value : x + +function escape_call!(astate::AnalysisState, pc::Int, args::Vector{Any}, callinfo::Vector{Any}) + info = callinfo[pc] + if isa(info, Bool) + info && return # known to be no escape + # now cascade to the builtin handling + escape_call!(astate, pc, args) + return + elseif isa(info, CallInfo) + for linfo in info.linfos + escape_invoke!(astate, pc, args, linfo, 1) + end + # accounts for a potential escape via MethodError + info.nothrow || add_thrown_escapes!(astate, pc, args) + return + else + @assert info === missing + # if this call couldn't be analyzed, escape it conservatively + add_conservative_changes!(astate, pc, args) + end +end + +function escape_call!(astate::AnalysisState, pc::Int, args::Vector{Any}) + ir = astate.ir + ft = argextype(first(args), ir, ir.sptypes, ir.argtypes) + f = singleton_type(ft) + if isa(f, Core.IntrinsicFunction) + # XXX somehow `:call` expression can creep in here, ideally we should be able to do: + # argtypes = Any[argextype(args[i], astate.ir) for i = 2:length(args)] + argtypes = Any[] + for i = 2:length(args) + arg = args[i] + push!(argtypes, isexpr(arg, :call) ? Any : argextype(arg, ir)) + end + if intrinsic_nothrow(f, argtypes) + add_liveness_changes!(astate, pc, args, 2) + else + add_fallback_changes!(astate, pc, args, 2) + end + return # TODO accounts for pointer operations? + end + result = escape_builtin!(f, astate, pc, args) + if result === missing + # if this call hasn't been handled by any of pre-defined handlers, escape it conservatively + add_conservative_changes!(astate, pc, args) + return + elseif result === true + add_liveness_changes!(astate, pc, args, 2) + return # ThrownEscape is already checked + else + # we escape statements with the `ThrownEscape` property using the effect-freeness + # computed by `stmt_effect_free` invoked within inlining + # TODO throwness ≠ "effect-free-ness" + if is_effect_free(astate.ir, pc) + add_liveness_changes!(astate, pc, args, 2) + else + add_fallback_changes!(astate, pc, args, 2) + end + return + end +end + +escape_builtin!(@nospecialize(f), _...) = return missing + +# safe builtins +escape_builtin!(::typeof(isa), _...) = return false +escape_builtin!(::typeof(typeof), _...) = return false +escape_builtin!(::typeof(sizeof), _...) = return false +escape_builtin!(::typeof(===), _...) = return false +# not really safe, but `ThrownEscape` will be imposed later +escape_builtin!(::typeof(isdefined), _...) = return false +escape_builtin!(::typeof(throw), _...) = return false + +function escape_builtin!(::typeof(ifelse), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 4 || return false + f, cond, th, el = args + ret = SSAValue(pc) + condt = argextype(cond, astate.ir) + if isa(condt, Const) && (cond = condt.val; isa(cond, Bool)) + if cond + add_alias_change!(astate, th, ret) + else + add_alias_change!(astate, el, ret) + end + else + add_alias_change!(astate, th, ret) + add_alias_change!(astate, el, ret) + end + return false +end + +function escape_builtin!(::typeof(typeassert), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 3 || return false + f, obj, typ = args + ret = SSAValue(pc) + add_alias_change!(astate, ret, obj) + return false +end + +function escape_new!(astate::AnalysisState, pc::Int, args::Vector{Any}) + obj = SSAValue(pc) + objinfo = astate.estate[obj] + AliasInfo = objinfo.AliasInfo + nargs = length(args) + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # AliasInfo of this object hasn't been analyzed yet: set AliasInfo now + typ = widenconst(argextype(obj, astate.ir)) + nflds = fieldcount_noerror(typ) + if nflds === nothing + AliasInfo = Unindexable() + @goto escape_unindexable_def + else + AliasInfo = IndexableFields(nflds) + @goto escape_indexable_def + end + elseif isa(AliasInfo, IndexableFields) + @label escape_indexable_def + # fields are known precisely: propagate escape information imposed on recorded possibilities to the exact field values + infos = AliasInfo.infos + nf = length(infos) + objinfo′ = ignore_aliasinfo(objinfo) + for i in 2:nargs + i-1 > nf && break # may happen when e.g. ϕ-node merges values with different types + arg = args[i] + add_alias_escapes!(astate, arg, infos[i-1]) + push!(infos[i-1], LocalDef(pc)) + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, arg, objinfo′) + add_liveness_change!(astate, arg, pc) + end + add_escape_change!(astate, obj, EscapeInfo(objinfo, AliasInfo)) # update with new AliasInfo + elseif isa(AliasInfo, Unindexable) + @label escape_unindexable_def + # fields are known partially: propagate escape information imposed on recorded possibilities to all fields values + info = AliasInfo.info + objinfo′ = ignore_aliasinfo(objinfo) + for i in 2:nargs + arg = args[i] + add_alias_escapes!(astate, arg, info) + push!(info, LocalDef(pc)) + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, arg, objinfo′) + add_liveness_change!(astate, arg, pc) + end + add_escape_change!(astate, obj, EscapeInfo(objinfo, AliasInfo)) # update with new AliasInfo + else + # this object has been used as array, but it is allocated as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # the fields couldn't be analyzed precisely: propagate the entire escape information + # of this object to all its fields as the most conservative propagation + for i in 2:nargs + arg = args[i] + add_escape_change!(astate, arg, objinfo) + add_liveness_change!(astate, arg, pc) + end + end + if !is_effect_free(astate.ir, pc) + add_thrown_escapes!(astate, pc, args) + end +end + +function escape_builtin!(::typeof(tuple), astate::AnalysisState, pc::Int, args::Vector{Any}) + escape_new!(astate, pc, args) + return false +end + +function analyze_fields(ir::IRCode, @nospecialize(typ), @nospecialize(fld)) + nflds = fieldcount_noerror(typ) + if nflds === nothing + return Unindexable(), 0 + end + if isa(typ, DataType) + fldval = try_compute_field(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx === nothing + return Unindexable(), 0 + end + return IndexableFields(nflds), fidx +end + +function reanalyze_fields(ir::IRCode, AliasInfo::IndexableFields, @nospecialize(typ), @nospecialize(fld)) + nflds = fieldcount_noerror(typ) + if nflds === nothing + return merge_to_unindexable(AliasInfo), 0 + end + if isa(typ, DataType) + fldval = try_compute_field(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx === nothing + return merge_to_unindexable(AliasInfo), 0 + end + infos = AliasInfo.infos + ninfos = length(infos) + if nflds > ninfos + for _ in 1:(nflds-ninfos) + push!(infos, AInfo()) + end + end + return AliasInfo, fidx +end + +function escape_builtin!(::typeof(getfield), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 3 || return false + ir, estate = astate.ir, astate.estate + obj = args[2] + typ = widenconst(argextype(obj, ir)) + if hasintersect(typ, Module) # global load + add_escape_change!(astate, SSAValue(pc), ⊤) + end + if isa(obj, SSAValue) || isa(obj, Argument) + objinfo = estate[obj] + else + return false + end + AliasInfo = objinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # AliasInfo of this object hasn't been analyzed yet: set AliasInfo now + AliasInfo, fidx = analyze_fields(ir, typ, args[3]) + if isa(AliasInfo, IndexableFields) + @goto record_indexable_use + else + @goto record_unindexable_use + end + elseif isa(AliasInfo, IndexableFields) + AliasInfo, fidx = reanalyze_fields(ir, AliasInfo, typ, args[3]) + isa(AliasInfo, Unindexable) && @goto record_unindexable_use + @label record_indexable_use + push!(AliasInfo.infos[fidx], LocalUse(pc)) + add_escape_change!(astate, obj, EscapeInfo(objinfo, AliasInfo)) # update with new AliasInfo + elseif isa(AliasInfo, Unindexable) + @label record_unindexable_use + push!(AliasInfo.info, LocalUse(pc)) + add_escape_change!(astate, obj, EscapeInfo(objinfo, AliasInfo)) # update with new AliasInfo + else + # this object has been used as array, but it is used as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # at the extreme case, a field of `obj` may point to `obj` itself + # so add the alias change here as the most conservative propagation + add_alias_change!(astate, obj, SSAValue(pc)) + end + return false +end + +function escape_builtin!(::typeof(setfield!), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 4 || return false + ir, estate = astate.ir, astate.estate + obj = args[2] + val = args[4] + if isa(obj, SSAValue) || isa(obj, Argument) + objinfo = estate[obj] + else + # unanalyzable object (e.g. obj::GlobalRef): escape field value conservatively + add_escape_change!(astate, val, ⊤) + @goto add_thrown_escapes + end + AliasInfo = objinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # AliasInfo of this object hasn't been analyzed yet: set AliasInfo now + typ = widenconst(argextype(obj, ir)) + AliasInfo, fidx = analyze_fields(ir, typ, args[3]) + if isa(AliasInfo, IndexableFields) + @goto escape_indexable_def + else + @goto escape_unindexable_def + end + elseif isa(AliasInfo, IndexableFields) + typ = widenconst(argextype(obj, ir)) + AliasInfo, fidx = reanalyze_fields(ir, AliasInfo, typ, args[3]) + isa(AliasInfo, Unindexable) && @goto escape_unindexable_def + @label escape_indexable_def + add_alias_escapes!(astate, val, AliasInfo.infos[fidx]) + push!(AliasInfo.infos[fidx], LocalDef(pc)) + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) # update with new AliasInfo + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, val, ignore_aliasinfo(objinfo)) + elseif isa(AliasInfo, Unindexable) + info = AliasInfo.info + @label escape_unindexable_def + add_alias_escapes!(astate, val, AliasInfo.info) + push!(AliasInfo.info, LocalDef(pc)) + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) # update with new AliasInfo + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, val, ignore_aliasinfo(objinfo)) + else + # this object has been used as array, but it is used as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # the field couldn't be analyzed: alias this object to the value being assigned + # as the most conservative propagation (as required for ArgAliasing) + add_alias_change!(astate, val, obj) + end + # also propagate escape information imposed on the return value of this `setfield!` + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, val, ssainfo) + # compute the throwness of this setfield! call here since builtin_nothrow doesn't account for that + @label add_thrown_escapes + argtypes = Any[] + for i = 2:length(args) + push!(argtypes, argextype(args[i], ir)) + end + setfield!_nothrow(argtypes) || add_thrown_escapes!(astate, pc, args, 2) + return true +end + +function escape_builtin!(::typeof(arrayref), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 4 || return false + # check potential thrown escapes from this arrayref call + argtypes = Any[argextype(args[i], astate.ir) for i in 2:length(args)] + boundcheckt = argtypes[1] + aryt = argtypes[2] + if !array_builtin_common_typecheck(boundcheckt, aryt, argtypes, 3) + add_thrown_escapes!(astate, pc, args, 2) + end + ary = args[3] + inbounds = isa(boundcheckt, Const) && !boundcheckt.val::Bool + inbounds || add_escape_change!(astate, ary, ThrownEscape(pc)) + # we don't track precise index information about this array and thus don't know what values + # can be referenced here: directly propagate the escape information imposed on the return + # value of this `arrayref` call to the array itself as the most conservative propagation + # but also with updated index information + estate = astate.estate + if isa(ary, SSAValue) || isa(ary, Argument) + aryinfo = estate[ary] + else + return true + end + AliasInfo = aryinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # AliasInfo of this array hasn't been analyzed yet: set AliasInfo now + idx = array_nd_index(astate, ary, args[4:end]) + if isa(idx, Int) + AliasInfo = IndexableElements(IdDict{Int,AInfo}()) + @goto record_indexable_use + end + AliasInfo = Unindexable() + @goto record_unindexable_use + elseif isa(AliasInfo, IndexableElements) + idx = array_nd_index(astate, ary, args[4:end]) + if !isa(idx, Int) + AliasInfo = merge_to_unindexable(AliasInfo) + @goto record_unindexable_use + end + @label record_indexable_use + info = get!(()->AInfo(), AliasInfo.infos, idx) + push!(info, LocalUse(pc)) + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) # update with new AliasInfo + elseif isa(AliasInfo, Unindexable) + @label record_unindexable_use + push!(AliasInfo.info, LocalUse(pc)) + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) # update with new AliasInfo + else + # this object has been used as struct, but it is used as array here (thus should throw) + # update ary's element information and just handle this case conservatively + aryinfo = escape_unanalyzable_obj!(astate, ary, aryinfo) + @label conservative_propagation + # at the extreme case, an element of `ary` may point to `ary` itself + # so add the alias change here as the most conservative propagation + add_alias_change!(astate, ary, SSAValue(pc)) + end + return true +end + +function escape_builtin!(::typeof(arrayset), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 5 || return false + # check potential escapes from this arrayset call + # NOTE here we essentially only need to account for TypeError, assuming that + # UndefRefError or BoundsError don't capture any of the arguments here + argtypes = Any[argextype(args[i], astate.ir) for i in 2:length(args)] + boundcheckt = argtypes[1] + aryt = argtypes[2] + valt = argtypes[3] + if !(array_builtin_common_typecheck(boundcheckt, aryt, argtypes, 4) && + arrayset_typecheck(aryt, valt)) + add_thrown_escapes!(astate, pc, args, 2) + end + ary = args[3] + val = args[4] + inbounds = isa(boundcheckt, Const) && !boundcheckt.val::Bool + inbounds || add_escape_change!(astate, ary, ThrownEscape(pc)) + # we don't track precise index information about this array and won't record what value + # is being assigned here: directly propagate the escape information of this array to + # the value being assigned as the most conservative propagation + estate = astate.estate + if isa(ary, SSAValue) || isa(ary, Argument) + aryinfo = estate[ary] + else + # unanalyzable object (e.g. obj::GlobalRef): escape field value conservatively + add_escape_change!(astate, val, ⊤) + return true + end + AliasInfo = aryinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # AliasInfo of this array hasn't been analyzed yet: set AliasInfo now + idx = array_nd_index(astate, ary, args[5:end]) + if isa(idx, Int) + AliasInfo = IndexableElements(IdDict{Int,AInfo}()) + @goto escape_indexable_def + end + AliasInfo = Unindexable() + @goto escape_unindexable_def + elseif isa(AliasInfo, IndexableElements) + idx = array_nd_index(astate, ary, args[5:end]) + if !isa(idx, Int) + AliasInfo = merge_to_unindexable(AliasInfo) + @goto escape_unindexable_def + end + @label escape_indexable_def + info = get!(()->AInfo(), AliasInfo.infos, idx) + add_alias_escapes!(astate, val, info) + push!(info, LocalDef(pc)) + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) # update with new AliasInfo + # propagate the escape information of this array ignoring elements information + add_escape_change!(astate, val, ignore_aliasinfo(aryinfo)) + elseif isa(AliasInfo, Unindexable) + @label escape_unindexable_def + add_alias_escapes!(astate, val, AliasInfo.info) + push!(AliasInfo.info, LocalDef(pc)) + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) # update with new AliasInfo + # propagate the escape information of this array ignoring elements information + add_escape_change!(astate, val, ignore_aliasinfo(aryinfo)) + else + # this object has been used as struct, but it is used as array here (thus should throw) + # update ary's element information and just handle this case conservatively + aryinfo = escape_unanalyzable_obj!(astate, ary, aryinfo) + @label conservative_propagation + add_alias_change!(astate, val, ary) + end + # also propagate escape information imposed on the return value of this `arrayset` + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, ary, ssainfo) + return true +end + +# NOTE this function models and thus should be synced with the implementation of: +# size_t array_nd_index(jl_array_t *a, jl_value_t **args, size_t nidxs, ...) +function array_nd_index(astate::AnalysisState, @nospecialize(ary), args::Vector{Any}, nidxs::Int = length(args)) + isa(ary, SSAValue) || return nothing + aryid = ary.id + arrayinfo = astate.estate.arrayinfo + isa(arrayinfo, ArrayInfo) || return nothing + haskey(arrayinfo, aryid) || return nothing + dims = arrayinfo[aryid] + local i = 0 + local k, stride = 0, 1 + local nd = length(dims) + while k < nidxs + arg = args[k+1] + argval = argextype(arg, astate.ir) + isa(argval, Const) || return nothing + argval = argval.val + isa(argval, Int) || return nothing + ii = argval - 1 + i += ii * stride + d = k ≥ nd ? 1 : dims[k+1] + k < nidxs - 1 && ii ≥ d && return nothing # BoundsError + stride *= d + k += 1 + end + while k < nd + stride *= dims[k+1] + k += 1 + end + i ≥ stride && return nothing # BoundsError + return i +end + +function escape_builtin!(::typeof(arraysize), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 3 || return false + ary = args[2] + dim = args[3] + if !arraysize_typecheck(ary, dim, astate.ir) + add_escape_change!(astate, ary, ThrownEscape(pc)) + add_escape_change!(astate, dim, ThrownEscape(pc)) + end + # NOTE we may still see "arraysize: dimension out of range", but it doesn't capture anything + return true +end + +function arraysize_typecheck(@nospecialize(ary), @nospecialize(dim), ir::IRCode) + aryt = argextype(ary, ir) + aryt ⊑ Array || return false + dimt = argextype(dim, ir) + dimt ⊑ Int || return false + return true +end + +# returns nothing if this isn't array resizing operation, +# otherwise returns true if it can throw BoundsError and false if not +function array_resize_info(name::Symbol) + if name === :jl_array_grow_beg || name === :jl_array_grow_end + return false, 1 + elseif name === :jl_array_del_beg || name === :jl_array_del_end + return true, 1 + elseif name === :jl_array_grow_at || name === :jl_array_del_at + return true, 2 + else + return nothing + end +end + +# NOTE may potentially throw "cannot resize array with shared data" error, +# but just ignore it since it doesn't capture anything +function escape_array_resize!(boundserror::Bool, ninds::Int, + astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 6+ninds || return add_fallback_changes!(astate, pc, args) + ary = args[6] + aryt = argextype(ary, astate.ir) + aryt ⊑ Array || return add_fallback_changes!(astate, pc, args) + for i in 1:ninds + ind = args[i+6] + indt = argextype(ind, astate.ir) + indt ⊑ Integer || return add_fallback_changes!(astate, pc, args) + end + if boundserror + # this array resizing can potentially throw `BoundsError`, impose it now + add_escape_change!(astate, ary, ThrownEscape(pc)) + end + # give up indexing analysis whenever we see array resizing + # (since we track array dimensions only globally) + mark_unindexable!(astate, ary) + add_liveness_changes!(astate, pc, args, 6) +end + +function mark_unindexable!(astate::AnalysisState, @nospecialize(ary)) + isa(ary, SSAValue) || return + aryinfo = astate.estate[ary] + AliasInfo = aryinfo.AliasInfo + isa(AliasInfo, IndexableElements) || return + AliasInfo = merge_to_unindexable(AliasInfo) + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) +end + +is_array_copy(name::Symbol) = name === :jl_array_copy + +# FIXME this implementation is very conservative, improve the accuracy and solve broken test cases +function escape_array_copy!(astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 6 || return add_fallback_changes!(astate, pc, args) + ary = args[6] + aryt = argextype(ary, astate.ir) + aryt ⊑ Array || return add_fallback_changes!(astate, pc, args) + if isa(ary, SSAValue) || isa(ary, Argument) + newary = SSAValue(pc) + aryinfo = astate.estate[ary] + newaryinfo = astate.estate[newary] + add_escape_change!(astate, newary, aryinfo) + add_escape_change!(astate, ary, newaryinfo) + end + add_liveness_changes!(astate, pc, args, 6) +end + +is_array_isassigned(name::Symbol) = name === :jl_array_isassigned + +function escape_array_isassigned!(astate::AnalysisState, pc::Int, args::Vector{Any}) + if !array_isassigned_nothrow(args, astate.ir) + add_thrown_escapes!(astate, pc, args) + end + add_liveness_changes!(astate, pc, args, 6) +end + +function array_isassigned_nothrow(args::Vector{Any}, src::IRCode) + # if !validate_foreigncall_args(args, + # :jl_array_isassigned, Cint, svec(Any,Csize_t), 0, :ccall) + # return false + # end + length(args) ≥ 7 || return false + arytype = argextype(args[6], src) + arytype ⊑ Array || return false + idxtype = argextype(args[7], src) + idxtype ⊑ Csize_t || return false + return true +end + +# # COMBAK do we want to enable this (and also backport this to Base for array allocations?) +# import Core.Compiler: Cint, svec +# function validate_foreigncall_args(args::Vector{Any}, +# name::Symbol, @nospecialize(rt), argtypes::SimpleVector, nreq::Int, convension::Symbol) +# length(args) ≥ 5 || return false +# normalize(args[1]) === name || return false +# args[2] === rt || return false +# args[3] === argtypes || return false +# args[4] === vararg || return false +# normalize(args[5]) === convension || return false +# return true +# end + +if isdefined(Core, :ImmutableArray) + +import Core: ImmutableArray, arrayfreeze, mutating_arrayfreeze, arraythaw + +escape_builtin!(::typeof(arrayfreeze), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(Array, astate, args) +escape_builtin!(::typeof(mutating_arrayfreeze), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(Array, astate, args) +escape_builtin!(::typeof(arraythaw), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(ImmutableArray, astate, args) +function is_safe_immutable_array_op(@nospecialize(arytype), astate::AnalysisState, args::Vector{Any}) + length(args) == 2 || return false + argextype(args[2], astate.ir) ⊑ arytype || return false + return true +end + +end # if isdefined(Core, :ImmutableArray) + +if _TOP_MOD !== Core.Compiler + # NOTE define fancy package utilities when developing EA as an external package + include("EAUtils.jl") + using .EAUtils + export code_escapes, @code_escapes, __clear_cache! +end + +end # baremodule EscapeAnalysis diff --git a/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl b/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl new file mode 100644 index 0000000000000..915bc214d7c3c --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl @@ -0,0 +1,143 @@ +# A disjoint set implementation adapted from +# https://github.com/JuliaCollections/DataStructures.jl/blob/f57330a3b46f779b261e6c07f199c88936f28839/src/disjoint_set.jl +# under the MIT license: https://github.com/JuliaCollections/DataStructures.jl/blob/master/License.md + +# imports +import ._TOP_MOD: + length, + eltype, + union!, + push! +# usings +import ._TOP_MOD: + OneTo, collect, zero, zeros, one, typemax + +# Disjoint-Set + +############################################################ +# +# A forest of disjoint sets of integers +# +# Since each element is an integer, we can use arrays +# instead of dictionary (for efficiency) +# +# Disjoint sets over other key types can be implemented +# based on an IntDisjointSet through a map from the key +# to an integer index +# +############################################################ + +_intdisjointset_bounds_err_msg(T) = "the maximum number of elements in IntDisjointSet{$T} is $(typemax(T))" + +""" + IntDisjointSet{T<:Integer}(n::Integer) + +A forest of disjoint sets of integers, which is a data structure +(also called a union–find data structure or merge–find set) +that tracks a set of elements partitioned +into a number of disjoint (non-overlapping) subsets. +""" +mutable struct IntDisjointSet{T<:Integer} + parents::Vector{T} + ranks::Vector{T} + ngroups::T +end + +IntDisjointSet(n::T) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(n)), zeros(T, n), n) +IntDisjointSet{T}(n::Integer) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(T(n))), zeros(T, T(n)), T(n)) +length(s::IntDisjointSet) = length(s.parents) + +""" + num_groups(s::IntDisjointSet) + +Get a number of groups. +""" +num_groups(s::IntDisjointSet) = s.ngroups +eltype(::Type{IntDisjointSet{T}}) where {T<:Integer} = T + +# find the root element of the subset that contains x +# path compression is implemented here +function find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +# unsafe version of the above +function _find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + @inbounds p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +""" + find_root!(s::IntDisjointSet{T}, x::T) + +Find the root element of the subset that contains an member `x`. +Path compression happens here. +""" +find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) + +""" + in_same_set(s::IntDisjointSet{T}, x::T, y::T) + +Returns `true` if `x` and `y` belong to the same subset in `s`, and `false` otherwise. +""" +in_same_set(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} = find_root!(s, x) == find_root!(s, y) + +""" + union!(s::IntDisjointSet{T}, x::T, y::T) + +Merge the subset containing `x` and that containing `y` into one +and return the root of the new set. +""" +function union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + xroot = find_root_impl!(parents, x) + yroot = find_root_impl!(parents, y) + return xroot != yroot ? root_union!(s, xroot, yroot) : xroot +end + +""" + root_union!(s::IntDisjointSet{T}, x::T, y::T) + +Form a new set that is the union of the two sets whose root elements are +`x` and `y` and return the root of the new set. +Assume `x ≠ y` (unsafe). +""" +function root_union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + rks = s.ranks + @inbounds xrank = rks[x] + @inbounds yrank = rks[y] + + if xrank < yrank + x, y = y, x + elseif xrank == yrank + rks[x] += one(T) + end + @inbounds parents[y] = x + s.ngroups -= one(T) + return x +end + +""" + push!(s::IntDisjointSet{T}) + +Make a new subset with an automatically chosen new element `x`. +Returns the new element. Throw an `ArgumentError` if the +capacity of the set would be exceeded. +""" +function push!(s::IntDisjointSet{T}) where {T<:Integer} + l = length(s) + l < typemax(T) || throw(ArgumentError(_intdisjointset_bounds_err_msg(T))) + x = l + one(T) + push!(s.parents, x) + push!(s.ranks, zero(T)) + s.ngroups += one(T) + return x +end diff --git a/base/compiler/ssair/EscapeAnalysis/interprocedural.jl b/base/compiler/ssair/EscapeAnalysis/interprocedural.jl new file mode 100644 index 0000000000000..9880c13db4ad1 --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/interprocedural.jl @@ -0,0 +1,151 @@ +# TODO this file contains many duplications with the inlining analysis code, factor them out + +import Core.Compiler: + MethodInstance, InferenceResult, Signature, ConstResult, + MethodResultPure, MethodMatchInfo, UnionSplitInfo, ConstCallInfo, InvokeCallInfo, + call_sig, argtypes_to_type, is_builtin, is_return_type, istopfunction, validate_sparams, + specialize_method, invoke_rewrite + +const Linfo = Union{MethodInstance,InferenceResult} +struct CallInfo + linfos::Vector{Linfo} + nothrow::Bool +end + +function resolve_call(ir::IRCode, stmt::Expr, @nospecialize(info)) + sig = call_sig(ir, stmt) + if sig === nothing + return missing + end + # TODO handle _apply_iterate + if is_builtin(sig) && sig.f !== invoke + return false + end + # handling corresponding to late_inline_special_case! + (; f, argtypes) = sig + if length(argtypes) == 3 && istopfunction(f, :!==) + return true + elseif length(argtypes) == 3 && istopfunction(f, :(>:)) + return true + elseif f === TypeVar && 2 ≤ length(argtypes) ≤ 4 && (argtypes[2] ⊑ Symbol) + return true + elseif f === UnionAll && length(argtypes) == 3 && (argtypes[2] ⊑ TypeVar) + return true + elseif is_return_type(f) + return true + end + if info isa MethodResultPure + return true + elseif info === false + return missing + end + # TODO handle OpaqueClosureCallInfo + if sig.f === invoke + isa(info, InvokeCallInfo) || return missing + return analyze_invoke_call(sig, info) + elseif isa(info, ConstCallInfo) + return analyze_const_call(sig, info) + elseif isa(info, MethodMatchInfo) + infos = MethodMatchInfo[info] + elseif isa(info, UnionSplitInfo) + infos = info.matches + else # isa(info, ReturnTypeCallInfo), etc. + return missing + end + return analyze_call(sig, infos) +end + +function analyze_invoke_call(sig::Signature, info::InvokeCallInfo) + match = info.match + if !match.fully_covers + # TODO: We could union split out the signature check and continue on + return missing + end + result = info.result + if isa(result, InferenceResult) + return CallInfo(Linfo[result], true) + else + argtypes = invoke_rewrite(sig.argtypes) + mi = analyze_match(match, length(argtypes)) + mi === nothing && return missing + return CallInfo(Linfo[mi], true) + end +end + +function analyze_const_call(sig::Signature, cinfo::ConstCallInfo) + linfos = Linfo[] + (; call, results) = cinfo + infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches + local nothrow = true # required to account for potential escape via MethodError + local j = 0 + for i in 1:length(infos) + meth = infos[i].results + nothrow &= !meth.ambig + nmatch = Core.Compiler.length(meth) + if nmatch == 0 # No applicable methods + # mark this call may potentially throw, and the try next union split + nothrow = false + continue + end + for i = 1:nmatch + j += 1 + result = results[j] + match = Core.Compiler.getindex(meth, i) + if result === nothing + mi = analyze_match(match, length(sig.argtypes)) + mi === nothing && return missing + push!(linfos, mi) + elseif isa(result, ConstResult) + # TODO we may want to feedback information that this call always throws if !isdefined(result, :result) + push!(linfos, result.mi) + else + push!(linfos, result) + end + nothrow &= match.fully_covers + end + end + return CallInfo(linfos, nothrow) +end + +function analyze_call(sig::Signature, infos::Vector{MethodMatchInfo}) + linfos = Linfo[] + local nothrow = true # required to account for potential escape via MethodError + for i in 1:length(infos) + meth = infos[i].results + nothrow &= !meth.ambig + nmatch = Core.Compiler.length(meth) + if nmatch == 0 # No applicable methods + # mark this call may potentially throw, and the try next union split + nothrow = false + continue + end + for i = 1:nmatch + match = Core.Compiler.getindex(meth, i) + mi = analyze_match(match, length(sig.argtypes)) + mi === nothing && return missing + push!(linfos, mi) + nothrow &= match.fully_covers + end + end + return CallInfo(linfos, nothrow) +end + +function analyze_match(match::MethodMatch, npassedargs::Int) + method = match.method + na = Int(method.nargs) + if na != npassedargs && !(na > 0 && method.isva) + # we have a method match only because an earlier + # inference step shortened our call args list, even + # though we have too many arguments to actually + # call this function + return nothing + end + + # Bail out if any static parameters are left as TypeVar + # COMBAK is this needed for escape analysis? + validate_sparams(match.sparams) || return nothing + + # See if there exists a specialization for this method signature + mi = specialize_method(match; preexisting=true) # Union{Nothing, MethodInstance} + return mi +end diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index e54a09fe351b3..7329dafcb1121 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -14,8 +14,10 @@ include("compiler/ssair/basicblock.jl") include("compiler/ssair/domtree.jl") include("compiler/ssair/ir.jl") include("compiler/ssair/slot2ssa.jl") -include("compiler/ssair/passes.jl") include("compiler/ssair/inlining.jl") include("compiler/ssair/verify.jl") include("compiler/ssair/legacy.jl") -#@isdefined(Base) && include("compiler/ssair/show.jl") +function try_compute_field end # imported by EscapeAnalysis +include("compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl") +include("compiler/ssair/passes.jl") +# @isdefined(Base) && include("compiler/ssair/show.jl") diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 2238d43d65b27..e67594f196c90 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -1288,7 +1288,7 @@ function apply_type_nothrow(argtypes::Array{Any, 1}, @nospecialize(rt)) return false end elseif (isa(ai, Const) && isa(ai.val, Type)) || isconstType(ai) - ai = isa(ai, Const) ? ai.val : ai.parameters[1] + ai = isa(ai, Const) ? ai.val : (ai::DataType).parameters[1] if has_free_typevars(u.var.lb) || has_free_typevars(u.var.ub) return false end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 03ba383de4f61..d600df1dbb0a1 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -313,7 +313,7 @@ function CodeInstance( widenconst(result_type), rettype_const, inferred_result, const_flags, first(valid_worlds), last(valid_worlds), # TODO: Actually do something with non-IPO effects - encode_effects(result.ipo_effects), encode_effects(result.ipo_effects), + encode_effects(result.ipo_effects), encode_effects(result.ipo_effects), result.argescapes, relocatability) end diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 3b19b509f11bf..1ef92ea65598e 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -120,16 +120,18 @@ mutable struct InferenceResult linfo::MethodInstance argtypes::Vector{Any} overridden_by_const::BitVector - result # ::Type, or InferenceState if WIP - src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available + result # ::Type, or InferenceState if WIP + src # ::Union{CodeInfo, OptimizationState} if inferred copy is available, nothing otherwise valid_worlds::WorldRange # if inference and optimization is finished - ipo_effects::Effects # if inference is finished - effects::Effects # if optimization is finished + ipo_effects::Effects # if inference is finished + effects::Effects # if optimization is finished + argescapes # ::ArgEscapeCache if optimized, nothing otherwise function InferenceResult(linfo::MethodInstance, arginfo#=::Union{Nothing,Tuple{ArgInfo,InferenceState}}=# = nothing, va_override::Bool = false) argtypes, overridden_by_const = matching_cache_argtypes(linfo, arginfo, va_override) - return new(linfo, argtypes, overridden_by_const, Any, nothing, WorldRange(), Effects(), Effects()) + return new(linfo, argtypes, overridden_by_const, Any, nothing, + WorldRange(), Effects(), Effects(), nothing) end end diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index e97441495f16b..9b1106e964919 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -19,6 +19,8 @@ function _any(@nospecialize(f), a) end return false end +any(@nospecialize(f), itr) = _any(f, itr) +any(itr) = _any(identity, itr) function _all(@nospecialize(f), a) for x in a @@ -26,6 +28,8 @@ function _all(@nospecialize(f), a) end return true end +all(@nospecialize(f), itr) = _all(f, itr) +all(itr) = _all(identity, itr) function contains_is(itr, @nospecialize(x)) for y in itr diff --git a/doc/make.jl b/doc/make.jl index 8be3b807400d1..bb7ef83048178 100644 --- a/doc/make.jl +++ b/doc/make.jl @@ -148,6 +148,7 @@ DevDocs = [ "devdocs/require.md", "devdocs/inference.md", "devdocs/ssair.md", + "devdocs/EscapeAnalysis.md", "devdocs/gc-sa.md", ], "Developing/debugging Julia's C code" => [ diff --git a/doc/src/devdocs/EscapeAnalysis.md b/doc/src/devdocs/EscapeAnalysis.md new file mode 100644 index 0000000000000..de09dfec48c42 --- /dev/null +++ b/doc/src/devdocs/EscapeAnalysis.md @@ -0,0 +1,363 @@ +`Core.Compiler.EscapeAnalysis` is a compiler utility module that aims to analyze +escape information of [Julia's SSA-form IR](@ref Julia-SSA-form-IR) a.k.a. `IRCode`. + +You can give a try to the escape analysis by loading the `EAUtils.jl` utility script that +define the convenience entries `code_escapes` and `@code_escapes` for testing and debugging purposes: +```@repl EAUtils +include(normpath(Sys.BINDIR::String, "..", "share", "julia", "test", "compiler", "EscapeAnalysis", "EAUtils.jl")) +using EAUtils + +mutable struct SafeRef{T} + x::T +end +Base.getindex(x::SafeRef) = x.x; +Base.setindex!(x::SafeRef, v) = x.x = v; +Base.isassigned(x::SafeRef) = true; +get′(x) = isassigned(x) ? x[] : throw(x); + +result = code_escapes((String,String,String,String)) do s1, s2, s3, s4 + r1 = Ref(s1) + r2 = Ref(s2) + r3 = SafeRef(s3) + try + s1 = get′(r1) + ret = sizeof(s1) + catch err + global GV = err # will definitely escape `r1` + end + s2 = get′(r2) # still `r2` doesn't escape fully + s3 = get′(r3) # still `r3` doesn't escape fully + s4 = sizeof(s4) # the argument `s4` doesn't escape here + return s2, s3, s4 +end +``` + +The symbols in the side of each call argument and SSA statements represents the following meaning: +- `◌` (plain): this value is not analyzed because escape information of it won't be used anyway (when the object is `isbitstype` for example) +- `✓` (green or cyan): this value never escapes (`has_no_escape(result.state[x])` holds), colored blue if it has arg escape also (`has_arg_escape(result.state[x])` holds) +- `↑` (blue or yellow): this value can escape to the caller via return (`has_return_escape(result.state[x])` holds), colored yellow if it has unhandled thrown escape also (`has_thrown_escape(result.state[x])` holds) +- `X` (red): this value can escape to somewhere the escape analysis can't reason about like escapes to a global memory (`has_all_escape(result.state[x])` holds) +- `*` (bold): this value's escape state is between the `ReturnEscape` and `AllEscape` in the partial order of [`EscapeInfo`](@ref Core.Compiler.EscapeAnalysis.EscapeInfo), colored yellow if it has unhandled thrown escape also (`has_thrown_escape(result.state[x])` holds) +- `′`: this value has additional object field / array element information in its `AliasInfo` property + +Escape information of each call argument and SSA value can be inspected programmatically as like: +```@repl EAUtils +result.state[Core.Argument(3)] # get EscapeInfo of `s2` + +result.state[Core.SSAValue(3)] # get EscapeInfo of `r3` +``` + +## Analysis Design + +### Lattice Design + +`EscapeAnalysis` is implemented as a [data-flow analysis](https://en.wikipedia.org/wiki/Data-flow_analysis) +that works on a lattice of `x::EscapeInfo`, which is composed of the following properties: +- `x.Analyzed::Bool`: not formally part of the lattice, only indicates `x` has not been analyzed or not +- `x.ReturnEscape::BitSet`: records SSA statements where `x` can escape to the caller via return +- `x.ThrownEscape::BitSet`: records SSA statements where `x` can be thrown as exception + (used for the [exception handling](@ref EA-Exception-Handling) described below) +- `x.AliasInfo`: maintains all possible values that can be aliased to fields or array elements of `x` + (used for the [alias analysis](@ref EA-Alias-Analysis) described below) +- `x.ArgEscape::Int` (not implemented yet): indicates it will escape to the caller through + `setfield!` on argument(s) + +These attributes can be combined to create a partial lattice that has a finite height, given +the invariant that an input program has a finite number of statements, which is assured by Julia's semantics. +The clever part of this lattice design is that it enables a simpler implementation of +lattice operations by allowing them to handle each lattice property separately[^LatticeDesign]. + +### Backward Escape Propagation + +This escape analysis implementation is based on the data-flow algorithm described in the paper[^MM02]. +The analysis works on the lattice of `EscapeInfo` and transitions lattice elements from the +bottom to the top until every lattice element gets converged to a fixed point by maintaining +a (conceptual) working set that contains program counters corresponding to remaining SSA +statements to be analyzed. The analysis manages a single global state that tracks +`EscapeInfo` of each argument and SSA statement, but also note that some flow-sensitivity +is encoded as program counters recorded in `EscapeInfo`'s `ReturnEscape` property, +which can be combined with domination analysis later to reason about flow-sensitivity if necessary. + +One distinctive design of this escape analysis is that it is fully _backward_, +i.e. escape information flows _from usages to definitions_. +For example, in the code snippet below, EA first analyzes the statement `return %1` and +imposes `ReturnEscape` on `%1` (corresponding to `obj`), and then it analyzes +`%1 = %new(Base.RefValue{String, _2}))` and propagates the `ReturnEscape` imposed on `%1` +to the call argument `_2` (corresponding to `s`): +```@repl EAUtils +code_escapes((String,)) do s + obj = Ref(s) + return obj +end +``` + +The key observation here is that this backward analysis allows escape information to flow +naturally along the use-def chain rather than control-flow[^BackandForth]. +As a result this scheme enables a simple implementation of escape analysis, +e.g. `PhiNode` for example can be handled simply by propagating escape information +imposed on a `PhiNode` to its predecessor values: +```@repl EAUtils +code_escapes((Bool, String, String)) do cnd, s, t + if cnd + obj = Ref(s) + else + obj = Ref(t) + end + return obj +end +``` + +### [Alias Analysis](@id EA-Alias-Analysis) + +`EscapeAnalysis` implements a backward field analysis in order to reason about escapes +imposed on object fields with certain accuracy, +and `x::EscapeInfo`'s `x.AliasInfo` property exists for this purpose. +It records all possible values that can be aliased to fields of `x` at "usage" sites, +and then the escape information of that recorded values are propagated to the actual field values later at "definition" sites. +More specifically, the analysis records a value that may be aliased to a field of object by analyzing `getfield` call, +and then it propagates its escape information to the field when analyzing `%new(...)` expression or `setfield!` call[^Dynamism]. +```@repl EAUtils +code_escapes((String,)) do s + obj = SafeRef("init") + obj[] = s + v = obj[] + return v +end +``` +In the example above, `ReturnEscape` imposed on `%3` (corresponding to `v`) is _not_ directly +propagated to `%1` (corresponding to `obj`) but rather that `ReturnEscape` is only propagated +to `_2` (corresponding to `s`). Here `%3` is recorded in `%1`'s `AliasInfo` property as +it can be aliased to the first field of `%1`, and then when analyzing `Base.setfield!(%1, :x, _2)::String`, +that escape information is propagated to `_2` but not to `%1`. + +So `EscapeAnalysis` tracks which IR elements can be aliased across a `getfield`-`%new`/`setfield!` chain +in order to analyze escapes of object fields, but actually this alias analysis needs to be +generalized to handle other IR elements as well. This is because in Julia IR the same +object is sometimes represented by different IR elements and so we should make sure that those +different IR elements that actually can represent the same object share the same escape information. +IR elements that return the same object as their operand(s), such as `PiNode` and `typeassert`, +can cause that IR-level aliasing and thus requires escape information imposed on any of such +aliased values to be shared between them. +More interestingly, it is also needed for correctly reasoning about mutations on `PhiNode`. +Let's consider the following example: +```@repl EAUtils +code_escapes((Bool, String,)) do cond, x + if cond + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + ϕ2[] = x + y = ϕ1[] + return y +end +``` +`ϕ1 = %5` and `ϕ2 = %6` are aliased and thus `ReturnEscape` imposed on `%8 = Base.getfield(%6, :x)::String` (corresponding to `y = ϕ1[]`) +needs to be propagated to `Base.setfield!(%5, :x, _3)::String` (corresponding to `ϕ2[] = x`). +In order for such escape information to be propagated correctly, the analysis should recognize that +the _predecessors_ of `ϕ1` and `ϕ2` can be aliased as well and equalize their escape information. + +One interesting property of such aliasing information is that it is not known at "usage" site +but can only be derived at "definition" site (as aliasing is conceptually equivalent to assignment), +and thus it doesn't naturally fit in a backward analysis. In order to efficiently propagate escape +information between related values, EscapeAnalysis.jl uses an approach inspired by the escape +analysis algorithm explained in an old JVM paper[^JVM05]. That is, in addition to managing +escape lattice elements, the analysis also maintains an "equi"-alias set, a disjoint set of +aliased arguments and SSA statements. The alias set manages values that can be aliased to +each other and allows escape information imposed on any of such aliased values to be equalized +between them. + +### [Array Analysis](@id EA-Array-Analysis) + +The alias analysis for object fields described above can also be generalized to analyze array operations. +`EscapeAnalysis` implements handlings for various primitive array operations so that it can propagate +escapes via `arrayref`-`arrayset` use-def chain and does not escape allocated arrays too conservatively: +```@repl EAUtils +code_escapes((String,)) do s + ary = Any[] + push!(ary, SafeRef(s)) + return ary[1], length(ary) +end +``` +In the above example `EscapeAnalysis` understands that `%20` and `%2` (corresponding to the allocated object `SafeRef(s)`) +are aliased via the `arrayset`-`arrayref` chain and imposes `ReturnEscape` on them, +but not impose it on the allocated array `%1` (corresponding to `ary`). +`EscapeAnalysis` still imposes `ThrownEscape` on `ary` since it also needs to account for +potential escapes via `BoundsError`, but also note that such unhandled `ThrownEscape` can +often be ignored when optimizing the `ary` allocation. + +Furthermore, in cases when array index information as well as array dimensions can be known _precisely_, +`EscapeAnalysis` is able to even reason about "per-element" aliasing via `arrayref`-`arrayset` chain, +as `EscapeAnalysis` does "per-field" alias analysis for objects: +```@repl EAUtils +code_escapes((String,String)) do s, t + ary = Vector{Any}(undef, 2) + ary[1] = SafeRef(s) + ary[2] = SafeRef(t) + return ary[1], length(ary) +end +``` +Note that `ReturnEscape` is only imposed on `%2` (corresponding to `SafeRef(s)`) but not on `%4` (corresponding to `SafeRef(t)`). +This is because the allocated array's dimension and indices involved with all `arrayref`/`arrayset` +operations are available as constant information and `EscapeAnalysis` can understand that +`%6` is aliased to `%2` but never be aliased to `%4`. +In this kind of case, the succeeding optimization passes will be able to +replace `Base.arrayref(true, %1, 1)::Any` with `%2` (a.k.a. "load-forwarding") and +eventually eliminate the allocation of array `%1` entirely (a.k.a. "scalar-replacement"). + +When compared to object field analysis, where an access to object field can be analyzed trivially +using type information derived by inference, array dimension isn't encoded as type information +and so we need an additional analysis to derive that information. `EscapeAnalysis` at this moment +first does an additional simple linear scan to analyze dimensions of allocated arrays before +firing up the main analysis routine so that the succeeding escape analysis can precisely +analyze operations on those arrays. + +However, such precise "per-element" alias analysis is often hard. +Essentially, the main difficulty inherit to array is that array dimension and index are often non-constant: +- loop often produces loop-variant, non-constant array indices +- (specific to vectors) array resizing changes array dimension and invalidates its constant-ness + +Let's discuss those difficulties with concrete examples. + +In the following example, `EscapeAnalysis` fails the precise alias analysis since the index +at the `Base.arrayset(false, %4, %8, %6)::Vector{Any}` is not (trivially) constant. +Especially `Any[nothing, nothing]` forms a loop and calls that `arrayset` operation in a loop, +where `%6` is represented as a ϕ-node value (whose value is control-flow dependent). +As a result, `ReturnEscape` ends up imposed on both `%23` (corresponding to `SafeRef(s)`) and +`%25` (corresponding to `SafeRef(t)`), although ideally we want it to be imposed only on `%23` but not on `%25`: +```@repl EAUtils +code_escapes((String,String)) do s, t + ary = Any[nothing, nothing] + ary[1] = SafeRef(s) + ary[2] = SafeRef(t) + return ary[1], length(ary) +end +``` + +The next example illustrates how vector resizing makes precise alias analysis hard. +The essential difficulty is that the dimension of allocated array `%1` is first initialized as `0`, +but it changes by the two `:jl_array_grow_end` calls afterwards. +`EscapeAnalysis` currently simply gives up precise alias analysis whenever it encounters any +array resizing operations and so `ReturnEscape` is imposed on both `%2` (corresponding to `SafeRef(s)`) +and `%20` (corresponding to `SafeRef(t)`): +```@repl EAUtils +code_escapes((String,String)) do s, t + ary = Any[] + push!(ary, SafeRef(s)) + push!(ary, SafeRef(t)) + ary[1], length(ary) +end +``` + +In order to address these difficulties, we need inference to be aware of array dimensions +and propagate array dimensions in a flow-sensitive way[^ArrayDimension], as well as come +up with nice representation of loop-variant values. + +`EscapeAnalysis` at this moment quickly switches to the more imprecise analysis that doesn't +track precise index information in cases when array dimensions or indices are trivially non +constant. The switch can naturally be implemented as a lattice join operation of +`EscapeInfo.AliasInfo` property in the data-flow analysis framework. + +### [Exception Handling](@id EA-Exception-Handling) + +It would be also worth noting how `EscapeAnalysis` handles possible escapes via exceptions. +Naively it seems enough to propagate escape information imposed on `:the_exception` object to +all values that may be thrown in a corresponding `try` block. +But there are actually several other ways to access to the exception object in Julia, +such as `Base.current_exceptions` and `rethrow`. +For example, escape analysis needs to account for potential escape of `r` in the example below: +```@repl EAUtils +const GR = Ref{Any}(); +@noinline function rethrow_escape!() + try + rethrow() + catch err + GR[] = err + end +end; +get′(x) = isassigned(x) ? x[] : throw(x); + +code_escapes() do + r = Ref{String}() + local t + try + t = get′(r) + catch err + t = typeof(err) # `err` (which `r` aliases to) doesn't escape here + rethrow_escape!() # but `r` escapes here + end + return t +end +``` + +It requires a global analysis in order to correctly reason about all possible escapes via +existing exception interfaces. For now we always propagate the topmost escape information to +all potentially thrown objects conservatively, since such an additional analysis might not be +worthwhile to do given that exception handling and error path usually don't need to be +very performance sensitive, and also optimizations of error paths might be very ineffective anyway +since they are often even "unoptimized" intentionally for latency reasons. + +`x::EscapeInfo`'s `x.ThrownEscape` property records SSA statements where `x` can be thrown as an exception. +Using this information `EscapeAnalysis` can propagate possible escapes via exceptions limitedly +to only those may be thrown in each `try` region: +```@repl EAUtils +result = code_escapes((String,String)) do s1, s2 + r1 = Ref(s1) + r2 = Ref(s2) + local ret + try + s1 = get′(r1) + ret = sizeof(s1) + catch err + global GV = err # will definitely escape `r1` + end + s2 = get′(r2) # still `r2` doesn't escape fully + return s2 +end +``` + +## Analysis Usage + +When using `EscapeAnalysis` in Julia's high-level compilation pipeline, we can run +`analyze_escapes(ir::IRCode) -> estate::EscapeState` to analyze escape information of each SSA-IR element in `ir`. + +Note that it should be most effective if `analyze_escapes` runs after inlining, +as `EscapeAnalysis`'s interprocedural escape information handling is limited at this moment. + +Since the computational cost of `analyze_escapes` is not that cheap, +it is more ideal if it runs once and succeeding optimization passes incrementally update + the escape information upon IR transformation. + +```@docs +Core.Compiler.EscapeAnalysis.analyze_escapes +Core.Compiler.EscapeAnalysis.EscapeState +Core.Compiler.EscapeAnalysis.EscapeInfo +``` + +[^LatticeDesign]: Our type inference implementation takes the alternative approach, + where each lattice property is represented by a special lattice element type object. + It turns out that it started to complicate implementations of the lattice operations + mainly because it often requires conversion rules between each lattice element type object. + And we are working on [overhauling our type inference lattice implementation](https://github.com/JuliaLang/julia/pull/42596) + with `EscapeInfo`-like lattice design. + +[^MM02]: _A Graph-Free approach to Data-Flow Analysis_. + Markas Mohnen, 2002, April. + . + +[^BackandForth]: Our type inference algorithm in contrast is implemented as a forward analysis, + because type information usually flows from "definition" to "usage" and it is more + natural and effective to propagate such information in a forward way. + +[^Dynamism]: In some cases, however, object fields can't be analyzed precisely. + For example, object may escape to somewhere `EscapeAnalysis` can't account for possible memory effects on it, + or fields of the objects simply can't be known because of the lack of type information. + In such cases `AliasInfo` property is raised to the topmost element within its own lattice order, + and it causes succeeding field analysis to be conservative and escape information imposed on + fields of an unanalyzable object to be propagated to the object itself. + +[^JVM05]: _Escape Analysis in the Context of Dynamic Compilation and Deoptimization_. + Thomas Kotzmann and Hanspeter Mössenböck, 2005, June. + . + +[^ArrayDimension]: Otherwise we will need yet another forward data-flow analysis on top of the escape analysis. diff --git a/doc/src/devdocs/llvm.md b/doc/src/devdocs/llvm.md index 1e983949ea0b6..840822f136004 100644 --- a/doc/src/devdocs/llvm.md +++ b/doc/src/devdocs/llvm.md @@ -28,7 +28,7 @@ The difference between an intrinsic and a builtin is that a builtin is a first c that can be used like any other Julia function. An intrinsic can operate only on unboxed data, and therefore its arguments must be statically typed. -### Alias Analysis +### [Alias Analysis](@id LLVM-Alias-Analysis) Julia currently uses LLVM's [Type Based Alias Analysis](https://llvm.org/docs/LangRef.html#tbaa-metadata). To find the comments that document the inclusion relationships, look for `static MDNode*` in diff --git a/src/dump.c b/src/dump.c index 168034d89236d..f2c8629ca9c8b 100644 --- a/src/dump.c +++ b/src/dump.c @@ -524,12 +524,14 @@ static void jl_serialize_code_instance(jl_serializer_state *s, jl_code_instance_ jl_serialize_value(s, codeinst->inferred); jl_serialize_value(s, codeinst->rettype_const); jl_serialize_value(s, codeinst->rettype); + jl_serialize_value(s, codeinst->argescapes); } else { // skip storing useless data jl_serialize_value(s, NULL); jl_serialize_value(s, NULL); jl_serialize_value(s, jl_any_type); + jl_serialize_value(s, jl_nothing); } write_uint8(s->s, codeinst->relocatability); jl_serialize_code_instance(s, codeinst->next, skip_partial_opaque); @@ -1667,6 +1669,8 @@ static jl_value_t *jl_deserialize_value_code_instance(jl_serializer_state *s, jl jl_gc_wb(codeinst, codeinst->rettype_const); codeinst->rettype = jl_deserialize_value(s, &codeinst->rettype); jl_gc_wb(codeinst, codeinst->rettype); + codeinst->argescapes = jl_deserialize_value(s, &codeinst->argescapes); + jl_gc_wb(codeinst, codeinst->argescapes); if (constret) codeinst->invoke = jl_fptr_const_return; if ((flags >> 3) & 1) diff --git a/src/gf.c b/src/gf.c index 7c42a9b802df3..01d03fe77394f 100644 --- a/src/gf.c +++ b/src/gf.c @@ -207,7 +207,8 @@ JL_DLLEXPORT jl_code_instance_t* jl_new_codeinst( jl_method_instance_t *mi, jl_value_t *rettype, jl_value_t *inferred_const, jl_value_t *inferred, int32_t const_flags, size_t min_world, size_t max_world, - uint8_t ipo_effects, uint8_t effects, uint8_t relocatability); + uint8_t ipo_effects, uint8_t effects, jl_value_t *argescapes, + uint8_t relocatability); JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT, jl_code_instance_t *ci JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED); @@ -244,7 +245,7 @@ jl_datatype_t *jl_mk_builtin_func(jl_datatype_t *dt, const char *name, jl_fptr_a jl_code_instance_t *codeinst = jl_new_codeinst(mi, (jl_value_t*)jl_any_type, jl_nothing, jl_nothing, - 0, 1, ~(size_t)0, 0, 0, 0); + 0, 1, ~(size_t)0, 0, 0, jl_nothing, 0); jl_mi_cache_insert(mi, codeinst); codeinst->specptr.fptr1 = fptr; codeinst->invoke = jl_fptr_args; @@ -367,7 +368,7 @@ JL_DLLEXPORT jl_code_instance_t *jl_get_method_inferred( } codeinst = jl_new_codeinst( mi, rettype, NULL, NULL, - 0, min_world, max_world, 0, 0, 0); + 0, min_world, max_world, 0, 0, jl_nothing, 0); jl_mi_cache_insert(mi, codeinst); return codeinst; } @@ -376,7 +377,8 @@ JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst( jl_method_instance_t *mi, jl_value_t *rettype, jl_value_t *inferred_const, jl_value_t *inferred, int32_t const_flags, size_t min_world, size_t max_world, - uint8_t ipo_effects, uint8_t effects, uint8_t relocatability + uint8_t ipo_effects, uint8_t effects, jl_value_t *argescapes, + uint8_t relocatability /*, jl_array_t *edges, int absolute_max*/) { jl_task_t *ct = jl_current_task; @@ -401,9 +403,10 @@ JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst( codeinst->isspecsig = 0; codeinst->precompile = 0; codeinst->next = NULL; - codeinst->relocatability = relocatability; codeinst->ipo_purity_bits = ipo_effects; codeinst->purity_bits = effects; + codeinst->argescapes = argescapes; + codeinst->relocatability = relocatability; return codeinst; } @@ -2013,7 +2016,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t if (unspec && jl_atomic_load_relaxed(&unspec->invoke)) { jl_code_instance_t *codeinst = jl_new_codeinst(mi, (jl_value_t*)jl_any_type, NULL, NULL, - 0, 1, ~(size_t)0, 0, 0, 0); + 0, 1, ~(size_t)0, 0, 0, jl_nothing, 0); codeinst->isspecsig = 0; codeinst->specptr = unspec->specptr; codeinst->rettype_const = unspec->rettype_const; @@ -2031,7 +2034,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t if (!jl_code_requires_compiler(src)) { jl_code_instance_t *codeinst = jl_new_codeinst(mi, (jl_value_t*)jl_any_type, NULL, NULL, - 0, 1, ~(size_t)0, 0, 0, 0); + 0, 1, ~(size_t)0, 0, 0, jl_nothing, 0); codeinst->invoke = jl_fptr_interpret_call; jl_mi_cache_insert(mi, codeinst); record_precompile_statement(mi); @@ -2066,7 +2069,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t return ucache; } codeinst = jl_new_codeinst(mi, (jl_value_t*)jl_any_type, NULL, NULL, - 0, 1, ~(size_t)0, 0, 0, 0); + 0, 1, ~(size_t)0, 0, 0, jl_nothing, 0); codeinst->isspecsig = 0; codeinst->specptr = ucache->specptr; codeinst->rettype_const = ucache->rettype_const; diff --git a/src/jltypes.c b/src/jltypes.c index f6f9db0762810..86630ac39c059 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2492,7 +2492,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_code_instance_type = jl_new_datatype(jl_symbol("CodeInstance"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(14, + jl_perm_symsvec(15, "def", "next", "min_world", @@ -2502,10 +2502,11 @@ void jl_init_types(void) JL_GC_DISABLED "inferred", //"edges", //"absolute_max", - "ipo_purity_bits", "purity_bits", + "ipo_purity_bits", "purity_bits", + "argescapes", "isspecsig", "precompile", "invoke", "specptr", // function object decls "relocatability"), - jl_svec(14, + jl_svec(15, jl_method_instance_type, jl_any_type, jl_ulong_type, @@ -2515,7 +2516,8 @@ void jl_init_types(void) JL_GC_DISABLED jl_any_type, //jl_any_type, //jl_bool_type, - jl_uint8_type, jl_uint8_type, + jl_uint8_type, jl_uint8_type, + jl_any_type, jl_bool_type, jl_bool_type, jl_any_type, jl_any_type, // fptrs @@ -2668,8 +2670,8 @@ void jl_init_types(void) JL_GC_DISABLED jl_svecset(jl_methtable_type->types, 11, jl_uint8_type); jl_svecset(jl_method_type->types, 12, jl_method_instance_type); jl_svecset(jl_method_instance_type->types, 6, jl_code_instance_type); - jl_svecset(jl_code_instance_type->types, 11, jl_voidpointer_type); jl_svecset(jl_code_instance_type->types, 12, jl_voidpointer_type); + jl_svecset(jl_code_instance_type->types, 13, jl_voidpointer_type); jl_compute_field_offsets(jl_datatype_type); jl_compute_field_offsets(jl_typename_type); diff --git a/src/julia.h b/src/julia.h index 20edd53ad39a7..f3905897a1202 100644 --- a/src/julia.h +++ b/src/julia.h @@ -410,6 +410,7 @@ typedef struct _jl_code_instance_t { uint8_t terminates:2; } purity_flags; }; + jl_value_t *argescapes; // escape information of call arguments // compilation state cache uint8_t isspecsig; // if specptr is a specialized function signature for specTypes->rettype diff --git a/test/choosetests.jl b/test/choosetests.jl index e00aedffdd42e..f86f665bc2217 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -142,7 +142,10 @@ function choosetests(choices = []) filtertests!(tests, "subarray") filtertests!(tests, "compiler", ["compiler/inference", "compiler/validation", "compiler/ssair", "compiler/irpasses", "compiler/codegen", - "compiler/inline", "compiler/contextual"]) + "compiler/inline", "compiler/contextual", + "compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"]) + filtertests!(tests, "compiler/EscapeAnalysis", [ + "compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"]) filtertests!(tests, "stdlib", STDLIBS) # do ambiguous first to avoid failing if ambiguities are introduced by other tests filtertests!(tests, "ambiguous") diff --git a/test/compiler/EscapeAnalysis/EAUtils.jl b/test/compiler/EscapeAnalysis/EAUtils.jl new file mode 100644 index 0000000000000..3ae9b41a0ddac --- /dev/null +++ b/test/compiler/EscapeAnalysis/EAUtils.jl @@ -0,0 +1,385 @@ +module EAUtils + +export code_escapes, @code_escapes, __clear_cache! + +const CC = Core.Compiler +const EA = CC.EscapeAnalysis + +# entries +# ------- + +import Base: unwrap_unionall, rewrap_unionall +import InteractiveUtils: gen_call_with_extracted_types_and_kwargs + +""" + @code_escapes [options...] f(args...) + +Evaluates the arguments to the function call, determines its types, and then calls +[`code_escapes`](@ref) on the resulting expression. +As with `@code_typed` and its family, any of `code_escapes` keyword arguments can be given +as the optional arguments like `@code_escapes optimize=false myfunc(myargs...)`. +""" +macro code_escapes(ex0...) + return gen_call_with_extracted_types_and_kwargs(__module__, :code_escapes, ex0) +end + +""" + code_escapes(f, argtypes=Tuple{}; [debuginfo::Symbol = :none], [optimize::Bool = true]) -> result::EscapeResult + +Runs the escape analysis on optimized IR of a generic function call with the given type signature. + +# Keyword Arguments + +- `optimize::Bool = true`: + if `true` returns escape information of post-inlining IR (used for local optimization), + otherwise returns escape information of pre-inlining IR (used for interprocedural escape information generation) +- `debuginfo::Symbol = :none`: + controls the amount of code metadata present in the output, possible options are `:none` or `:source`. +""" +function code_escapes(@nospecialize(f), @nospecialize(types=Base.default_tt(f)); + world::UInt = get_world_counter(), + interp::Core.Compiler.AbstractInterpreter = Core.Compiler.NativeInterpreter(world), + debuginfo::Symbol = :none, + optimize::Bool = true) + ft = Core.Typeof(f) + if isa(types, Type) + u = unwrap_unionall(types) + tt = rewrap_unionall(Tuple{ft, u.parameters...}, types) + else + tt = Tuple{ft, types...} + end + interp = EscapeAnalyzer(interp, tt, optimize) + results = Base.code_typed_by_type(tt; optimize=true, world, interp) + isone(length(results)) || throw(ArgumentError("`code_escapes` only supports single analysis result")) + return EscapeResult(interp.ir, interp.state, interp.linfo, debuginfo===:source) +end + +# in order to run a whole analysis from ground zero (e.g. for benchmarking, etc.) +__clear_cache!() = empty!(GLOBAL_CODE_CACHE) + +# AbstractInterpreter +# ------------------- + +# imports +import .CC: + AbstractInterpreter, NativeInterpreter, WorldView, WorldRange, + InferenceParams, OptimizationParams, get_world_counter, get_inference_cache, code_cache, + lock_mi_inference, unlock_mi_inference, add_remark!, + may_optimize, may_compress, may_discard_trees, verbose_stmt_info +# usings +import Core: + CodeInstance, MethodInstance, CodeInfo +import .CC: + InferenceResult, OptimizationState, IRCode, copy as cccopy, + @timeit, convert_to_ircode, slot2reg, compact!, ssa_inlining_pass!, sroa_pass!, + adce_pass!, type_lift_pass!, JLOptions, verify_ir, verify_linetable +import .EA: analyze_escapes, ArgEscapeCache, EscapeInfo, EscapeState, is_ipo_profitable + +# when working outside of Core.Compiler, +# cache entire escape state for later inspection and debugging +struct EscapeCache + cache::ArgEscapeCache + state::EscapeState # preserved just for debugging purpose + ir::IRCode # preserved just for debugging purpose +end + +mutable struct EscapeAnalyzer{State} <: AbstractInterpreter + native::NativeInterpreter + cache::IdDict{InferenceResult,EscapeCache} + entry_tt + optimize::Bool + ir::IRCode + state::State + linfo::MethodInstance + EscapeAnalyzer(native::NativeInterpreter, @nospecialize(tt), optimize::Bool) = + new{EscapeState}(native, IdDict{InferenceResult,EscapeCache}(), tt, optimize) +end + +CC.InferenceParams(interp::EscapeAnalyzer) = InferenceParams(interp.native) +CC.OptimizationParams(interp::EscapeAnalyzer) = OptimizationParams(interp.native) +CC.get_world_counter(interp::EscapeAnalyzer) = get_world_counter(interp.native) + +CC.lock_mi_inference(::EscapeAnalyzer, ::MethodInstance) = nothing +CC.unlock_mi_inference(::EscapeAnalyzer, ::MethodInstance) = nothing + +CC.add_remark!(interp::EscapeAnalyzer, sv, s) = add_remark!(interp.native, sv, s) + +CC.may_optimize(interp::EscapeAnalyzer) = may_optimize(interp.native) +CC.may_compress(interp::EscapeAnalyzer) = may_compress(interp.native) +CC.may_discard_trees(interp::EscapeAnalyzer) = may_discard_trees(interp.native) +CC.verbose_stmt_info(interp::EscapeAnalyzer) = verbose_stmt_info(interp.native) + +CC.get_inference_cache(interp::EscapeAnalyzer) = get_inference_cache(interp.native) + +const GLOBAL_CODE_CACHE = IdDict{MethodInstance,CodeInstance}() + +function CC.code_cache(interp::EscapeAnalyzer) + worlds = WorldRange(get_world_counter(interp)) + return WorldView(GlobalCache(), worlds) +end + +struct GlobalCache end + +CC.haskey(wvc::WorldView{GlobalCache}, mi::MethodInstance) = haskey(GLOBAL_CODE_CACHE, mi) + +CC.get(wvc::WorldView{GlobalCache}, mi::MethodInstance, default) = get(GLOBAL_CODE_CACHE, mi, default) + +CC.getindex(wvc::WorldView{GlobalCache}, mi::MethodInstance) = getindex(GLOBAL_CODE_CACHE, mi) + +function CC.setindex!(wvc::WorldView{GlobalCache}, ci::CodeInstance, mi::MethodInstance) + GLOBAL_CODE_CACHE[mi] = ci + add_callback!(mi) # register the callback on invalidation + return nothing +end + +function add_callback!(linfo) + if !isdefined(linfo, :callbacks) + linfo.callbacks = Any[invalidate_cache!] + else + if !any(@nospecialize(cb)->cb===invalidate_cache!, linfo.callbacks) + push!(linfo.callbacks, invalidate_cache!) + end + end + return nothing +end + +function invalidate_cache!(replaced, max_world, depth = 0) + delete!(GLOBAL_CODE_CACHE, replaced) + + if isdefined(replaced, :backedges) + for mi in replaced.backedges + mi = mi::MethodInstance + if !haskey(GLOBAL_CODE_CACHE, mi) + continue # otherwise fall into infinite loop + end + invalidate_cache!(mi, max_world, depth+1) + end + end + return nothing +end + +function CC.optimize(interp::EscapeAnalyzer, + opt::OptimizationState, params::OptimizationParams, caller::InferenceResult) + ir = run_passes_with_ea(interp, opt.src, opt, caller) + return CC.finish(interp, opt, params, ir, caller) +end + +function CC.cache_result!(interp::EscapeAnalyzer, caller::InferenceResult) + if haskey(interp.cache, caller) + GLOBAL_ESCAPE_CACHE[caller.linfo] = interp.cache[caller] + end + return Base.@invoke CC.cache_result!(interp::AbstractInterpreter, caller::InferenceResult) +end + +const GLOBAL_ESCAPE_CACHE = IdDict{MethodInstance,EscapeCache}() + +""" + cache_escapes!(caller::InferenceResult, estate::EscapeState, cacheir::IRCode) + +Transforms escape information of call arguments of `caller`, +and then caches it into a global cache for later interprocedural propagation. +""" +function cache_escapes!(interp::EscapeAnalyzer, + caller::InferenceResult, estate::EscapeState, cacheir::IRCode) + cache = ArgEscapeCache(estate) + ecache = EscapeCache(cache, estate, cacheir) + interp.cache[caller] = ecache + return cache +end + +function get_escape_cache(interp::EscapeAnalyzer) + return function (linfo::Union{InferenceResult,MethodInstance}) + if isa(linfo, InferenceResult) + ecache = get(interp.cache, linfo, nothing) + else + ecache = get(GLOBAL_ESCAPE_CACHE, linfo, nothing) + end + return ecache !== nothing ? ecache.cache : nothing + end +end + +function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::OptimizationState, + caller::InferenceResult) + @timeit "convert" ir = convert_to_ircode(ci, sv) + @timeit "slot2reg" ir = slot2reg(ir, ci, sv) + # TODO: Domsorting can produce an updated domtree - no need to recompute here + @timeit "compact 1" ir = compact!(ir) + nargs = let def = sv.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end + local state + if is_ipo_profitable(ir, nargs) || caller.linfo.specTypes === interp.entry_tt + try + @timeit "[IPO EA]" begin + state = analyze_escapes(ir, nargs, false, get_escape_cache(interp)) + cache_escapes!(interp, caller, state, cccopy(ir)) + end + catch err + @error "error happened within [IPO EA], insepct `Main.ir` and `Main.nargs`" + @eval Main (ir = $ir; nargs = $nargs) + rethrow(err) + end + end + if caller.linfo.specTypes === interp.entry_tt && !interp.optimize + # return back the result + interp.ir = cccopy(ir) + interp.state = state + interp.linfo = sv.linfo + end + @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) + # @timeit "verify 2" verify_ir(ir) + @timeit "compact 2" ir = compact!(ir) + if caller.linfo.specTypes === interp.entry_tt && interp.optimize + try + @timeit "[Local EA]" state = analyze_escapes(ir, nargs, true, get_escape_cache(interp)) + catch err + @error "error happened within [Local EA], insepct `Main.ir` and `Main.nargs`" + @eval Main (ir = $ir; nargs = $nargs) + rethrow(err) + end + # return back the result + interp.ir = cccopy(ir) + interp.state = state + interp.linfo = sv.linfo + end + @timeit "SROA" ir = sroa_pass!(ir) + @timeit "ADCE" ir = adce_pass!(ir) + @timeit "type lift" ir = type_lift_pass!(ir) + @timeit "compact 3" ir = compact!(ir) + if JLOptions().debug_level == 2 + @timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable)) + end + return ir +end + +# printing +# -------- + +import Core: Argument, SSAValue +import .CC: widenconst, singleton_type + +Base.getindex(estate::EscapeState, @nospecialize(x)) = CC.getindex(estate, x) + +function get_name_color(x::EscapeInfo, symbol::Bool = false) + getname(x) = string(nameof(x)) + if x === EA.⊥ + name, color = (getname(EA.NotAnalyzed), "◌"), :plain + elseif EA.has_no_escape(EA.ignore_argescape(x)) + if EA.has_arg_escape(x) + name, color = (getname(EA.ArgEscape), "✓"), :cyan + else + name, color = (getname(EA.NoEscape), "✓"), :green + end + elseif EA.has_all_escape(x) + name, color = (getname(EA.AllEscape), "X"), :red + elseif EA.has_return_escape(x) + name = (getname(EA.ReturnEscape), "↑") + color = EA.has_thrown_escape(x) ? :yellow : :blue + else + name = (nothing, "*") + color = EA.has_thrown_escape(x) ? :yellow : :bold + end + name = symbol ? last(name) : first(name) + if name !== nothing && !isa(x.AliasInfo, Bool) + name = string(name, "′") + end + return name, color +end + +# pcs = sprint(show, collect(x.EscapeSites); context=:limit=>true) +function Base.show(io::IO, x::EscapeInfo) + name, color = get_name_color(x) + if isnothing(name) + Base.@invoke show(io::IO, x::Any) + else + printstyled(io, name; color) + end +end +function Base.show(io::IO, ::MIME"application/prs.juno.inline", x::EscapeInfo) + name, color = get_name_color(x) + if isnothing(name) + return x # use fancy tree-view + else + printstyled(io, name; color) + end +end + +struct EscapeResult + ir::IRCode + state::EscapeState + linfo::Union{Nothing,MethodInstance} + source::Bool + function EscapeResult(ir::IRCode, state::EscapeState, + linfo::Union{Nothing,MethodInstance} = nothing, + source::Bool=false) + return new(ir, state, linfo, source) + end +end +Base.show(io::IO, result::EscapeResult) = print_with_info(io, result) +@eval Base.iterate(res::EscapeResult, state=1) = + return state > $(fieldcount(EscapeResult)) ? nothing : (getfield(res, state), state+1) + +Base.show(io::IO, cached::EscapeCache) = show(io, EscapeResult(cached.ir, cached.state, nothing)) + +# adapted from https://github.com/JuliaDebug/LoweredCodeUtils.jl/blob/4612349432447e868cf9285f647108f43bd0a11c/src/codeedges.jl#L881-L897 +function print_with_info(io::IO, (; ir, state, linfo, source)::EscapeResult) + # print escape information on SSA values + function preprint(io::IO) + ft = ir.argtypes[1] + f = singleton_type(ft) + if f === nothing + f = widenconst(ft) + end + print(io, f, '(') + for i in 1:state.nargs + arg = state[Argument(i)] + i == 1 && continue + c, color = get_name_color(arg, true) + printstyled(io, c, ' ', '_', i, "::", ir.argtypes[i]; color) + i ≠ state.nargs && print(io, ", ") + end + print(io, ')') + if !isnothing(linfo) + def = linfo.def + printstyled(io, " in ", (isa(def, Module) ? (def,) : (def.module, " at ", def.file, ':', def.line))...; color=:bold) + end + println(io) + end + + # print escape information on SSA values + # nd = ndigits(length(ssavalues)) + function preprint(io::IO, idx::Int) + c, color = get_name_color(state[SSAValue(idx)], true) + # printstyled(io, lpad(idx, nd), ' ', c, ' '; color) + printstyled(io, rpad(c, 2), ' '; color) + end + + print_with_info(preprint, (args...)->nothing, io, ir, source) +end + +function print_with_info(preprint, postprint, io::IO, ir::IRCode, source::Bool) + io = IOContext(io, :displaysize=>displaysize(io)) + used = Base.IRShow.stmts_used(io, ir) + if source + line_info_preprinter = function (io::IO, indent::String, idx::Int) + r = Base.IRShow.inline_linfo_printer(ir)(io, indent, idx) + idx ≠ 0 && preprint(io, idx) + return r + end + else + line_info_preprinter = Base.IRShow.lineinfo_disabled + end + line_info_postprinter = Base.IRShow.default_expr_type_printer + preprint(io) + bb_idx_prev = bb_idx = 1 + for idx = 1:length(ir.stmts) + preprint(io, idx) + bb_idx = Base.IRShow.show_ir_stmt(io, ir, idx, line_info_preprinter, line_info_postprinter, used, ir.cfg, bb_idx) + postprint(io, idx, bb_idx != bb_idx_prev) + bb_idx_prev = bb_idx + end + max_bb_idx_size = ndigits(length(ir.cfg.blocks)) + line_info_preprinter(io, " "^(max_bb_idx_size + 2), 0) + postprint(io) + return nothing +end + +end # module EAUtils diff --git a/test/compiler/EscapeAnalysis/interprocedural.jl b/test/compiler/EscapeAnalysis/interprocedural.jl new file mode 100644 index 0000000000000..eccdc710a6c12 --- /dev/null +++ b/test/compiler/EscapeAnalysis/interprocedural.jl @@ -0,0 +1,264 @@ +# IPO EA Test +# =========== +# EA works on pre-inlining IR + +include(normpath(@__DIR__, "setup.jl")) + +# callsites +# --------- + +import .EA: ignore_argescape + +noescape(a) = nothing +noescape(a, b) = nothing +function global_escape!(x) + GR[] = x + return nothing +end +union_escape!(x) = global_escape!(x) +union_escape!(x::SafeRef) = nothing +union_escape!(x::SafeRefs) = nothing +Base.@constprop :aggressive function conditional_escape!(cnd, x) + cnd && global_escape!(x) + return nothing +end + +# MethodMatchInfo -- global cache +let result = code_escapes((SafeRef{String},); optimize=false) do x + return noescape(x) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + identity(x) + return nothing + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + return identity(x) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + return Ref(x) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + r = Ref{SafeRef{String}}() + r[] = x + return r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + global_escape!(x) + end + @test has_all_escape(result.state[Argument(2)]) +end +# UnionSplitInfo +let result = code_escapes((Bool,Vector{Any}); optimize=false) do c, s + x = c ? s : SafeRef(s) + union_escape!(x) + end + @test has_all_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((Bool,Vector{Any}); optimize=false) do c, s + x = c ? SafeRef(s) : SafeRefs(s, s) + union_escape!(x) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end +# ConstCallInfo -- local cache +let result = code_escapes((SafeRef{String},); optimize=false) do x + return conditional_escape!(false, x) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end +# InvokeCallInfo +let result = code_escapes((SafeRef{String},); optimize=false) do x + return Base.@invoke noescape(x::Any) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + return Base.@invoke conditional_escape!(false::Any, x::Any) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end + +# MethodError +# ----------- +# accounts for ThrownEscape via potential MethodError + +# no method error +identity_if_string(x::SafeRef) = nothing +let result = code_escapes((SafeRef{String},); optimize=false) do x + identity_if_string(x) + end + i = only(findall(iscall((result.ir, identity_if_string)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], i) + @test !has_return_escape(result.state[Argument(2)], r) +end +let result = code_escapes((Union{SafeRef{String},Vector{String}},); optimize=false) do x + identity_if_string(x) + end + i = only(findall(iscall((result.ir, identity_if_string)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], i) + @test !has_return_escape(result.state[Argument(2)], r) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + try + identity_if_string(x) + catch err + global GV = err + end + return nothing + end + @test !has_all_escape(result.state[Argument(2)]) +end +let result = code_escapes((Union{SafeRef{String},Vector{String}},); optimize=false) do x + try + identity_if_string(x) + catch err + global GV = err + end + return nothing + end + @test has_all_escape(result.state[Argument(2)]) +end +# method ambiguity error +ambig_error_test(a::SafeRef, b) = nothing +ambig_error_test(a, b::SafeRef) = nothing +ambig_error_test(a, b) = nothing +let result = code_escapes((SafeRef{String},Any); optimize=false) do x, y + ambig_error_test(x, y) + end + i = only(findall(iscall((result.ir, ambig_error_test)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], i) # x + @test has_thrown_escape(result.state[Argument(3)], i) # y + @test !has_return_escape(result.state[Argument(2)], r) # x + @test !has_return_escape(result.state[Argument(3)], r) # y +end +let result = code_escapes((SafeRef{String},Any); optimize=false) do x, y + try + ambig_error_test(x, y) + catch err + global GV = err + end + end + @test has_all_escape(result.state[Argument(2)]) # x + @test has_all_escape(result.state[Argument(3)]) # y +end + +# Local EA integration +# -------------------- + +# propagate escapes imposed on call arguments + +# FIXME handle _apply_iterate +# FIXME currently we can't prove the effect-freeness of `getfield(RefValue{String}, :x)` +# because of this check https://github.com/JuliaLang/julia/blob/94b9d66b10e8e3ebdb268e4be5f7e1f43079ad4e/base/compiler/tfuncs.jl#L745 +# and thus it leads to the following two broken tests + +@noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing) +let result = code_escapes() do + broadcast_noescape1(Ref("Hi")) + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)]) + @test_broken !has_thrown_escape(result.state[SSAValue(i)]) +end +@noinline broadcast_noescape2(b) = broadcast(identity, b) +let result = code_escapes() do + broadcast_noescape2(Ref("Hi")) + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)]) + @test_broken !has_thrown_escape(result.state[SSAValue(i)]) +end +@noinline allescape_argument(a) = (global GV = a) # obvious escape +let result = code_escapes() do + allescape_argument(Ref("Hi")) + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) +end +# if we can't determine the matching method statically, we should be conservative +let result = code_escapes((Ref{Any},)) do a + may_exist(a) + end + @test has_all_escape(result.state[Argument(2)]) +end +let result = code_escapes((Ref{Any},)) do a + Base.@invokelatest broadcast_noescape1(a) + end + @test has_all_escape(result.state[Argument(2)]) +end + +# handling of simple union-split (just exploit the inliner's effort) +@noinline unionsplit_noescape(a) = string(nothing) +@noinline unionsplit_noescape(a::Int) = a + 10 +let result = code_escapes((Union{Int,Nothing},)) do x + s = SafeRef{Union{Int,Nothing}}(x) + unionsplit_noescape(s[]) + return nothing + end + inds = findall(isnew, result.ir.stmts.inst) # find allocation statement + @assert !isempty(inds) + for i in inds + @test has_no_escape(result.state[SSAValue(i)]) + end +end + +@noinline function unused_argument(a) + println("prevent inlining") + return Base.inferencebarrier(nothing) +end +let result = code_escapes() do + a = Ref("foo") # shouldn't be "return escape" + b = unused_argument(a) + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + + result = code_escapes() do + a = Ref("foo") # still should be "return escape" + b = unused_argument(a) + return a + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) +end + +# should propagate escape information imposed on return value to the aliased call argument +@noinline returnescape_argument(a) = (println("prevent inlining"); a) +let result = code_escapes() do + obj = Ref("foo") # should be "return escape" + ret = returnescape_argument(obj) + return ret # alias of `obj` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) +end +@noinline noreturnescape_argument(a) = (println("prevent inlining"); identity("hi")) +let result = code_escapes() do + obj = Ref("foo") # better to not be "return escape" + ret = noreturnescape_argument(obj) + return ret # must not alias to `obj` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) +end diff --git a/test/compiler/EscapeAnalysis/local.jl b/test/compiler/EscapeAnalysis/local.jl new file mode 100644 index 0000000000000..e5d8f1bf2c940 --- /dev/null +++ b/test/compiler/EscapeAnalysis/local.jl @@ -0,0 +1,2206 @@ +# Local EA Test +# ============= +# EA works on post-inlining IR + +include(normpath(@__DIR__, "setup.jl")) + +@testset "basics" begin + let # arg return + result = code_escapes((Any,)) do a # return to caller + return nothing + end + @test has_arg_escape(result.state[Argument(2)]) + # return + result = code_escapes((Any,)) do a + return a + end + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_arg_escape(result.state[Argument(1)]) # self + @test !has_return_escape(result.state[Argument(1)], i) # self + @test has_arg_escape(result.state[Argument(2)]) # a + @test has_return_escape(result.state[Argument(2)], i) # a + end + let # global store + result = code_escapes((Any,)) do a + global GV = a + nothing + end + @test has_all_escape(result.state[Argument(2)]) + end + let # global load + result = code_escapes() do + global GV + return GV + end + i = only(findall(has_return_escape, map(i->result.state[SSAValue(i)], 1:length(result.ir.stmts)))) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # global store / load (https://github.com/aviatesk/EscapeAnalysis.jl/issues/56) + result = code_escapes((Any,)) do s + global GV + GV = s + return GV + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + end + let # :gc_preserve_begin / :gc_preserve_end + result = code_escapes((String,)) do s + m = SafeRef(s) + GC.@preserve m begin + return nothing + end + end + i = findfirst(isT(SafeRef{String}), result.ir.stmts.type) # find allocation statement + @test !isnothing(i) + @test has_no_escape(result.state[SSAValue(i)]) + end + let # :isdefined + result = code_escapes((String, Bool, )) do a, b + if b + s = Ref(a) + end + return @isdefined(s) + end + i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement + @test !isnothing(i) + @test has_no_escape(result.state[SSAValue(i)]) + end + let # ϕ-node + result = code_escapes((Bool,Any,Any)) do cond, a, b + c = cond ? a : b # ϕ(a, b) + return c + end + @assert any(@nospecialize(x)->isa(x, Core.PhiNode), result.ir.stmts.inst) + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], i) # a + @test has_return_escape(result.state[Argument(4)], i) # b + end + let # π-node + result = code_escapes((Any,)) do a + if isa(a, Regex) # a::π(Regex) + return a + end + return nothing + end + @assert any(@nospecialize(x)->isa(x, Core.PiNode), result.ir.stmts.inst) + @test any(findall(isreturn, result.ir.stmts.inst)) do i + has_return_escape(result.state[Argument(2)], i) + end + end + let # φᶜ-node / ϒ-node + result = code_escapes((Any,String)) do a, b + local x::String + try + x = a + catch err + x = b + end + return x + end + @assert any(@nospecialize(x)->isa(x, Core.PhiCNode), result.ir.stmts.inst) + @assert any(@nospecialize(x)->isa(x, Core.UpsilonNode), result.ir.stmts.inst) + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], i) + @test has_return_escape(result.state[Argument(3)], i) + end + let # branching + result = code_escapes((Any,Bool,)) do a, c + if c + return nothing # a doesn't escape in this branch + else + return a # a escapes to a caller + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let # loop + result = code_escapes((Int,)) do n + c = SafeRef{Bool}(false) + while n > 0 + rand(Bool) && return c + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) + end + let # try/catch + result = code_escapes((Any,)) do a + try + nothing + catch err + return a # return escape + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + try + nothing + finally + return a # return escape + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let # :foreigncall + result = code_escapes((Any,)) do x + ccall(:some_ccall, Any, (Any,), x) + end + @test has_all_escape(result.state[Argument(2)]) + end +end + +let # simple allocation + result = code_escapes((Bool,)) do c + mm = SafeRef{Bool}(c) # just allocated, never escapes + return mm[] ? nothing : 1 + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(i)]) +end + +@testset "builtins" begin + let # throw + r = code_escapes((Any,)) do a + throw(a) + end + @test has_thrown_escape(r.state[Argument(2)]) + end + + let # implicit throws + r = code_escapes((Any,)) do a + getfield(a, :may_not_field) + end + @test has_thrown_escape(r.state[Argument(2)]) + + r = code_escapes((Any,)) do a + sizeof(a) + end + @test has_thrown_escape(r.state[Argument(2)]) + end + + let # :=== + result = code_escapes((Bool, SafeRef{String})) do cond, s + m = cond ? s : nothing + c = m === nothing + return c + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) + end + + let # sizeof + result = code_escapes((Vector{Any},)) do xs + sizeof(xs) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) + end + + let # ifelse + result = code_escapes((Bool,)) do c + r = ifelse(c, Ref("yes"), Ref("no")) + return r + end + inds = findall(isnew, result.ir.stmts.inst) + @assert !isempty(inds) + for i in inds + @test has_return_escape(result.state[SSAValue(i)]) + end + end + let # ifelse (with constant condition) + result = code_escapes() do + r = ifelse(true, Ref("yes"), Ref(nothing)) + return r + end + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)]) + elseif isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{Nothing})(result.ir.stmts.type[i]) + @test has_no_escape(result.state[SSAValue(i)]) + end + end + end + + let # typeassert + result = code_escapes((Any,)) do x + y = x::String + return y + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test !has_all_escape(result.state[Argument(2)]) + end + + let # isdefined + result = code_escapes((Any,)) do x + isdefined(x, :foo) ? x : throw("undefined") + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test !has_all_escape(result.state[Argument(2)]) + + result = code_escapes((Module,)) do m + isdefined(m, 10) # throws + end + @test has_thrown_escape(result.state[Argument(2)]) + end +end + +@testset "flow-sensitivity" begin + # ReturnEscape + let result = code_escapes((Bool,)) do cond + r = Ref("foo") + if cond + return cond + end + return r + end + i = only(findall(isnew, result.ir.stmts.inst)) + rts = findall(isreturn, result.ir.stmts.inst) + @assert length(rts) == 2 + @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 1 + end + let result = code_escapes((Bool,)) do cond + r = Ref("foo") + cnt = 0 + while rand(Bool) + cnt += 1 + rand(Bool) && return r + end + rand(Bool) && return r + return cnt + end + i = only(findall(isnew, result.ir.stmts.inst)) + rts = findall(isreturn, result.ir.stmts.inst) # return statement + @assert length(rts) == 3 + @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 2 + end +end + +@testset "escape through exceptions" begin + M = @eval Module() begin + unsafeget(x) = isassigned(x) ? x[] : throw(x) + @noinline function escape_rethrow!() + try + rethrow() + catch err + GR[] = err + end + end + @noinline function escape_current_exceptions!() + excs = Base.current_exceptions() + GR[] = excs + end + const GR = Ref{Any}() + @__MODULE__ + end + + let # simple: return escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch err + ret = err + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) + end + + let # simple: global escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret # prevent DCE + try + s = unsafeget(r) + ret = sizeof(s) + catch err + global GV = err + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + + let # account for possible escapes via nested throws + result = @eval M $code_escapes() do + r = Ref{String}() + try + try + unsafeget(r) + catch err1 + throw(err1) + end + catch err2 + GR[] = err2 + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + r = Ref{String}() + try + try + unsafeget(r) + catch err1 + rethrow(err1) + end + catch err2 + GR[] = err2 + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + escape_rethrow!() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + local t + try + r = Ref{String}() + t = unsafeget(r) + catch err + t = typeof(err) + escape_rethrow!() + end + return t + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `Base.current_exceptions` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + GR[] = Base.current_exceptions() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `Base.current_exceptions` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + escape_current_exceptions!() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + + let # contextual: escape information imposed on `err` shouldn't propagate to `r2`, but only to `r1` + result = @eval M $code_escapes() do + r1 = Ref{String}() + r2 = Ref{String}() + local ret + try + s1 = unsafeget(r1) + ret = sizeof(s1) + catch err + global GV = err + end + s2 = unsafeget(r2) + return s2, r2 + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i1, i2 = is + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i1)]) + @test !has_all_escape(result.state[SSAValue(i2)]) + @test has_return_escape(result.state[SSAValue(i2)], r) + end + + # XXX test cases below are currently broken because of the technical reason described in `escape_exception!` + + let # limited propagation: exception is caught within a frame => doesn't escape to a caller + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch + ret = nothing + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + end + let # sequential: escape information imposed on `err1` and `err2 should propagate separately + result = @eval M $code_escapes() do + r1 = Ref{String}() + r2 = Ref{String}() + local ret + try + s1 = unsafeget(r1) + ret = sizeof(s1) + catch err1 + global GV = err1 + end + try + s2 = unsafeget(r2) + ret = sizeof(s2) + catch err2 + ret = err2 + end + return ret + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i1, i2 = is + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i1)]) + @test has_return_escape(result.state[SSAValue(i2)], r) + @test_broken !has_all_escape(result.state[SSAValue(i2)]) + end + let # nested: escape information imposed on `inner` shouldn't propagate to `s` + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + try + ret = sizeof(s) + catch inner + return inner + end + catch outer + ret = nothing + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)]) + end + let # merge: escape information imposed on `err1` and `err2 should be merged + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch err1 + return err1 + end + try + s = unsafeget(r) + ret = sizeof(s) + catch err2 + return err2 + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + rs = findall(isreturn, result.ir.stmts.inst) + @test_broken !has_all_escape(result.state[SSAValue(i)]) + for r in rs + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let # no exception handling: should keep propagating the escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + finally + if !@isdefined(ret) + ret = 42 + end + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + end +end + +@testset "field analysis / alias analysis" begin + # escaped allocations + # ------------------- + + # escaped object should escape its fields as well + let result = code_escapes((Any,)) do a + global GV = SafeRef{Any}(a) + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + global GV = (a,) + nothing + end + i = only(findall(iscall((result.ir, tuple)), result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + o0 = SafeRef{Any}(a) + global GV = SafeRef(o0) + nothing + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i0, i1 = is + @test has_all_escape(result.state[SSAValue(i0)]) + @test has_all_escape(result.state[SSAValue(i1)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + t0 = (a,) + global GV = (t0,) + nothing + end + inds = findall(iscall((result.ir, tuple)), result.ir.stmts.inst) + @assert length(inds) == 2 + for i in inds; @test has_all_escape(result.state[SSAValue(i)]); end + @test has_all_escape(result.state[Argument(2)]) + end + # global escape through `setfield!` + let result = code_escapes((Any,)) do a + r = SafeRef{Any}(:init) + global GV = r + r[] = a + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(a) + global GV = r + r[] = b + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) # a + @test has_all_escape(result.state[Argument(3)]) # b + end + let result = @eval EATModule() begin + const Rx = SafeRef{String}("Rx") + $code_escapes((String,)) do s + Rx[] = s + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + let result = @eval EATModule() begin + const Rx = SafeRef{String}("Rx") + $code_escapes((String,)) do s + setfield!(Rx, :x, s) + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + let M = EATModule() + @eval M module ___xxx___ + import ..SafeRef + const Rx = SafeRef("Rx") + end + result = @eval M begin + $code_escapes((String,)) do s + rx = getfield(___xxx___, :Rx) + rx[] = s + nothing + end + end + @test has_all_escape(result.state[Argument(2)]) + end + + # field escape + # ------------ + + # field escape should propagate to :new arguments + let result = code_escapes((String,)) do a + o = SafeRef(a) + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((String,)) do a + t = SafeRef((a,)) + f = t[][1] + return f + end + i = only(findall(iscall((result.ir, tuple)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + result.state[SSAValue(i)].AliasInfo + end + let result = code_escapes((String, String)) do a, b + obj = SafeRefs(a, b) + fld1 = obj[1] + fld2 = obj[2] + return (fld1, fld2) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # field escape should propagate to `setfield!` argument + let result = code_escapes((String,)) do a + o = SafeRef("foo") + o[] = a + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # propagate escape information imposed on return value of `setfield!` call + let result = code_escapes((String,)) do a + obj = SafeRef("foo") + return (obj[] = a) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # nested allocations + let result = code_escapes((String,)) do a + o1 = SafeRef(a) + o2 = SafeRef(o1) + return o2[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(SafeRef{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)], r) + elseif isnew(result.ir.stmts.inst[i]) && isT(SafeRef{SafeRef{String}})(result.ir.stmts.type[i]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + end + let result = code_escapes((String,)) do a + o1 = (a,) + o2 = (o1,) + return o2[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(Tuple{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)], r) + elseif isnew(result.ir.stmts.inst[i]) && isT(Tuple{Tuple{String}})(result.ir.stmts.type[i]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + end + let result = code_escapes((String,)) do a + o1 = SafeRef(a) + o2 = SafeRef(o1) + o1′ = o2[] + a′ = o1′[] + return a′ + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes() do + o1 = SafeRef("foo") + o2 = SafeRef(o1) + return o2 + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(isnew, result.ir.stmts.inst) + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = code_escapes() do + o1 = SafeRef("foo") + o2′ = SafeRef(nothing) + o2 = SafeRef{SafeRef}(o2′) + o2[] = o1 + return o2 + end + r = only(findall(isreturn, result.ir.stmts.inst)) + findall(1:length(result.ir.stmts)) do i + if isnew(result.ir.stmts[i][:inst]) + t = result.ir.stmts[i][:type] + return t === SafeRef{String} || # o1 + t === SafeRef{SafeRef} # o2 + end + return false + end |> x->foreach(x) do i + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = code_escapes((String,)) do x + broadcast(identity, Ref(x)) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # ϕ-node allocations + let result = code_escapes((Bool,Any,Any)) do cond, x, y + if cond + ϕ = SafeRef{Any}(x) + else + ϕ = SafeRef{Any}(y) + end + return ϕ[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test has_return_escape(result.state[Argument(4)], r) # y + i = only(findall(isϕ, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes((Bool,Any,Any)) do cond, x, y + if cond + ϕ2 = ϕ1 = SafeRef{Any}(x) + else + ϕ2 = ϕ1 = SafeRef{Any}(y) + end + return ϕ1[], ϕ2[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test has_return_escape(result.state[Argument(4)], r) # y + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + # when ϕ-node merges values with different types + let result = code_escapes((Bool,String,String,String)) do cond, x, y, z + local out + if cond + ϕ = SafeRef(x) + out = ϕ[] + else + ϕ = SafeRefs(z, y) + end + return @isdefined(out) ? out : throw(ϕ) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) + ϕ = only(findall(isT(Union{SafeRef{String},SafeRefs{String,String}}), result.ir.stmts.type)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test !has_return_escape(result.state[Argument(4)], r) # y + @test has_return_escape(result.state[Argument(5)], r) # z + @test has_thrown_escape(result.state[SSAValue(ϕ)], t) + end + + # alias analysis + # -------------- + + # alias via getfield & Expr(:new) + let result = code_escapes((String,)) do s + r = SafeRef(s) + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + @test !isaliased(Argument(2), SSAValue(i), result.state) + end + let result = code_escapes((String,)) do s + r1 = SafeRef(s) + r2 = SafeRef(r1) + return r2[] + end + i1, i2 = findall(isnew, result.ir.stmts.inst) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test !isaliased(SSAValue(i1), SSAValue(i2), result.state) + @test isaliased(SSAValue(i1), val, result.state) + @test !isaliased(SSAValue(i2), val, result.state) + end + let result = code_escapes((String,)) do s + r1 = SafeRef(s) + r2 = SafeRef(r1) + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + end + let result = @eval EATModule() begin + const Rx = SafeRef("Rx") + $code_escapes((String,)) do s + r = SafeRef(Rx) + rx = r[] # rx aliased to Rx + rx[] = s + nothing + end + end + i = findfirst(isnew, result.ir.stmts.inst) + @test has_all_escape(result.state[Argument(2)]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # alias via getfield & setfield! + let result = code_escapes((String,)) do s + r = Ref{String}() + r[] = s + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + @test !isaliased(Argument(2), SSAValue(i), result.state) + end + let result = code_escapes((String,)) do s + r1 = Ref(s) + r2 = Ref{Base.RefValue{String}}() + r2[] = r1 + return r2[] + end + i1, i2 = findall(isnew, result.ir.stmts.inst) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test !isaliased(SSAValue(i1), SSAValue(i2), result.state) + @test isaliased(SSAValue(i1), val, result.state) + @test !isaliased(SSAValue(i2), val, result.state) + end + let result = code_escapes((String,)) do s + r1 = Ref{String}() + r2 = Ref{Base.RefValue{String}}() + r2[] = r1 + r1[] = s + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + result = code_escapes((String,)) do s + r1 = Ref{String}() + r2 = Ref{Base.RefValue{String}}() + r1[] = s + r2[] = r1 + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + end + let result = @eval EATModule() begin + const Rx = SafeRef("Rx") + $code_escapes((SafeRef{String}, String,)) do _rx, s + r = SafeRef(_rx) + r[] = Rx + rx = r[] # rx aliased to Rx + rx[] = s + nothing + end + end + i = findfirst(isnew, result.ir.stmts.inst) + @test has_all_escape(result.state[Argument(3)]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # alias via typeassert + let result = code_escapes((Any,)) do a + r = a::String + return r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(2)], r) # a + @test isaliased(Argument(2), val, result.state) # a <-> r + end + let result = code_escapes((Any,)) do a + global GV + (g::SafeRef{Any})[] = a + nothing + end + @test has_all_escape(result.state[Argument(2)]) + end + # alias via ifelse + let result = code_escapes((Bool,Any,Any)) do c, a, b + r = ifelse(c, a, b) + return r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(3)], r) # a + @test has_return_escape(result.state[Argument(4)], r) # b + @test !isaliased(Argument(2), val, result.state) # c r + @test isaliased(Argument(3), val, result.state) # a <-> r + @test isaliased(Argument(4), val, result.state) # b <-> r + end + let result = @eval EATModule() begin + const Lx, Rx = SafeRef("Lx"), SafeRef("Rx") + $code_escapes((Bool,String,)) do c, a + r = ifelse(c, Lx, Rx) + r[] = a + nothing + end + end + @test has_all_escape(result.state[Argument(3)]) # a + end + # alias via ϕ-node + let result = code_escapes((Bool,String)) do cond, x + if cond + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + ϕ2[] = x + return ϕ1[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(3)], r) # x + @test isaliased(Argument(3), val, result.state) # x + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes((Bool,Bool,String)) do cond1, cond2, x + if cond1 + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + cond2 && (ϕ2[] = x) + return ϕ1[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(4)], r) # x + @test isaliased(Argument(4), val, result.state) # x + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + # alias via π-node + let result = code_escapes((Any,)) do x + if isa(x, String) + return x + end + throw("error!") + end + r = only(findall(isreturn, result.ir.stmts.inst)) + rval = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(2)], r) # x + @test isaliased(Argument(2), rval, result.state) + end + let result = code_escapes((String,)) do x + global GV + l = g + if isa(l, SafeRef{String}) + l[] = x + end + nothing + end + @test has_all_escape(result.state[Argument(2)]) # x + end + # circular reference + let result = code_escapes() do + x = Ref{Any}() + x[] = x + return x[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + let result = @eval Module() begin + const Rx = Ref{Any}() + Rx[] = Rx + $code_escapes() do + r = Rx[]::Base.RefValue{Any} + return r[] + end + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(iscall((result.ir, getfield)), result.ir.stmts.inst) + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = @eval Module() begin + @noinline function genr() + r = Ref{Any}() + r[] = r + return r + end + $code_escapes() do + x = genr() + return x[] + end + end + i = only(findall(isinvoke(:genr), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + + # dynamic semantics + # ----------------- + + # conservatively handle untyped objects + let result = @eval code_escapes((Any,Any,)) do T, x + obj = $(Expr(:new, :T, :x)) + end + t = only(findall(isnew, result.ir.stmts.inst)) + @test #=T=# has_thrown_escape(result.state[Argument(2)], t) # T + @test #=x=# has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z + obj = $(Expr(:new, :T, :x, :y)) + return getfield(obj, :x) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test #=x=# has_return_escape(result.state[Argument(3)], r) + @test #=y=# has_return_escape(result.state[Argument(4)], r) + @test #=z=# !has_return_escape(result.state[Argument(5)], r) + end + let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z + obj = $(Expr(:new, :T, :x)) + setfield!(obj, :x, y) + return getfield(obj, :x) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test #=x=# has_return_escape(result.state[Argument(3)], r) + @test #=y=# has_return_escape(result.state[Argument(4)], r) + @test #=z=# !has_return_escape(result.state[Argument(5)], r) + end + + # conservatively handle unknown field: + # all fields should be escaped, but the allocation itself doesn't need to be escaped + let result = code_escapes((String, Symbol)) do a, fld + obj = SafeRef(a) + return getfield(obj, fld) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Symbol)) do a, b, fld + obj = SafeRefs(a, b) + return getfield(obj, fld) # should escape both `a` and `b` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Int)) do a, b, idx + obj = SafeRefs(a, b) + return obj[idx] # should escape both `a` and `b` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Symbol)) do a, b, fld + obj = SafeRefs("a", "b") + setfield!(obj, fld, a) + return obj[2] # should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, Symbol)) do a, fld + obj = SafeRefs("a", "b") + setfield!(obj, fld, a) + return obj[1] # this should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Int)) do a, b, idx + obj = SafeRefs("a", "b") + obj[idx] = a + return obj[2] # should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + + # interprocedural + # --------------- + + let result = @eval EATModule() begin + @noinline getx(obj) = obj[] + $code_escapes((String,)) do a + obj = SafeRef(a) + fld = getx(obj) + return fld + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + # NOTE we can't scalar replace `obj`, but still we may want to stack allocate it + @test_broken is_load_forwardable(result.state[SSAValue(i)]) + end + + # TODO interprocedural alias analysis + let result = code_escapes((SafeRef{String},)) do s + s[] = "bar" + global GV = s[] + nothing + end + @test_broken !has_all_escape(result.state[Argument(2)]) + end + + # aliasing between arguments + let result = @eval EATModule() begin + @noinline setxy!(x, y) = x[] = y + $code_escapes((String,)) do y + x = SafeRef("init") + setxy!(x, y) + return x + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test has_return_escape(result.state[Argument(2)], r) # y + end + let result = @eval EATModule() begin + @noinline setxy!(x, y) = x[] = y + $code_escapes((String,)) do y + x1 = SafeRef("init") + x2 = SafeRef(y) + setxy!(x1, x2[]) + return x1 + end + end + i1, i2 = findall(isnew, result.ir.stmts.inst) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i1)], r) + @test !has_return_escape(result.state[SSAValue(i2)], r) + @test has_return_escape(result.state[Argument(2)], r) # y + end + let result = @eval EATModule() begin + @noinline mysetindex!(x, a) = x[1] = a + const Ax = Vector{Any}(undef, 1) + $code_escapes((String,)) do s + mysetindex!(Ax, s) + end + end + @test has_all_escape(result.state[Argument(2)]) # s + end + + # TODO flow-sensitivity? + # ---------------------- + + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(a) + r[] = b + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(:init) + r[] = a + r[] = b + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((Any,Any,Bool)) do a, b, cond + r = SafeRef{Any}(:init) + if cond + r[] = a + return r[] + else + r[] = b + return nothing + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + r = only(findall(result.ir.stmts.inst) do @nospecialize x + isreturn(x) && isa(x.val, Core.SSAValue) + end) + @test has_return_escape(result.state[Argument(2)], r) # a + @test_broken !has_return_escape(result.state[Argument(3)], r) # b + end + + # handle conflicting field information correctly + let result = code_escapes((Bool,String,String,)) do cnd, baz, qux + if cnd + o = SafeRef("foo") + else + o = SafeRefs("bar", baz) + r = getfield(o, 2) + end + if cnd + o = o::SafeRef + setfield!(o, 1, qux) + r = getfield(o, 1) + end + r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # baz + @test has_return_escape(result.state[Argument(4)], r) # qux + for new in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(new)]) + end + end + let result = code_escapes((Bool,String,String,)) do cnd, baz, qux + if cnd + o = SafeRefs("foo", "bar") + r = setfield!(o, 2, baz) + else + o = SafeRef(qux) + end + if !cnd + o = o::SafeRef + r = getfield(o, 1) + end + r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # baz + @test has_return_escape(result.state[Argument(4)], r) # qux + end + + # foreigncall should disable field analysis + let result = code_escapes((Any,Nothing,Int,UInt)) do t, mt, lim, world + ambig = false + min = Ref{UInt}(typemin(UInt)) + max = Ref{UInt}(typemax(UInt)) + has_ambig = Ref{Int32}(0) + mt = ccall(:jl_matching_methods, Any, + (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ref{Int32}), + t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool} + return mt, has_ambig[] + end + for i in findall(isnew, result.ir.stmts.inst) + @test !is_load_forwardable(result.state[SSAValue(i)]) + end + end +end + +# demonstrate the power of our field / alias analysis with a realistic end to end example +abstract type AbstractPoint{T} end +mutable struct MPoint{T} <: AbstractPoint{T} + x::T + y::T +end +add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y) +function compute(T, ax, ay, bx, by) + a = T(ax, ay) + b = T(bx, by) + for i in 0:(100000000-1) + c = add(a, b) # replaceable + a = add(c, b) # replaceable + end + a.x, a.y +end +let result = @code_escapes compute(MPoint, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + for i in findall(1:length(result.ir.stmts)) do idx + inst = EscapeAnalysis.getinst(result.ir, idx) + stmt = inst[:inst] + return (isnew(stmt) || isϕ(stmt)) && inst[:type] <: MPoint + end + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end +function compute(a, b) + for i in 0:(100000000-1) + c = add(a, b) # replaceable + a = add(c, b) # unreplaceable (aliased to the call argument `a`) + end + a.x, a.y +end +let result = @code_escapes compute(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + idxs = findall(1:length(result.ir.stmts)) do idx + inst = EscapeAnalysis.getinst(result.ir, idx) + stmt = inst[:inst] + return isnew(stmt) && inst[:type] <: MPoint + end + @assert length(idxs) == 2 + @test count(i->is_load_forwardable(result.state[SSAValue(i)]), idxs) == 1 +end +function compute!(a, b) + for i in 0:(100000000-1) + c = add(a, b) # replaceable + a′ = add(c, b) # replaceable + a.x = a′.x + a.y = a′.y + end +end +let result = @code_escapes compute!(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + for i in findall(1:length(result.ir.stmts)) do idx + inst = EscapeAnalysis.getinst(result.ir, idx) + stmt = inst[:inst] + return isnew(stmt) && inst[:type] <: MPoint + end + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end + +@testset "array primitives" begin + inbounds = Base.JLOptions().check_bounds == 0 + + # arrayref + let result = code_escapes((Vector{String},Int)) do xs, i + s = Base.arrayref(true, xs, i) + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + let result = code_escapes((Vector{String},Int)) do xs, i + s = Base.arrayref(false, xs, i) + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + inbounds && let result = code_escapes((Vector{String},Int)) do xs, i + s = @inbounds xs[i] + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + let result = code_escapes((Vector{String},Bool)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError will happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((String,Int)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError will happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((AbstractVector{String},Int)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError may happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((Vector{String},Any)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError may happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + + # arrayset + let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + Base.arrayset(false, xs, x, i) + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + inbounds && let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + @inbounds xs[i] = x + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + let result = code_escapes((String,String,String,)) do s, t, u + xs = Vector{String}(undef, 3) + Base.arrayset(true, xs, s, 1) + Base.arrayset(true, xs, t, 2) + Base.arrayset(true, xs, u, 3) + return xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + for i in 2:result.state.nargs + @test has_return_escape(result.state[Argument(i)], r) + end + end + let result = code_escapes((Vector{String},String,Bool,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError will happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((String,String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError will happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs::String + @test has_thrown_escape(result.state[Argument(3)], t) # x::String + end + let result = code_escapes((AbstractVector{String},String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{String},AbstractString,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + + # arrayref and arrayset + let result = code_escapes() do + a = Vector{Vector{Any}}(undef, 1) + b = Any[] + a[1] = b + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + ai = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} + end) + bi = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Any} + end) + @test !has_return_escape(result.state[SSAValue(ai)], r) + @test has_return_escape(result.state[SSAValue(bi)], r) + end + let result = code_escapes() do + a = Vector{Vector{Any}}(undef, 1) + b = Any[] + a[1] = b + return a + end + r = only(findall(isreturn, result.ir.stmts.inst)) + ai = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} + end) + bi = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Any} + end) + @test has_return_escape(result.state[SSAValue(ai)], r) + @test has_return_escape(result.state[SSAValue(bi)], r) + end + let result = code_escapes((Vector{Any},String,Int,Int)) do xs, s, i, j + x = SafeRef(s) + xs[i] = x + xs[j] # potential error + end + i = only(findall(isnew, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(3)], t) # s + @test has_thrown_escape(result.state[SSAValue(i)], t) # x + end + + # arraysize + let result = code_escapes((Vector{Any},)) do xs + Core.arraysize(xs, 1) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) + end + let result = code_escapes((Vector{Any},Int,)) do xs, dim + Core.arraysize(xs, dim) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) + end + let result = code_escapes((Any,)) do xs + Core.arraysize(xs, 1) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) + end + + # arraylen + let result = code_escapes((Vector{Any},)) do xs + Base.arraylen(xs) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((String,)) do xs + Base.arraylen(xs) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((Vector{Any},)) do xs + Base.arraylen(xs, 1) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + + # array resizing + # without BoundsErrors + let result = code_escapes((Vector{Any},String)) do xs, x + @ccall jl_array_grow_beg(xs::Any, 2::UInt)::Cvoid + xs[1] = x + xs + end + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{Any},String)) do xs, x + @ccall jl_array_grow_end(xs::Any, 2::UInt)::Cvoid + xs[1] = x + xs + end + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + # with possible BoundsErrors + let result = code_escapes((String,)) do x + xs = Any[1,2,3] + xs[3] = x + @ccall jl_array_del_beg(xs::Any, 2::UInt)::Cvoid # can potentially throw + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[1,2,3] + xs[1] = x + @ccall jl_array_del_end(xs::Any, 2::UInt)::Cvoid # can potentially throw + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[x] + @ccall jl_array_grow_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[x] + @ccall jl_array_del_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + inbounds && let result = code_escapes((String,)) do x + xs = @inbounds Any[x] + @ccall jl_array_del_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + + # array copy + let result = code_escapes((Vector{Any},)) do xs + return copy(xs) + end + i = only(findall(isarraycopy, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test_broken !has_return_escape(result.state[Argument(2)], r) + end + let result = code_escapes((String,)) do s + xs = String[s] + xs′ = copy(xs) + return xs′[1] + end + i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) + i2 = only(findall(isarraycopy, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i1)]) + @test !has_return_escape(result.state[SSAValue(i2)]) + @test has_return_escape(result.state[Argument(2)], r) # s + end + let result = code_escapes((Vector{Any},)) do xs + xs′ = copy(xs) + return xs′[1] # may potentially throw BoundsError, should escape `xs` conservatively (i.e. escape its elements) + end + i = only(findall(isarraycopy, result.ir.stmts.inst)) + ref = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + ret = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_thrown_escape(result.state[SSAValue(i)], ref) + @test_broken !has_return_escape(result.state[SSAValue(i)], ret) + @test has_thrown_escape(result.state[Argument(2)], ref) + @test has_return_escape(result.state[Argument(2)], ret) + end + let result = code_escapes((String,)) do s + xs = Vector{String}(undef, 1) + xs[1] = s + xs′ = copy(xs) + length(xs′) > 2 && throw(xs′) + return xs′ + end + i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) + i2 = only(findall(isarraycopy, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_thrown_escape(result.state[SSAValue(i1)], t) + @test_broken !has_return_escape(result.state[SSAValue(i1)], r) + @test has_thrown_escape(result.state[SSAValue(i2)], t) + @test has_return_escape(result.state[SSAValue(i2)], r) + @test has_thrown_escape(result.state[Argument(2)], t) + @test has_return_escape(result.state[Argument(2)], r) + end + + # isassigned + let result = code_escapes((Vector{Any},Int)) do xs, i + return isassigned(xs, i) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[Argument(2)], r) + @test !has_thrown_escape(result.state[Argument(2)]) + end + + # indexing analysis + # ----------------- + + # safe case + let result = code_escapes((String,String)) do s, t + a = Vector{Any}(undef, 2) + a[1] = s + a[2] = t + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(2)], r) # s + @test !has_return_escape(result.state[Argument(3)], r) # t + end + let result = code_escapes((String,String)) do s, t + a = Matrix{Any}(undef, 1, 2) + a[1, 1] = s + a[1, 2] = t + return a[1, 1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(2)], r) # s + @test !has_return_escape(result.state[Argument(3)], r) # t + end + let result = code_escapes((Bool,String,String,String)) do c, s, t, u + a = Vector{Any}(undef, 2) + if c + a[1] = s + a[2] = u + else + a[1] = t + a[2] = u + end + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test has_return_escape(result.state[Argument(3)], r) # s + @test has_return_escape(result.state[Argument(4)], r) # t + @test !has_return_escape(result.state[Argument(5)], r) # u + end + let result = code_escapes((Bool,String,String,String)) do c, s, t, u + a = Any[nothing, nothing] # TODO how to deal with loop indexing? + if c + a[1] = s + a[2] = u + else + a[1] = t + a[2] = u + end + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test_broken is_load_forwardable(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(3)], r) # s + @test has_return_escape(result.state[Argument(4)], r) # t + @test_broken !has_return_escape(result.state[Argument(5)], r) # u + end + let result = code_escapes((String,)) do s + a = Vector{Vector{Any}}(undef, 1) + b = Any[s] + a[1] = b + return a[1][1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + is = findall(isarrayalloc, result.ir.stmts.inst) + @assert length(is) == 2 + ia, ib = is + @test !has_return_escape(result.state[SSAValue(ia)], r) + @test is_load_forwardable(result.state[SSAValue(ia)]) + @test !has_return_escape(result.state[SSAValue(ib)], r) + @test_broken is_load_forwardable(result.state[SSAValue(ib)]) + @test has_return_escape(result.state[Argument(2)], r) # s + end + let result = code_escapes((Bool,String,String,Regex,Regex,)) do c, s1, s2, t1, t2 + if c + a = Vector{String}(undef, 2) + a[1] = s1 + a[2] = s2 + else + a = Vector{Regex}(undef, 2) + a[1] = t1 + a[2] = t2 + end + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(isarrayalloc, result.ir.stmts.inst) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + @test has_return_escape(result.state[Argument(3)], r) # s1 + @test !has_return_escape(result.state[Argument(4)], r) # s2 + @test has_return_escape(result.state[Argument(5)], r) # t1 + @test !has_return_escape(result.state[Argument(6)], r) # t2 + end + let result = code_escapes((String,String,Int)) do s, t, i + a = Any[s] + push!(a, t) + return a[2] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test_broken is_load_forwardable(result.state[SSAValue(i)]) + @test_broken !has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + end + # unsafe cases + let result = code_escapes((String,String,Int)) do s, t, i + a = Vector{Any}(undef, 2) + a[1] = s + a[2] = t + return a[i] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test !is_load_forwardable(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + end + let result = code_escapes((String,String,Int)) do s, t, i + a = Vector{Any}(undef, 2) + a[1] = s + a[i] = t + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test !is_load_forwardable(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + end + let result = code_escapes((String,String,Int,Int,Int)) do s, t, i, j, k + a = Vector{Any}(undef, 2) + a[3] = s # BoundsError + a[1] = t + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test !is_load_forwardable(result.state[SSAValue(i)]) + end + let result = @eval Module() begin + @noinline some_resize!(a) = pushfirst!(a, nothing) + $code_escapes((String,String,Int)) do s, t, i + a = Vector{Any}(undef, 2) + a[1] = s + some_resize!(a) + return a[2] + end + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + @test !is_load_forwardable(result.state[SSAValue(i)]) + end + + # circular reference + let result = code_escapes() do + xs = Vector{Any}(undef, 1) + xs[1] = xs + return xs[1] + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + let result = @eval Module() begin + const Ax = Vector{Any}(undef, 1) + Ax[1] = Ax + $code_escapes() do + xs = Ax[1]::Vector{Any} + return xs[1] + end + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(iscall((result.ir, Core.arrayref)), result.ir.stmts.inst) + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = @eval Module() begin + @noinline function genxs() + xs = Vector{Any}(undef, 1) + xs[1] = xs + return xs + end + $code_escapes() do + xs = genxs() + return xs[1] + end + end + i = only(findall(isinvoke(:genxs), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end +end + +# demonstrate array primitive support with a realistic end to end example +let result = code_escapes((Int,String,)) do n,s + xs = String[] + for i in 1:n + push!(xs, s) + end + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + Base.JLOptions().check_bounds ≠ 0 && @test has_thrown_escape(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(3)], r) # s + Base.JLOptions().check_bounds ≠ 0 && @test has_thrown_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((Int,String,)) do n,s + xs = String[] + for i in 1:n + pushfirst!(xs, s) + end + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) # xs + @test has_thrown_escape(result.state[SSAValue(i)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # s + @test has_thrown_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((String,String,String)) do s, t, u + xs = String[] + resize!(xs, 3) + xs[1] = s + xs[1] = t + xs[1] = u + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test has_thrown_escape(result.state[SSAValue(i)]) # xs + @test has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + @test has_return_escape(result.state[Argument(4)], r) # u +end + +@static if isdefined(Core, :ImmutableArray) + +import Core: ImmutableArray, arrayfreeze, mutating_arrayfreeze, arraythaw + +@testset "ImmutableArray" begin + # arrayfreeze + let result = code_escapes((Vector{Any},)) do xs + arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector,)) do xs + arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray{Any,1},)) do xs + arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = Any[] + arrayfreeze(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end + + # mutating_arrayfreeze + let result = code_escapes((Vector{Any},)) do xs + mutating_arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector,)) do xs + mutating_arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + mutating_arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray{Any,1},)) do xs + mutating_arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = Any[] + mutating_arrayfreeze(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end + + # arraythaw + let result = code_escapes((ImmutableArray{Any,1},)) do xs + arraythaw(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray,)) do xs + arraythaw(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + arraythaw(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector{Any},)) do xs + arraythaw(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = ImmutableArray(Any[]) + arraythaw(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end +end + +# demonstrate some arrayfreeze optimizations +# !has_return_escape(ary) means ary is eligible for arrayfreeze to mutating_arrayfreeze optimization +let result = code_escapes((Int,)) do n + xs = collect(1:n) + ImmutableArray(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end +let result = code_escapes((Vector{Float64},)) do xs + ys = sin.(xs) + ImmutableArray(ys) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end +let result = code_escapes((Vector{Pair{Int,String}},)) do xs + n = maximum(first, xs) + ys = Vector{String}(undef, n) + for (i, s) in xs + ys[i] = s + end + ImmutableArray(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end + +end # @static if isdefined(Core, :ImmutableArray) + +# demonstrate a simple type level analysis can sometimes improve the analysis accuracy +# by compensating the lack of yet unimplemented analyses +@testset "special-casing bitstype" begin + let result = code_escapes((Nothing,)) do a + global GV = a + end + @test !(has_all_escape(result.state[Argument(2)])) + end + + let result = code_escapes((Int,)) do a + o = SafeRef(a) + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + end + + # an escaped tuple stmt will not propagate to its Int argument (since `Int` is of bitstype) + let result = code_escapes((Int,Any,)) do a, b + t = tuple(a, b) + return t + end + i = only(findall(iscall((result.ir, tuple)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[Argument(2)], r) + @test has_return_escape(result.state[Argument(3)], r) + end +end + +# # TODO implement a finalizer elision pass +# mutable struct WithFinalizer +# v +# function WithFinalizer(v) +# x = new(v) +# f(t) = @async println("Finalizing $t.") +# return finalizer(x, x) +# end +# end +# make_m(v = 10) = MyMutable(v) +# function simple(cond) +# m = make_m() +# if cond +# # println(m.v) +# return nothing # <= insert `finalize` call here +# end +# return m +# end diff --git a/test/compiler/EscapeAnalysis/setup.jl b/test/compiler/EscapeAnalysis/setup.jl new file mode 100644 index 0000000000000..5123b18e2dfdd --- /dev/null +++ b/test/compiler/EscapeAnalysis/setup.jl @@ -0,0 +1,71 @@ +include(normpath(@__DIR__, "EAUtils.jl")) +using Test, Core.Compiler.EscapeAnalysis, .EAUtils +import Core: Argument, SSAValue, ReturnNode +const EA = Core.Compiler.EscapeAnalysis + +isT(T) = (@nospecialize x) -> x === T +isreturn(@nospecialize x) = isa(x, Core.ReturnNode) && isdefined(x, :val) +isthrow(@nospecialize x) = Meta.isexpr(x, :call) && Core.Compiler.is_throw_call(x) +isnew(@nospecialize x) = Meta.isexpr(x, :new) +isϕ(@nospecialize x) = isa(x, Core.PhiNode) +function with_normalized_name(@nospecialize(f), @nospecialize(x)) + if Meta.isexpr(x, :foreigncall) + name = x.args[1] + nn = EA.normalize(name) + return isa(nn, Symbol) && f(nn) + end + return false +end +isarrayalloc(@nospecialize x) = with_normalized_name(nn->!isnothing(Core.Compiler.alloc_array_ndims(nn)), x) +isarrayresize(@nospecialize x) = with_normalized_name(nn->!isnothing(EA.array_resize_info(nn)), x) +isarraycopy(@nospecialize x) = with_normalized_name(nn->EA.is_array_copy(nn), x) +import Core.Compiler: argextype, singleton_type +iscall(y) = @nospecialize(x) -> iscall(y, x) +function iscall((ir, f), @nospecialize(x)) + return iscall(x) do @nospecialize x + singleton_type(Core.Compiler.argextype(x, ir, Any[])) === f + end +end +iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1]) + +# check if `x` is a statically-resolved call of a function whose name is `sym` +isinvoke(y) = @nospecialize(x) -> isinvoke(y, x) +isinvoke(sym::Symbol, @nospecialize(x)) = isinvoke(mi->mi.def.name===sym, x) +isinvoke(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :invoke) && pred(x.args[1]::Core.MethodInstance) + +""" + is_load_forwardable(x::EscapeInfo) -> Bool + +Queries if `x` is elibigle for store-to-load forwarding optimization. +""" +function is_load_forwardable(x::EA.EscapeInfo) + AliasInfo = x.AliasInfo + # NOTE technically we also need to check `!has_thrown_escape(x)` here as well, + # but we can also do equivalent check during forwarding + return isa(AliasInfo, EA.IndexableFields) || isa(AliasInfo, EA.IndexableElements) +end + +let setup_ex = quote + mutable struct SafeRef{T} + x::T + end + Base.getindex(s::SafeRef) = getfield(s, 1) + Base.setindex!(s::SafeRef, x) = setfield!(s, 1, x) + + mutable struct SafeRefs{S,T} + x1::S + x2::T + end + Base.getindex(s::SafeRefs, idx::Int) = getfield(s, idx) + Base.setindex!(s::SafeRefs, x, idx::Int) = setfield!(s, idx, x) + + global GV::Any + const global GR = Ref{Any}() + end + global function EATModule(setup_ex = setup_ex) + M = Module() + Core.eval(M, setup_ex) + return M + end + Core.eval(@__MODULE__, setup_ex) +end From 325f41471792dfb5e2b53b96c3388ad8c7b6564a Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Sat, 22 Jan 2022 03:12:25 +0900 Subject: [PATCH 2/5] optimizer: alias-aware SROA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enhances SROA of mutables using the novel Julia-level escape analysis (on top of #43800): 1. alias-aware SROA, mutable ϕ-node elimination 2. `isdefined` check elimination 3. load-forwarding for non-eliminable but analyzable mutables --- 1. alias-aware SROA, mutable ϕ-node elimination EA's alias analysis allows this new SROA to handle nested mutables allocations pretty well. Now we can eliminate the heap allocations completely from this insanely nested examples by the single analysis/optimization pass: ```julia julia> function refs(x) (Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][] end refs (generic function with 1 method) julia> refs("julia"); @allocated refs("julia") 0 ``` EA can also analyze escape of ϕ-node as well as its aliasing. Mutable ϕ-nodes would be eliminated even for a very tricky case as like: ```julia julia> code_typed((Bool,String,)) do cond, x # these allocation form multiple ϕ-nodes if cond ϕ2 = ϕ1 = Ref{Any}("foo") else ϕ2 = ϕ1 = Ref{Any}("bar") end ϕ2[] = x y = ϕ1[] # => x return y end 1-element Vector{Any}: CodeInfo( 1 ─ goto #3 if not cond 2 ─ goto #4 3 ─ nothing::Nothing 4 ┄ return x ) => Any ``` Combined with the alias analysis and ϕ-node handling above, allocations in the following "realistic" examples will be optimized: ```julia julia> # demonstrate the power of our field / alias analysis with realistic end to end examples # adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B abstract type AbstractPoint{T} end julia> struct Point{T} <: AbstractPoint{T} x::T y::T end julia> mutable struct MPoint{T} <: AbstractPoint{T} x::T y::T end julia> add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y); julia> function compute_point(T, n, ax, ay, bx, by) a = T(ax, ay) b = T(bx, by) for i in 0:(n-1) a = add(add(a, b), b) end a.x, a.y end; julia> function compute_point(n, a, b) for i in 0:(n-1) a = add(add(a, b), b) end a.x, a.y end; julia> function compute_point!(n, a, b) for i in 0:(n-1) a′ = add(add(a, b), b) a.x = a′.x a.y = a′.y end end; julia> compute_point(MPoint, 10, 1+.5, 2+.5, 2+.25, 4+.75); julia> compute_point(MPoint, 10, 1+.5im, 2+.5im, 2+.25im, 4+.75im); julia> @allocated compute_point(MPoint, 10000, 1+.5, 2+.5, 2+.25, 4+.75) 0 julia> @allocated compute_point(MPoint, 10000, 1+.5im, 2+.5im, 2+.25im, 4+.75im) 0 julia> compute_point(10, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75)); julia> compute_point(10, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)); julia> @allocated compute_point(10000, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75)) 0 julia> @allocated compute_point(10000, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) 0 julia> af, bf = MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75); julia> ac, bc = MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im); julia> compute_point!(10, af, bf); julia> compute_point!(10, ac, bc); julia> @allocated compute_point!(10000, af, bf) 0 julia> @allocated compute_point!(10000, ac, bc) 0 ``` 2. `isdefined` check elimination This commit also implements a simple optimization to eliminate `isdefined` call by checking load-fowardability. This optimization may be especially useful to eliminate extra allocation involved with a capturing closure, e.g.: ```julia julia> callit(f, args...) = f(args...); julia> function isdefined_elim() local arr::Vector{Any} callit() do arr = Any[] end return arr end; julia> code_typed(isdefined_elim) 1-element Vector{Any}: CodeInfo( 1 ─ %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Any}, svec(Any, Int64), 0, :(:ccall), Vector{Any}, 0, 0))::Vector{Any} └── goto #3 if not true 2 ─ goto #4 3 ─ $(Expr(:throw_undef_if_not, :arr, false))::Any 4 ┄ return %1 ) => Vector{Any} ``` 3. load-forwarding for non-eliminable but analyzable mutables EA also allows us to forward loads even when the mutable allocation can't be eliminated but still its fields are known precisely. The load forwarding might be useful since it may derive new type information that succeeding optimization passes can use (or just because it allows simpler code transformations down the load): ```julia julia> code_typed((Bool,String,)) do c, s r = Ref{Any}(s) if c return r[]::String # adce_pass! will further eliminate this type assert call also else return r end end 1-element Vector{Any}: CodeInfo( 1 ─ %1 = %new(Base.RefValue{Any}, s)::Base.RefValue{Any} └── goto #3 if not c 2 ─ return s 3 ─ return %1 ) => Union{Base.RefValue{Any}, String} ``` --- Please refer to the newly added test cases for more examples. Also, EA's alias analysis already succeeds to reason about arrays, and so this EA-based SROA will hopefully be generalized for array SROA as well. --- base/compiler/bootstrap.jl | 6 +- base/compiler/optimize.jl | 26 +- base/compiler/ssair/passes.jl | 852 ++++++++++-------- test/compiler/EscapeAnalysis/EAUtils.jl | 6 +- .../EscapeAnalysis/interprocedural.jl | 2 - test/compiler/EscapeAnalysis/setup.jl | 1 + test/compiler/irpasses.jl | 781 ++++++++++++++-- 7 files changed, 1252 insertions(+), 422 deletions(-) diff --git a/base/compiler/bootstrap.jl b/base/compiler/bootstrap.jl index 1989d8aa57393..487ddf2ccdd1b 100644 --- a/base/compiler/bootstrap.jl +++ b/base/compiler/bootstrap.jl @@ -11,7 +11,11 @@ let world = get_world_counter() interp = NativeInterpreter(world) - analyze_escapes_tt = Tuple{typeof(analyze_escapes), IRCode, Int, Bool, typeof(get_escape_cache(code_cache(interp)))} + analyze_escapes_tt = Any[typeof(analyze_escapes), IRCode, Int, Bool, + # typeof(get_escape_cache(code_cache(interp))) # once we enable IPO EA + typeof(null_escape_cache) + ] + analyze_escapes_tt = Tuple{analyze_escapes_tt...} fs = Any[ # we first create caches for the optimizer, because they contain many loop constructions # and they're better to not run in interpreter even during bootstrapping diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 635e53a9e1f1d..e84f77ae1ea48 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -98,7 +98,7 @@ and then caches it into a global cache for later interprocedural propagation. cache_escapes!(caller::InferenceResult, estate::EscapeState) = caller.argescapes = ArgEscapeCache(estate) -function get_escape_cache(mi_cache::MICache) where MICache +function ipo_escape_cache(mi_cache::MICache) where MICache return function (linfo::Union{InferenceResult,MethodInstance}) if isa(linfo, InferenceResult) argescapes = linfo.argescapes @@ -110,6 +110,7 @@ function get_escape_cache(mi_cache::MICache) where MICache return argescapes !== nothing ? argescapes::ArgEscapeCache : nothing end end +null_escape_cache(linfo::Union{InferenceResult,MethodInstance}) = nothing mutable struct OptimizationState linfo::MethodInstance @@ -540,17 +541,24 @@ function run_passes(ci::CodeInfo, sv::OptimizationState, caller::InferenceResult # TODO: Domsorting can produce an updated domtree - no need to recompute here @timeit "compact 1" ir = compact!(ir) nargs = let def = sv.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end - get_escape_cache = (@__MODULE__).get_escape_cache(sv.inlining.mi_cache) - if is_ipo_profitable(ir, nargs) - @timeit "IPO EA" begin - state = analyze_escapes(ir, nargs, false, get_escape_cache) - cache_escapes!(caller, state) - end - end + # if is_ipo_profitable(ir, nargs) + # @timeit "IPO EA" begin + # state = analyze_escapes(ir, + # nargs, #=call_resolved=#false, ipo_escape_cache(sv.inlining.mi_cache)) + # cache_escapes!(caller, state) + # end + # end @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) # @timeit "verify 2" verify_ir(ir) @timeit "compact 2" ir = compact!(ir) - @timeit "SROA" ir = sroa_pass!(ir) + @timeit "SROA" ir, memory_opt = linear_pass!(ir) + if memory_opt + @timeit "memory_opt_pass!" begin + @timeit "Local EA" estate = analyze_escapes(ir, + nargs, #=call_resolved=#true, null_escape_cache) + @timeit "memory_opt_pass!" ir = memory_opt_pass!(ir, estate) + end + end @timeit "ADCE" ir = adce_pass!(ir) @timeit "type lift" ir = type_lift_pass!(ir) @timeit "compact 3" ir = compact!(ir) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 67610f0c1df60..4c25654f83f1b 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -6,29 +6,6 @@ function is_known_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,I return singleton_type(ft) === func end -""" - du::SSADefUse - -This struct keeps track of all uses of some mutable struct allocated in the current function: -- `du.uses::Vector{Int}` are all instances of `getfield` on the struct -- `du.defs::Vector{Int}` are all instances of `setfield!` on the struct -The terminology refers to the uses/defs of the "slot bundle" that the mutable struct represents. - -In addition we keep track of all instances of a `:foreigncall` that preserves of this mutable -struct in `du.ccall_preserve_uses`. Somewhat counterintuitively, we don't actually need to -make sure that the struct itself is live (or even allocated) at a `ccall` site. -If there are no other places where the struct escapes (and thus e.g. where its address is taken), -it need not be allocated. We do however, need to make sure to preserve any elements of this struct. -""" -struct SSADefUse - uses::Vector{Int} - defs::Vector{Int} - ccall_preserve_uses::Vector{Int} -end -SSADefUse() = SSADefUse(Int[], Int[], Int[]) - -compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses) - # assume `stmt == getfield(obj, field, ...)` or `stmt == setfield!(obj, field, val, ...)` try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr) = try_compute_field(ir, stmt.args[3]) @@ -55,112 +32,6 @@ function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::E return try_compute_fieldidx(typ, field) end -function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int) - # TODO: This can be much faster by looking at current level and only - # searching for those blocks in a sorted order - while !(curblock in allblocks) - curblock = domtree.idoms_bb[curblock] - end - return curblock -end - -function val_for_def_expr(ir::IRCode, def::Int, fidx::Int) - ex = ir[SSAValue(def)][:inst] - if isexpr(ex, :new) - return ex.args[1+fidx] - else - @assert isa(ex, Expr) - # The use is whatever the setfield was - return ex.args[4] - end -end - -function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) - curblock = find_curblock(domtree, allblocks, curblock) - def = 0 - for stmt in du.defs - if block_for_inst(ir.cfg, stmt) == curblock - def = max(def, stmt) - end - end - def == 0 ? phinodes[curblock] : val_for_def_expr(ir, def, fidx) -end - -function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) - def, useblock, curblock = find_def_for_use(ir, domtree, allblocks, du, use) - if def == 0 - if !haskey(phinodes, curblock) - # If this happens, we need to search the predecessors for defs. Which - # one doesn't matter - if it did, we'd have had a phinode - return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) - end - # The use is the phinode - return phinodes[curblock] - else - return val_for_def_expr(ir, def, fidx) - end -end - -# even when the allocation contains an uninitialized field, we try an extra effort to check -# if this load at `idx` have any "safe" `setfield!` calls that define the field -function has_safe_def( - ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, - newidx::Int, idx::Int) - def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx) - # will throw since we already checked this `:new` site doesn't define this field - def == newidx && return false - # found a "safe" definition - def ≠ 0 && return true - # we may still be able to replace this load with `PhiNode` - # examine if all predecessors of `block` have any "safe" definition - block = block_for_inst(ir, idx) - seen = BitSet(block) - worklist = BitSet(ir.cfg.blocks[block].preds) - isempty(worklist) && return false - while !isempty(worklist) - pred = pop!(worklist) - # if this block has already been examined, bail out to avoid infinite cycles - pred in seen && return false - idx = last(ir.cfg.blocks[pred].stmts) - # NOTE `idx` isn't a load, thus we can use inclusive coondition within the `find_def_for_use` - def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx, true) - # will throw since we already checked this `:new` site doesn't define this field - def == newidx && return false - push!(seen, pred) - # found a "safe" definition for this predecessor - def ≠ 0 && continue - # check for the predecessors of this predecessor - for newpred in ir.cfg.blocks[pred].preds - push!(worklist, newpred) - end - end - return true -end - -# find the first dominating def for the given use -function find_def_for_use( - ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, use::Int, inclusive::Bool=false) - useblock = block_for_inst(ir.cfg, use) - curblock = find_curblock(domtree, allblocks, useblock) - local def = 0 - for idx in du.defs - if block_for_inst(ir.cfg, idx) == curblock - if curblock != useblock - # Find the last def in this block - def = max(def, idx) - else - # Find the last def before our use - if inclusive - def = max(def, idx ≤ use ? idx : 0) - else - def = max(def, idx < use ? idx : 0) - end - end - end - end - return def, useblock, curblock -end - function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint)) if isa(val, Union{OldSSAValue, SSAValue}) val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint) @@ -657,38 +528,35 @@ end const SPCSet = IdSet{Int} """ - sroa_pass!(ir::IRCode) -> newir::IRCode - -`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization. - -This pass is based on a local field analysis by def-use chain walking. -It looks for struct allocation sites ("definitions"), and `getfield` calls as well as -`:foreigncall`s that preserve the structs ("usages"). If "definitions" have enough information, -then this pass will replace corresponding usages with forwarded values. -`mutable struct`s require additional cares and need to be handled separately from immutables. -For `mutable struct`s, `setfield!` calls account for "definitions" also, and the pass should -give up the lifting conservatively when there are any "intermediate usages" that may escape -the mutable struct (e.g. non-inlined generic function call that takes the mutable struct as -its argument). - -In a case when all usages are fully eliminated, `struct` allocation may also be erased as -a result of succeeding dead code elimination. + linear_pass!(ir::IRCode) -> (newir::IRCode, memory_opt::Bool) + +This pass consists of the following optimizations that can be performed by +a single linear traversal over IR statements: +- load forwarding of immutables (`getfield` elimination): immutable allocations whose + loads are all eliminated by this pass may be erased entirely as a result of succeeding + dead code elimination (this allocation elimination is called "SROA", Scalar Replacements of Aggregates) +- lifting of builtin comparisons: see [`lift_comparison!`](@ref) +- canonicalization of `typeassert` calls: see [`canonicalize_typeassert!`](@ref) + +In addition to performing the optimizations above, the linear traversal also examines each +statement and checks if there is any profitability of running [`memory_opt_pass!`](@ref) pass. +In such cases `memory_opt` is flagged on and it indicates `ir` may be further optimized by +running `memory_opt_pass!(ir, estate::EscapeState)`. """ -function sroa_pass!(ir::IRCode) +function linear_pass!(ir::IRCode) compact = IncrementalCompact(ir) - defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() + local memory_opt = false # whether or not to run the memory_opt_pass! pass later for ((_, idx), stmt) in compact - # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue - is_setfield = false field_ordering = :unspecified - if is_known_call(stmt, setfield!, compact) - 4 <= length(stmt.args) <= 5 || continue - is_setfield = true - if length(stmt.args) == 5 - field_ordering = argextype(stmt.args[5], compact) + if isexpr(stmt, :new) + typ = unwrap_unionall(widenconst(argextype(SSAValue(idx), compact))) + if ismutabletype(typ) + # mutable SROA may eliminate this eliminate this allocation, mark it now + memory_opt = true end + continue elseif is_known_call(stmt, getfield, compact) 3 <= length(stmt.args) <= 5 || continue if length(stmt.args) == 5 @@ -704,40 +572,21 @@ function sroa_pass!(ir::IRCode) for pidx in (6+nccallargs):length(stmt.args) preserved_arg = stmt.args[pidx] isa(preserved_arg, SSAValue) || continue - let intermediaries = SPCSet() - callback = function (@nospecialize(pi), @nospecialize(ssa)) - push!(intermediaries, ssa.id) - return false - end - def = simple_walk(compact, preserved_arg, callback) - isa(def, SSAValue) || continue - defidx = def.id - def = compact[defidx] - if is_known_call(def, tuple, compact) + def = simple_walk(compact, preserved_arg) + isa(def, SSAValue) || continue + defidx = def.id + def = compact[defidx] + if is_known_call(def, tuple, compact) + record_immutable_preserve!(new_preserves, def, compact) + push!(preserved, preserved_arg.id) + elseif isexpr(def, :new) + typ = unwrap_unionall(widenconst(argextype(SSAValue(defidx), compact))) + if typ isa DataType + ismutabletype(typ) && continue # mutable SROA is performed later record_immutable_preserve!(new_preserves, def, compact) push!(preserved, preserved_arg.id) - continue - elseif isexpr(def, :new) - typ = widenconst(argextype(SSAValue(defidx), compact)) - if isa(typ, UnionAll) - typ = unwrap_unionall(typ) - end - if typ isa DataType && !ismutabletype(typ) - record_immutable_preserve!(new_preserves, def, compact) - push!(preserved, preserved_arg.id) - continue - end - else - continue end - if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() - end - mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse())) - push!(defuse.ccall_preserve_uses, idx) - union!(mid, intermediaries) end - continue end if !isempty(new_preserves) compact[idx] = form_new_preserves(stmt, preserved, new_preserves) @@ -756,7 +605,7 @@ function sroa_pass!(ir::IRCode) continue end - # analyze this `getfield` / `setfield!` call + # analyze this `getfield` call field = try_compute_field_stmt(compact, stmt) field === nothing && continue @@ -774,32 +623,7 @@ function sroa_pass!(ir::IRCode) continue end - # analyze this mutable struct here for the later pass - if ismutabletype(struct_typ) - isa(val, SSAValue) || continue - let intermediaries = SPCSet() - callback = function (@nospecialize(pi), @nospecialize(ssa)) - push!(intermediaries, ssa.id) - return false - end - def = simple_walk(compact, val, callback) - # Mutable stuff here - isa(def, SSAValue) || continue - if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() - end - mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse())) - if is_setfield - push!(defuse.defs, idx) - else - push!(defuse.uses, idx) - end - union!(mid, intermediaries) - end - continue - elseif is_setfield - continue # invalid `setfield!` call, but just ignore here - end + ismutabletype(struct_typ) && continue # mutable SROA is performed later # perform SROA on immutable structs here on @@ -837,177 +661,503 @@ function sroa_pass!(ir::IRCode) end non_dce_finish!(compact) - if defuses !== nothing - # now go through analyzed mutable structs and see which ones we can eliminate - # NOTE copy the use count here, because `simple_dce!` may modify it and we need it - # consistent with the state of the IR here (after tracking `PhiNode` arguments, - # but before the DCE) for our predicate within `sroa_mutables!`, but we also - # try an extra effort using a callback so that reference counts are updated - used_ssas = copy(compact.used_ssas) - simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1) - ir = complete(compact) - sroa_mutables!(ir, defuses, used_ssas) - return ir - else - simple_dce!(compact) - return complete(compact) - end + simple_dce!(compact) + ir = complete(compact) + return ir, memory_opt end -function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}) - # initialization of domtree is delayed to avoid the expensive computation in many cases - local domtree = nothing - for (idx, (intermediaries, defuse)) in defuses - intermediaries = collect(intermediaries) - # Check if there are any uses we did not account for. If so, the variable - # escapes and we cannot eliminate the allocation. This works, because we're guaranteed - # not to include any intermediaries that have dead uses. As a result, missing uses will only ever - # show up in the nuses_total count. - nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) - nuses = 0 - for idx in intermediaries - nuses += used_ssas[idx] +function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) + newex = Expr(:foreigncall) + nccallargs = length(origex.args[3]::SimpleVector) + for i in 1:(6+nccallargs-1) + push!(newex.args, origex.args[i]) + end + for i in (6+nccallargs):length(origex.args) + x = origex.args[i] + # don't need to preserve intermediaries + if isa(x, SSAValue) && x.id in intermediates + continue end - nuses_total = used_ssas[idx] + nuses - length(intermediaries) - nleaves == nuses_total || continue - # Find the type for this allocation - defexpr = ir[SSAValue(idx)][:inst] - isexpr(defexpr, :new) || continue - newidx = idx - typ = ir.stmts[newidx][:type] - if isa(typ, UnionAll) - typ = unwrap_unionall(typ) + push!(newex.args, x) + end + for i in 1:length(new_preserves) + push!(newex.args, new_preserves[i]) + end + return newex +end + +import .EscapeAnalysis: + EscapeState, EscapeInfo, IndexableFields, LivenessSet, getaliases, LocalUse, LocalDef + +""" + memory_opt_pass!(ir::IRCode, estate::EscapeState) -> newir::IRCode + +Performs memory optimizations using escape information analyzed by `EscapeAnalysis`. +Specifically, this optimization pass does SROA of mutable allocations. + +`estate::EscapeState` is expected to be a result of `analyze_escapes(ir, ...)`. +Since the computational cost of running `analyze_escapes` can be relatively expensive, +it is recommended to run this pass "selectively" i.e. only when there seems to be +a profitability for the memory optimizations. +""" +function memory_opt_pass!(ir::IRCode, estate::EscapeState) + # Compute domtree now, needed below, now that we have finished compacting the IR. + # This needs to be after we iterate through the IR with `IncrementalCompact` + # because removing dead blocks can invalidate the domtree. + # TODO initialization of the domtree can be delayed to avoid the expensive computation + # in cases when there are no loads to be forwarded + @timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks) + workingset = BitSet(1:length(ir.stmts)+length(ir.new_nodes.stmts)) + eliminated = BitSet() + revisit = Tuple{#=related=#Vector{SSAValue}, #=Liveness=#LivenessSet}[] + allpreserved = true + newpreserves = nothing + while !isempty(workingset) + idx = pop!(workingset) + ssa = SSAValue(idx) + stmt = ir[ssa][:inst] + # NOTE `linear_pass!` can't eliminate immutables wrapped by mutables, + # but the EA-based alias analysis may be able to eliminate them also + isexpr(stmt, :new) || is_known_call(stmt, tuple, ir) || continue + einfo = estate[ssa] + is_load_forwardable(einfo) || continue + aliases = getaliases(ssa, estate) + if aliases === nothing + related = SSAValue[ssa] + else + related = SSAValue[] + for alias in aliases + @assert isa(alias, SSAValue) "invalid escape analysis" + push!(related, alias) + delete!(workingset, alias.id) + end end - # Could still end up here if we tried to setfield! on an immutable, which would - # error at runtime, but is not illegal to have in the IR. - ismutabletype(typ) || continue - typ = typ::DataType - # Partition defuses by field - fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)] - all_forwarded = true - for use in defuse.uses - stmt = ir[SSAValue(use)][:inst] # == `getfield` call - # We may have discovered above that this use is dead - # after the getfield elim of immutables. In that case, - # it would have been deleted. That's fine, just ignore - # the use in that case. - if stmt === nothing - all_forwarded = false - continue + finfos = (einfo.AliasInfo::IndexableFields).infos + nflds = length(finfos) + + # Partition defuses by field, and object identity + fdefuses = IdDict{Tuple{Int,SSAValue},FieldDefUse}() + for fidx = 1:nflds + finfo = finfos[fidx] + for fx in finfo + if isa(fx, LocalUse) + use = fx.idx + stmt = ir[SSAValue(use)][:inst] # use (getfield call) + @assert is_known_call(stmt, getfield, ir) + obj = stmt.args[2] + @assert isa(obj, SSAValue) + fdu = get!(()->FieldDefUse(), fdefuses, (fidx, obj)) + push!(fdu.uses, GetfieldLoad(use)) + elseif isa(fx, LocalDef) + def = fx.idx + obj = SSAValue(def) + stmt = ir[obj][:inst] # def (setfield! call, tuple call or :new expression) + for rel in related + if isexpr(stmt, :new) || is_known_call(stmt, tuple, ir) + relstmt = ir[rel][:inst] + if isexpr(relstmt, :new) || is_known_call(relstmt, tuple, ir) + rel !== obj && continue + end + end + fdu = get!(()->FieldDefUse(), fdefuses, (fidx, rel)) + push!(fdu.defs, def) + end + end end - field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ) - field === nothing && @goto skip - push!(fielddefuse[field].uses, use) end - for def in defuse.defs - stmt = ir[SSAValue(def)][:inst]::Expr # == `setfield!` call - field = try_compute_fieldidx_stmt(ir, stmt, typ) - field === nothing && @goto skip - isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error - push!(fielddefuse[field].defs, def) + + Liveness = einfo.Liveness + for livepc in Liveness + livestmt = ir[SSAValue(livepc)][:inst] + if is_known_call(livestmt, Core.ifelse, ir) || + is_known_call(livestmt, tuple, ir) || + is_known_call(livestmt, arrayset, ir) + # TODO the succeeding domination analysis doesn't account for flow sensitivity + # introduced by those program constructs, just give up SROA for now + @goto next_itr + elseif is_known_call(livestmt, isdefined, ir) + args = livestmt.args + length(args) ≥ 3 || continue + obj = args[2] + isa(obj, SSAValue) || continue + obj in related || continue + fld = args[3] + fldval = try_compute_field(ir, fld) + fldval === nothing && continue + typ = unwrap_unionall(widenconst(argextype(obj, ir))) + isa(typ, DataType) || continue + fidx = try_compute_fieldidx(typ, fldval) + fidx === nothing && continue + fdu = get!(()->FieldDefUse(), fdefuses, (fidx, obj)) + push!(fdu.uses, IsdefinedUse(livepc)) + elseif isexpr(livestmt, :foreigncall) + # we shouldn't eliminate this use if it's used as a direct argument + args = livestmt.args + nccallargs = length(args[3]::SimpleVector) + for i = 6:(5+nccallargs) + arg = args[i] + isa(arg, SSAValue) && arg in related && @goto next_liveness + end + # this use is preserve, and may be eliminable + for i = (6+nccallargs):length(args) + arg = args[i] + if isa(arg, SSAValue) && arg in related + for fidx in 1:nflds + fdu = get!(()->FieldDefUse(), fdefuses, (fidx, arg)) + push!(fdu.uses, PreserveUse(livepc)) + end + end + end + end + @label next_liveness end - # Check that the defexpr has defined values for all the fields - # we're accessing. In the future, we may want to relax this, - # but we should come up with semantics for well defined semantics - # for uninitialized fields first. - ndefuse = length(fielddefuse) - blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# Vector{Int}}}(undef, ndefuse) - for fidx in 1:ndefuse - du = fielddefuse[fidx] - isempty(du.uses) && continue - push!(du.defs, newidx) - ldu = compute_live_ins(ir.cfg, du) + + for ((fidx, objssa), fdu) in fdefuses + isempty(fdu.uses) && @goto next_field + # check if all uses have safe definitions first, otherwise we should bail out + # since then we may fail to form new ϕ-nodes + ldu = compute_live_ins(ir.cfg, fdu) if isempty(ldu.live_in_bbs) phiblocks = Int[] else - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) end - allblocks = sort(vcat(phiblocks, ldu.def_bbs)) - blocks[fidx] = phiblocks, allblocks - if fidx + 1 > length(defexpr.args) - for use in du.uses - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) - has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip - end + obj = ir[objssa][:inst] + if isa(obj, PhiNode) + push!(phiblocks, block_for_inst(ir, objssa.id)) end - end - # Everything accounted for. Go field by field and perform idf: - # Compute domtree now, needed below, now that we have finished compacting the IR. - # This needs to be after we iterate through the IR with `IncrementalCompact` - # because removing dead blocks can invalidate the domtree. - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) - preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing : - IdDict{Int, Vector{Any}}((idx=>Any[] for idx in SPCSet(defuse.ccall_preserve_uses))) - for fidx in 1:ndefuse - du = fielddefuse[fidx] - ftyp = fieldtype(typ, fidx) - if !isempty(du.uses) - phiblocks, allblocks = blocks[fidx] - phinodes = IdDict{Int, SSAValue}() - for b in phiblocks - phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), - NewInstruction(PhiNode(), ftyp)) + allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) + for use in fdu.uses + isa(use, IsdefinedUse) && continue + if isa(use, PreserveUse) && isempty(fdu.defs) + # nothing to preserve, just ignore this use (may happen when there are unintialized fields) + continue end - # Now go through all uses and rewrite them - for stmt in du.uses - ir[SSAValue(stmt)][:inst] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt) + if !has_safe_def(ir, domtree, allblocks, fdu, getuseidx(use)) + allpreserved = false + @goto next_field end - if !isbitstype(ftyp) - if preserve_uses !== nothing - for (use, list) in preserve_uses - push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use)) + end + phinodes = IdDict{Int, SSAValue}() + for b in phiblocks + phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), + NewInstruction(PhiNode(), Any)) + end + # Now go through all uses and rewrite them + for use in fdu.uses + if isa(use, GetfieldLoad) + use = getuseidx(use) + ir[SSAValue(use)][:inst] = compute_value_for_use( + ir, domtree, allblocks, fdu, phinodes, fidx, use) + push!(eliminated, use) + elseif isa(use, PreserveUse) + allpreserved || continue + isempty(fdu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) + # record this `use` as replaceable no matter if we preserve new value or not + use = getuseidx(use) + newval = compute_value_for_use( + ir, domtree, allblocks, fdu, phinodes, fidx, use) + if !isbitstype(widenconst(argextype(newval, ir))) + if newpreserves === nothing + newpreserves = IdDict{Int,Vector{Any}}() end + newvalues = get!(()->Any[], newpreserves, use) + push!(newvalues, newval) + end + elseif isa(use, IsdefinedUse) + use = getuseidx(use) + if has_safe_def(ir, domtree, allblocks, fdu, use) + ir[SSAValue(use)][:inst] = true + push!(eliminated, use) end + else + throw("unexpected use") end - for b in phiblocks - n = ir[phinodes[b]][:inst]::PhiNode + end + for b in phiblocks + ϕssa = phinodes[b] + n = ir[ϕssa][:inst]::PhiNode + t = Bottom + if isa(obj, PhiNode) + for i = 1:length(obj.edges) + isassigned(obj.edges, i) || continue + p = obj.edges[i] + push!(n.edges, p) + v = compute_value_for_block(ir, domtree, allblocks, + fdefuses[(fidx, obj.values[i]::SSAValue)], phinodes, fidx, Int(p)) + push!(n.values, v) + if t !== Any + t = tmerge(t, argextype(v, ir)) + end + end + else for p in ir.cfg.blocks[b].preds push!(n.edges, p) - push!(n.values, compute_value_for_block(ir, domtree, - allblocks, du, phinodes, fidx, p)) + v = compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, p) + push!(n.values, v) + if t !== Any + t = tmerge(t, argextype(v, ir)) + end end end + ir[ϕssa][:type] = t end - for stmt in du.defs - stmt == newidx && continue - ir[SSAValue(stmt)][:inst] = nothing - end + @label next_field end - preserve_uses === nothing && continue - if all_forwarded - # this means all ccall preserves have been replaced with forwarded loads - # so we can potentially eliminate the allocation, otherwise we must preserve - # the whole allocation. - push!(intermediaries, newidx) + push!(revisit, (related, Liveness)) + @label next_itr + end + + # remove dead setfield! and :new allocs + deadssas = IdSet{SSAValue}() + if allpreserved && newpreserves !== nothing + preserved = keys(newpreserves) + else + preserved = EMPTY_PRESERVED_SSAS + end + mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved) + for ssa in deadssas + ir[ssa][:inst] = nothing + end + if allpreserved && newpreserves !== nothing + deadssas = Int[ssa.id for ssa in deadssas] + for (idx, newuses) in newpreserves + ir[SSAValue(idx)][:inst] = form_new_preserves( + ir[SSAValue(idx)][:inst]::Expr, deadssas, newuses) end - # Insert the new preserves - for (use, new_preserves) in preserve_uses - ir[SSAValue(use)][:inst] = form_new_preserves(ir[SSAValue(use)][:inst]::Expr, intermediaries, new_preserves) + end + + return ir +end + +const EMPTY_PRESERVED_SSAS = keys(IdDict{Int,Vector{Any}}()) +const PreservedSets = typeof(EMPTY_PRESERVED_SSAS) + +function is_load_forwardable(x::EscapeInfo) + AliasInfo = x.AliasInfo + return isa(AliasInfo, IndexableFields) +end + +struct FieldDefUse + uses::Vector{Any} + defs::Vector{Int} +end +FieldDefUse() = FieldDefUse(Any[], Int[]) +struct GetfieldLoad + idx::Int +end +struct PreserveUse + idx::Int +end +struct IsdefinedUse + idx::Int +end +function getuseidx(@nospecialize use) + if isa(use, GetfieldLoad) + return use.idx + elseif isa(use, PreserveUse) + return use.idx + elseif isa(use, IsdefinedUse) + return use.idx + end + throw("getuseidx: unexpected use") +end + +function compute_live_ins(cfg::CFG, fdu::FieldDefUse) + uses = Int[] + for use in fdu.uses + isa(use, IsdefinedUse) && continue + push!(uses, getuseidx(use)) + end + return compute_live_ins(cfg, fdu.defs, uses) +end + +# even when the allocation contains an uninitialized field, we try an extra effort to check +# if this load at `idx` have any "safe" `setfield!` calls that define the field +# try to find +function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + dfu === nothing && return false + def = dfu[1] + def ≠ 0 && return true # found a "safe" definition + # we may still be able to replace this load with `PhiNode` -- examine if all predecessors of + # this `block` have any "safe" definition + block = block_for_inst(ir, use) + seen = BitSet(block) + worklist = BitSet(ir.cfg.blocks[block].preds) + isempty(worklist) && return false + while !isempty(worklist) + pred = pop!(worklist) + # if this block has already been examined, bail out to avoid infinite cycles + pred in seen && return false + use = last(ir.cfg.blocks[pred].stmts) + # NOTE this `use` isn't a load, and so the inclusive condition can be used + dfu = find_def_for_use(ir, domtree, allblocks, fdu, use, true) + dfu === nothing && return false + def = dfu[1] + push!(seen, pred) + def ≠ 0 && continue # found a "safe" definition for this predecessor + # if not, check for the predecessors of this predecessor + for newpred in ir.cfg.blocks[pred].preds + push!(worklist, newpred) end + end + return true +end - @label skip +# find the first dominating def for the given use +function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, use::Int, inclusive::Bool=false) + useblock = block_for_inst(ir.cfg, use) + curblock = find_curblock(domtree, allblocks, useblock) + curblock === nothing && return nothing + local def = 0 + for idx in fdu.defs + if block_for_inst(ir.cfg, idx) == curblock + if curblock != useblock + # Find the last def in this block + def = max(def, idx) + else + # Find the last def before our use + if inclusive + def = max(def, idx ≤ use ? idx : 0) + else + def = max(def, idx < use ? idx : 0) + end + end + end end + return def, useblock, curblock end -function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) - newex = Expr(:foreigncall) - nccallargs = length(origex.args[3]::SimpleVector) - for i in 1:(6+nccallargs-1) - push!(newex.args, origex.args[i]) +function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int) + # TODO: This can be much faster by looking at current level and only + # searching for those blocks in a sorted order + while !(curblock in allblocks) + curblock = domtree.idoms_bb[curblock] + curblock == 0 && return nothing end - for i in (6+nccallargs):length(origex.args) - x = origex.args[i] - # don't need to preserve intermediaries - if isa(x, SSAValue) && x.id in intermediates - continue + return curblock +end + +function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + @assert dfu !== nothing "has_safe_def condition unsatisfied" + def, useblock, curblock = dfu + if def == 0 + if !haskey(phinodes, curblock) + # If this happens, we need to search the predecessors for defs. Which + # one doesn't matter - if it did, we'd have had a phinode + return compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) end - push!(newex.args, x) + # The use is the phinode + return phinodes[curblock] + else + return val_for_def_expr(ir, def, fidx) end - for i in 1:length(new_preserves) - push!(newex.args, new_preserves[i]) +end + +function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) + curblock = find_curblock(domtree, allblocks, curblock) + @assert curblock !== nothing "has_safe_def condition unsatisfied" + def = 0 + for stmt in fdu.defs + if block_for_inst(ir.cfg, stmt) == curblock + def = max(def, stmt) + end end - return newex + return def == 0 ? phinodes[curblock] : val_for_def_expr(ir, def, fidx) +end + +function val_for_def_expr(ir::IRCode, def::Int, fidx::Int) + ex = ir[SSAValue(def)][:inst] + if isexpr(ex, :new) || is_known_call(ex, tuple, ir) + return ex.args[1+fidx] + else + @assert is_known_call(ex, setfield!, ir) "invalid load forwarding" + return ex.args[4] + end +end + +function mark_dead_ssas!(ir::IRCode, deadssas::IdSet{SSAValue}, + revisit::Vector{Tuple{Vector{SSAValue},LivenessSet}}, eliminated::BitSet, + preserved::PreservedSets) + workingset = BitSet(1:length(revisit)) + while !isempty(workingset) + revisit_idx = pop!(workingset) + mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved, workingset, revisit_idx) + end +end + +function mark_dead_ssas!(ir::IRCode, deadssas::IdSet{SSAValue}, + revisit::Vector{Tuple{Vector{SSAValue},LivenessSet}}, eliminated::BitSet, + preserved::PreservedSets, workingset::BitSet, revisit_idx::Int) + related, Liveness = revisit[revisit_idx] + eliminable = SSAValue[] + for livepc in Liveness + livepc in eliminated && @goto next_live + ssa = SSAValue(livepc) + stmt = ir[ssa][:inst] + if isexpr(stmt, :new) + ssa in deadssas && @goto next_live + for new_revisit_idx in workingset + if ssa in revisit[new_revisit_idx][1] + delete!(workingset, new_revisit_idx) + if mark_dead_ssas!(ir, deadssas, + revisit, eliminated, + preserved, workingset, new_revisit_idx) + push!(eliminable, ssa) + @goto next_live + else + return false + end + end + end + return false + elseif is_known_call(stmt, setfield!, ir) + @assert length(stmt.args) ≥ 4 "invalid escape analysis" + obj = stmt.args[2] + val = stmt.args[4] + if isa(obj, SSAValue) + if obj in related + push!(eliminable, ssa) + @goto next_live + end + if isa(val, SSAValue) && val in related + if obj in deadssas + push!(eliminable, ssa) + @goto next_live + end + for new_revisit_idx in workingset + if obj in revisit[new_revisit_idx][1] + delete!(workingset, new_revisit_idx) + if mark_dead_ssas!(ir, deadssas, + revisit, eliminated, + preserved, workingset, new_revisit_idx) + push!(eliminable, ssa) + @goto next_live + else + return false + end + end + end + end + end + return false + elseif isexpr(stmt, :foreigncall) + livepc in preserved && @goto next_live + return false + else + return false + end + @label next_live + end + for ssa in related; push!(deadssas, ssa); end + for ssa in eliminable; push!(deadssas, ssa); end + return true end """ @@ -1084,15 +1234,15 @@ In addition to a simple DCE for unused values and allocations, this pass also nullifies `typeassert` calls that can be proved to be no-op, in order to allow LLVM to emit simpler code down the road. -Note that this pass is more effective after SROA optimization (i.e. `sroa_pass!`), +Note that this pass is more effective after SROA optimization (i.e. `linear_pass!`), since SROA often allows this pass to: - eliminate allocation of object whose field references are all replaced with scalar values, and - nullify `typeassert` call whose first operand has been replaced with a scalar value (, which may have introduced new type information that inference did not understand) -Also note that currently this pass _needs_ to run after `sroa_pass!`, because +Also note that currently this pass _needs_ to run after `linear_pass!`, because the `typeassert` elimination depends on the transformation by `canonicalize_typeassert!` done -within `sroa_pass!` which redirects references of `typeassert`ed value to the corresponding `PiNode`. +within `linear_pass!` which redirects references of `typeassert`ed value to the corresponding `PiNode`. """ function adce_pass!(ir::IRCode) phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes)) diff --git a/test/compiler/EscapeAnalysis/EAUtils.jl b/test/compiler/EscapeAnalysis/EAUtils.jl index 3ae9b41a0ddac..7ef50d5434932 100644 --- a/test/compiler/EscapeAnalysis/EAUtils.jl +++ b/test/compiler/EscapeAnalysis/EAUtils.jl @@ -71,8 +71,8 @@ import Core: CodeInstance, MethodInstance, CodeInfo import .CC: InferenceResult, OptimizationState, IRCode, copy as cccopy, - @timeit, convert_to_ircode, slot2reg, compact!, ssa_inlining_pass!, sroa_pass!, - adce_pass!, type_lift_pass!, JLOptions, verify_ir, verify_linetable + @timeit, convert_to_ircode, slot2reg, compact!, ssa_inlining_pass!, linear_pass!, + memory_opt_pass!, adce_pass!, type_lift_pass!, JLOptions, verify_ir, verify_linetable import .EA: analyze_escapes, ArgEscapeCache, EscapeInfo, EscapeState, is_ipo_profitable # when working outside of Core.Compiler, @@ -227,6 +227,7 @@ function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::Optimizati @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) # @timeit "verify 2" verify_ir(ir) @timeit "compact 2" ir = compact!(ir) + @timeit "SROA" ir, _ = linear_pass!(ir) if caller.linfo.specTypes === interp.entry_tt && interp.optimize try @timeit "[Local EA]" state = analyze_escapes(ir, nargs, true, get_escape_cache(interp)) @@ -240,7 +241,6 @@ function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::Optimizati interp.state = state interp.linfo = sv.linfo end - @timeit "SROA" ir = sroa_pass!(ir) @timeit "ADCE" ir = adce_pass!(ir) @timeit "type lift" ir = type_lift_pass!(ir) @timeit "compact 3" ir = compact!(ir) diff --git a/test/compiler/EscapeAnalysis/interprocedural.jl b/test/compiler/EscapeAnalysis/interprocedural.jl index eccdc710a6c12..42a2505e03c08 100644 --- a/test/compiler/EscapeAnalysis/interprocedural.jl +++ b/test/compiler/EscapeAnalysis/interprocedural.jl @@ -7,8 +7,6 @@ include(normpath(@__DIR__, "setup.jl")) # callsites # --------- -import .EA: ignore_argescape - noescape(a) = nothing noescape(a, b) = nothing function global_escape!(x) diff --git a/test/compiler/EscapeAnalysis/setup.jl b/test/compiler/EscapeAnalysis/setup.jl index 5123b18e2dfdd..4e7d6fb5159aa 100644 --- a/test/compiler/EscapeAnalysis/setup.jl +++ b/test/compiler/EscapeAnalysis/setup.jl @@ -2,6 +2,7 @@ include(normpath(@__DIR__, "EAUtils.jl")) using Test, Core.Compiler.EscapeAnalysis, .EAUtils import Core: Argument, SSAValue, ReturnNode const EA = Core.Compiler.EscapeAnalysis +import .EA: ignore_argescape isT(T) = (@nospecialize x) -> x === T isreturn(@nospecialize x) = isa(x, Core.ReturnNode) && isdefined(x, :val) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 128fd6cc84b7b..b10fe94c1d125 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -2,7 +2,9 @@ using Test using Base.Meta -using Core: PhiNode, SSAValue, GotoNode, PiNode, QuoteNode, ReturnNode, GotoIfNot +import Core: + CodeInfo, Argument, SSAValue, GotoNode, GotoIfNot, PiNode, PhiNode, + QuoteNode, ReturnNode include(normpath(@__DIR__, "irutils.jl")) @@ -12,7 +14,7 @@ include(normpath(@__DIR__, "irutils.jl")) ## Test that domsort doesn't mangle single-argument phis (#29262) let m = Meta.@lower 1 + 1 @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo src.code = Any[ # block 1 Expr(:call, :opaque), @@ -47,7 +49,7 @@ end # test that we don't stack-overflow in SNCA with large functions. let m = Meta.@lower 1 + 1 @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo code = Any[] N = 2^15 for i in 1:2:N @@ -73,30 +75,87 @@ end # SROA # ==== +import Core.Compiler: widenconst + +is_load_forwarded(src::CodeInfo) = !any(iscall((src, getfield)), src.code) +is_scalar_replaced(src::CodeInfo) = + is_load_forwarded(src) && !any(iscall((src, setfield!)), src.code) && !any(isnew, src.code) + +function is_load_forwarded(@nospecialize(T), src::CodeInfo) + for i in 1:length(src.code) + x = src.code[i] + if iscall((src, getfield), x) + widenconst(argextype(x.args[1], src)) <: T && return false + end + end + return true +end +function is_scalar_replaced(@nospecialize(T), src::CodeInfo) + is_load_forwarded(T, src) || return false + for i in 1:length(src.code) + x = src.code[i] + if iscall((src, setfield!), x) + widenconst(argextype(x.args[1], src)) <: T && return false + elseif isnew(x) + widenconst(argextype(SSAValue(i), src)) <: T && return false + end + end + return true +end + struct ImmutableXYZ; x; y; z; end mutable struct MutableXYZ; x; y; z; end +struct ImmutableOuter{T}; x::T; y::T; z::T; end +mutable struct MutableOuter{T}; x::T; y::T; z::T; end +struct ImmutableRef{T}; x::T; end +Base.getindex(r::ImmutableRef) = r.x +mutable struct SafeRef{T}; x::T; end +Base.getindex(s::SafeRef) = getfield(s, 1) +Base.setindex!(s::SafeRef, x) = setfield!(s, 1, x) + +# simple immutability +# ------------------- -# should optimize away very basic cases let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = (x, y, z) + xyz[1], xyz[2], xyz[3] + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end + +# simple mutability +# ----------------- + let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end end - -# should handle simple mutabilities let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) xyz.y = 42 xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) @test any(src.code) do @nospecialize x iscall((src, tuple), x) && x.args[2:end] == Any[#=x=# Core.Argument(2), 42, #=x=# Core.Argument(4)] @@ -107,19 +166,23 @@ let src = code_typed1((Any,Any,Any)) do x, y, z xyz.x, xyz.z = xyz.z, xyz.x xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) @test any(src.code) do @nospecialize x iscall((src, tuple), x) && x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] end end -# circumvent uninitialized fields as far as there is a solid `setfield!` definition + +# uninitialized fields +# -------------------- + +# safe cases let src = code_typed1() do r = Ref{Any}() r[] = 42 return r[] end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Bool,)) do cond r = Ref{Any}() @@ -131,7 +194,7 @@ let src = code_typed1((Bool,)) do cond return r[] end end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Bool,)) do cond r = Ref{Any}() @@ -142,7 +205,7 @@ let src = code_typed1((Bool,)) do cond end return r[] end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Bool,Bool,Any,Any,Any)) do c1, c2, x, y, z r = Ref{Any}() @@ -157,7 +220,16 @@ let src = code_typed1((Bool,Bool,Any,Any,Any)) do c1, c2, x, y, z end return r[] end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) +end + +# unsafe cases +let src = code_typed1() do + r = Ref{Any}() + return r[] + end + @test count(isnew, src.code) == 1 + @test count(iscall((src, getfield)), src.code) == 1 end let src = code_typed1((Bool,)) do cond r = Ref{Any}() @@ -167,7 +239,9 @@ let src = code_typed1((Bool,)) do cond return r[] end # N.B. `r` should be allocated since `cond` might be `false` and then it will be thrown - @test any(isnew, src.code) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 1 + @test count(iscall((src, getfield)), src.code) == 1 end let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y r = Ref{Any}() @@ -181,12 +255,119 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y return r[] end # N.B. `r` should be allocated since `c2` might be `false` and then it will be thrown - @test any(isnew, src.code) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 end -# should include a simple alias analysis -struct ImmutableOuter{T}; x::T; y::T; z::T; end -mutable struct MutableOuter{T}; x::T; y::T; z::T; end +# load forwarding +# --------------- +# even if allocation can't be eliminated + +# safe cases +for T in (ImmutableRef{Any}, Ref{Any}) + let src = @eval code_typed1((Bool,Any,)) do c, a + r = $T(a) + if c + return r[] + else + return r + end + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + end + let src = @eval code_typed1((Bool,String,)) do c, a + r = $T(a) + if c + return r[]::String # adce_pass! will further eliminate this type assert call also + else + return r + end + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + @test !any(iscall((src, typeassert)), src.code) + end + let src = @eval code_typed1((Bool,Any,)) do c, a + r = $T(a) + if c + return r[] + else + throw(r) + end + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + end +end +let src = code_typed1((Bool,Any,Any)) do c, a, b + r = Ref{Any}(a) + if c + return r[] + end + r[] = b + return r + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 1 + @test count(src.code) do @nospecialize x + isreturn(x) && x.val === Argument(3) # a + end == 1 +end + +# unsafe case +let src = code_typed1((Bool,Any,Any)) do c, a, b + r = Ref{Any}(a) + r[] = b + @noinline some_escape!(r) + return r[] + end + @test !is_load_forwarded(src) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 1 +end +let src = code_typed1((Bool,String,Regex)) do c, a, b + r1 = Ref{Any}(a) + r2 = Ref{Any}(b) + return ifelse(c, r1, r2)[] + end + r = only(findall(isreturn, src.code)) + v = (src.code[r]::Core.ReturnNode).val + @test v !== Argument(3) # a + @test v !== Argument(4) # b + @test_broken is_scalar_replaced(src) # ideally +end +let src = code_typed1((Bool,String,Regex)) do c, a, b + r1 = Ref{Any}(a) + r2 = Ref{Any}(b) + t = (r1, r2) + return t[c ? 1 : 2][] + end + r = only(findall(isreturn, src.code)) + v = (src.code[r]::Core.ReturnNode).val + @test v !== Argument(3) # a + @test v !== Argument(4) # b + @test_broken is_scalar_replaced(src) # ideally +end +let src = code_typed1((Bool,String,Regex)) do c, a, b + r1 = Ref{Any}(a) + r2 = Ref{Any}(b) + a = [r1, r2] + return a[c ? 1 : 2][] + end + r = only(findall(isreturn, src.code)) + v = (src.code[r]::Core.ReturnNode).val + @test v !== Argument(3) # a + @test v !== Argument(4) # b + @test_broken is_scalar_replaced(src) # ideally +end + +# aliased load forwarding +# ----------------------- + +# OK: immutable(immutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = ImmutableOuter(xyz, xyz, xyz) @@ -214,22 +395,21 @@ let src = code_typed1((Any,Any,Any)) do x, y, z end end -# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well -# OK: mutable(immutable(...)) case +# OK: immutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) t = (xyz,) v = t[1].x v, v, v end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) outer = ImmutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) @test any(src.code) do @nospecialize x iscall((src, tuple), x) && x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] @@ -240,32 +420,541 @@ let # this is a simple end to end test case, which demonstrates allocation elimi # NOTE this test case isn't so robust and might be subject to future changes of the broadcasting implementation, # in that case you don't really need to stick to keeping this test case around simple_sroa(s) = broadcast(identity, Ref(s)) + let src = code_typed1(simple_sroa, (String,)) + @test is_scalar_replaced(src) + end s = Base.inferencebarrier("julia")::String simple_sroa(s) # NOTE don't hard-code `"julia"` in `@allocated` clause and make sure to execute the # compiled code for `simple_sroa`, otherwise everything can be folded even without SROA @test @allocated(simple_sroa(s)) == 0 end -# FIXME: immutable(mutable(...)) case +let # some insanely nested example + src = code_typed1((Int,)) do x + (Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][] + end + @test is_scalar_replaced(src) +end + +# OK: mutable(immutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end +let src = code_typed1((String,String,String)) do x, y, z + xyz = (x, y, z) + r = Ref(xyz) + return r[][3], r[][2], r[][1] + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] + end +end + +# OK: mutable(mutable(...)) case +# new chain +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + outer = MutableOuter(xyz, xyz, xyz) + outer.x.x, outer.y.y, outer.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end end -# FIXME: mutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) + xyz.x, xyz.y, xyz.z = z, y, x outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] + end end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + xyz.x, xyz.y, xyz.z = xyz.z, xyz.y, xyz.x + outer = MutableOuter(xyz, xyz, xyz) + outer.x.x, outer.y.y, outer.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + inner = MutableOuter(xyz, xyz, xyz) + outer = MutableOuter(inner, inner, inner) + outer.x.x.x, outer.y.y.y, outer.z.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + xyz.x, xyz.y, xyz.z = z, y, x + inner = MutableOuter(xyz, xyz, xyz) + outer = MutableOuter(inner, inner, inner) + outer.x.x.x, outer.y.y.y, outer.z.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] + end +end +# setfield! chain +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + outer = Ref{MutableXYZ}() + outer[] = xyz + return outer[].x, outer[].y, outer[].z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + outer = Ref{MutableXYZ}() + outer[] = xyz + xyz.z = 42 + return outer[].x, outer[].y, outer[].z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), 42] + end +end + +# ϕ-allocation elimination +# ------------------------ -let # should work with constant globals - # immutable case - # -------------- +# safe cases +let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + end + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 +end +let src = code_typed1((Bool,Bool,Any,Any,Any)) do cond1, cond2, x, y, z + if cond1 + ϕ = Ref{Any}(x) + elseif cond2 + ϕ = Ref{Any}(y) + else + ϕ = Ref{Any}(z) + end + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(4) in x.values && + #=y=# Core.Argument(5) in x.values && + #=z=# Core.Argument(6) in x.values + end == 1 +end +let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + end + ϕ[] = z + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.ReturnNode) && + #=z=# Core.Argument(5) === x.val + end == 1 +end +let src = code_typed1((Bool,Any,Any,)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + out1 = ϕ[] + else + ϕ = Ref{Any}(y) + out1 = ϕ[] + end + out2 = ϕ[] + out1, out2 + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 2 +end +let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + ϕ[] = z + end + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 +end +let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = Ref{Any}(x) + out1 = ϕ[] + else + ϕ = Ref{Any}(y) + out1 = ϕ[] + ϕ[] = z + end + out2 = ϕ[] + out1, out2 + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 +end +let src = code_typed1((Bool,Any,Any)) do cond, x, y + # these allocation form multiple ϕ-nodes + if cond + ϕ2 = ϕ1 = Ref{Any}(x) + else + ϕ2 = ϕ1 = Ref{Any}(y) + end + ϕ1[], ϕ2[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 2 +end +let src = code_typed1((Bool,String,)) do cond, x + # these allocation form multiple ϕ-nodes + if cond + ϕ2 = ϕ1 = Ref{Any}("foo") + else + ϕ2 = ϕ1 = Ref{Any}("bar") + end + ϕ2[] = x + y = ϕ1[] # => x + return y + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.ReturnNode) && + #=x=# x.val === Core.Argument(3) + end == 1 +end +let src = code_typed1((Bool,Any,Any,)) do cond, x, y + x′ = Ref{Any}(x) + y′ = Ref{Any}(y) + if cond + ϕ = x′ + else + ϕ = y′ + end + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 +end + +# unsafe cases +let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + end + some_escape!(ϕ) + ϕ[] + end + @test count(isnew, src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 +end +let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + some_escape!(ϕ) + else + ϕ = Ref{Any}(y) + end + ϕ[] + end + @test count(isnew, src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 +end +let src = code_typed1((Bool,Any,)) do cond, x + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}() + end + ϕ[] + end + @test count(isnew, src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 +end +let src = code_typed1((Bool,Any)) do c, a + local r + if c + r = Ref{Any}(a) + end + (r::Base.RefValue{Any})[] + end + @test count(isnew, src.code) == 1 + @test count(iscall((src, getfield)), src.code) == 1 +end + +function mutable_ϕ_elim(x, xs) + r = Ref(x) + for x in xs + r = Ref(x) + end + return r[] +end +let src = code_typed1(mutable_ϕ_elim, (String, Vector{String})) + @test is_scalar_replaced(src) + + xs = String[string(gensym()) for _ in 1:100] + mutable_ϕ_elim("init", xs) + @test @allocated(mutable_ϕ_elim("init", xs)) == 0 +end + +@noinline mightaliase_noinline(a, b) = Base.mightalias(a, b) +function assert_no_alias!(a, b, c) + x = Ref(a) + y = Ref(b) + @assert !mightaliase_noinline(x[], y[]) # shouldn't be transformed to `mightaliase_noinline(b, b)` + z = c ? x : y + z +end +let src = code_typed1(assert_no_alias!, (Vector{Any}, Vector{Any}, Bool,)) + @test count(src.code) do @nospecialize x + if isinvoke(:mightaliase_noinline, x) + if x.args[3] === Argument(2) # a + if x.args[4] === Argument(3) # b + return true + end + end + end + return false + end == 1 + a = Any[1,2,3] + b = Any[1,2,3] + @test assert_no_alias!(a, b, true)[] === a +end + +# demonstrate the power of our field / alias analysis with realistic end to end examples +# adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B +abstract type AbstractPoint{T} end +struct Point{T} <: AbstractPoint{T} + x::T + y::T +end +mutable struct MPoint{T} <: AbstractPoint{T} + x::T + y::T +end +add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y) +function compute_point(T, n, ax, ay, bx, by) + a = T(ax, ay) + b = T(bx, by) + for i in 0:(n-1) + a = add(add(a, b), b) + end + a.x, a.y +end +function compute_point(n, a, b) + for i in 0:(n-1) + a = add(add(a, b), b) + end + a.x, a.y +end +function compute_point!(n, a, b) + for i in 0:(n-1) + a′ = add(add(a, b), b) + a.x = a′.x + a.y = a′.y + end +end + +let # immutable case + src = code_typed1((Int,)) do n + compute_point(Point, n, 1+.5, 2+.5, 2+.25, 4+.75) + end + @test is_scalar_replaced(Point, src) + src = code_typed1((Int,)) do n + compute_point(Point, n, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + end + @test is_scalar_replaced(Point, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) + + # mutable case + src = code_typed1((Int,)) do n + compute_point(MPoint, n, 1+.5, 2+.5, 2+.25, 4+.75) + end + @test is_scalar_replaced(MPoint, src) + src = code_typed1((Int,)) do n + compute_point(MPoint, n, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + end + @test is_scalar_replaced(MPoint, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) +end +compute_point(MPoint, 10, 1+.5, 2+.5, 2+.25, 4+.75) +compute_point(MPoint, 10, 1+.5im, 2+.5im, 2+.25im, 4+.75im) +@test @allocated(compute_point(MPoint, 10000, 1+.5, 2+.5, 2+.25, 4+.75)) == 0 +@test @allocated(compute_point(MPoint, 10000, 1+.5im, 2+.5im, 2+.25im, 4+.75im)) == 0 + +let # immutable case + src = code_typed1((Int,)) do n + compute_point(n, Point(1+.5, 2+.5), Point(2+.25, 4+.75)) + end + @test is_scalar_replaced(Point, src) + src = code_typed1((Int,)) do n + compute_point(n, Point(1+.5im, 2+.5im), Point(2+.25im, 4+.75im)) + end + @test is_scalar_replaced(Point, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) + + # mutable case + src = code_typed1((Int,)) do n + compute_point(n, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75)) + end + @test is_scalar_replaced(MPoint, src) + src = code_typed1((Int,)) do n + compute_point(n, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + end + @test is_scalar_replaced(MPoint, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) +end +compute_point(10, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75)) +compute_point(10, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) +@test @allocated(compute_point(10000, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75))) == 0 +@test @allocated(compute_point(10000, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im))) == 0 + +let # mutable case + src = code_typed1(compute_point!, (Int,MPoint{Float64},MPoint{Float64})) + @test is_scalar_replaced(MPoint, src) + src = code_typed1(compute_point!, (Int,MPoint{ComplexF64},MPoint{ComplexF64})) + @test is_scalar_replaced(MPoint, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) +end +let + af, bf = MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75) + ac, bc = MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im) + compute_point!(10, af, bf) + compute_point!(10, ac, bc) + @test @allocated(compute_point!(10000, af, bf)) == 0 + @test @allocated(compute_point!(10000, ac, bc)) == 0 +end + +# isdefined elimination +# --------------------- + +let src = code_typed1((Any,)) do a + r = Ref{Any}() + r[] = a + if isassigned(r) + return r[] + end + return nothing + end + @test is_scalar_replaced(src) +end + +callit(f, args...) = f(args...) +function isdefined_elim() + local arr::Vector{Any} + callit() do + arr = Any[] + end + return arr +end +let src = code_typed1(isdefined_elim) + @test is_scalar_replaced(src) +end +@test isdefined_elim() == Any[] + +# preserve elimination +# -------------------- + +let src = code_typed1((String,)) do s + ccall(:some_ccall, Cint, (Ptr{String},), Ref(s)) + end + @test count(isnew, src.code) == 0 +end + +# if the mutable struct is directly used, we shouldn't eliminate it +let src = code_typed1() do + a = MutableXYZ(-512275808,882558299,-2133022131) + b = Int32(42) + ccall(:some_ccall, Cvoid, (MutableXYZ, Int32), a, b) + return a.x + end + @test count(isnew, src.code) == 1 +end + +# constant globals +# ---------------- + +let # immutable case src = @eval Module() begin const REF_FLD = :x struct ImmutableRef{T} @@ -282,7 +971,6 @@ let # should work with constant globals @test count(isnew, src.code) == 0 # mutable case - # ------------ src = @eval Module() begin const REF_FLD = :x code_typed() do @@ -295,25 +983,6 @@ let # should work with constant globals @test count(isnew, src.code) == 0 end -# should work nicely with inlining to optimize away a complicated case -# adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B -struct Point - x::Float64 - y::Float64 -end -#=@inline=# add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y) -function compute_points() - a = Point(1.5, 2.5) - b = Point(2.25, 4.75) - for i in 0:(100000000-1) - a = add(add(a, b), b) - end - a.x, a.y -end -let src = code_typed1(compute_points) - @test !any(isnew, src.code) -end - # comparison lifting # ================== @@ -454,7 +1123,7 @@ end # A SSAValue after the compaction line let m = Meta.@lower 1 + 1 @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo src.code = Any[ # block 1 nothing, @@ -492,7 +1161,7 @@ let m = Meta.@lower 1 + 1 src.ssaflags = fill(Int32(0), nstmts) ir = Core.Compiler.inflate_ir(src, Any[], Any[Any, Any]) @test Core.Compiler.verify_ir(ir) === nothing - ir = @test_nowarn Core.Compiler.sroa_pass!(ir) + ir, = @test_nowarn Core.Compiler.linear_pass!(ir) @test Core.Compiler.verify_ir(ir) === nothing end @@ -517,7 +1186,7 @@ end let m = Meta.@lower 1 + 1 # Test that CFG simplify combines redundant basic blocks @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo src.code = Any[ Core.Compiler.GotoNode(2), Core.Compiler.GotoNode(3), @@ -542,7 +1211,7 @@ end let m = Meta.@lower 1 + 1 # Test that CFG simplify doesn't mess up when chaining past return blocks @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo src.code = Any[ Core.Compiler.GotoIfNot(Core.Compiler.Argument(2), 3), Core.Compiler.GotoNode(4), @@ -572,7 +1241,7 @@ let m = Meta.@lower 1 + 1 # Test that CFG simplify doesn't try to merge every block in a loop into # its predecessor @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo src.code = Any[ # Block 1 Core.Compiler.GotoNode(2), From a4561b0614dacb87139f271cba8a4f644b8ce0c1 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Fri, 5 Apr 2019 15:23:27 -0400 Subject: [PATCH 3/5] Begin rebase of 42465 onto 43888 --- base/array.jl | 30 ++++++++++--- base/compiler/ssair/ir.jl | 7 ++++ base/compiler/ssair/passes.jl | 74 +++++++++++++++++++++++++++++++++ base/compiler/tfuncs.jl | 15 +++++++ base/dict.jl | 2 +- base/experimental.jl | 2 + src/builtin_proto.h | 3 ++ src/builtins.c | 61 ++++++++++++++++++++++++++- src/cgutils.cpp | 2 +- src/codegen.cpp | 31 +++++++++++++- src/datatype.c | 7 +++- src/gc.c | 6 ++- src/intrinsics.cpp | 2 +- src/jl_exported_data.inc | 2 + src/jltypes.c | 10 +++++ src/julia.h | 29 ++++++++++--- src/llvm-late-gc-lowering.cpp | 24 +++++++++++ src/llvm-pass-helpers.cpp | 4 +- src/llvm-pass-helpers.h | 1 + src/rtutils.c | 4 +- src/staticdata.c | 8 +++- test/choosetests.jl | 3 +- test/compiler/immutablearray.jl | 12 ++++++ 23 files changed, 311 insertions(+), 28 deletions(-) create mode 100644 test/compiler/immutablearray.jl diff --git a/base/array.jl b/base/array.jl index cf5bbc05e412a..dfcaed3725a1a 100644 --- a/base/array.jl +++ b/base/array.jl @@ -147,12 +147,20 @@ function vect(X...) return copyto!(Vector{T}(undef, length(X)), X) end -size(a::Array, d::Integer) = arraysize(a, convert(Int, d)) -size(a::Vector) = (arraysize(a,1),) -size(a::Matrix) = (arraysize(a,1), arraysize(a,2)) -size(a::Array{<:Any,N}) where {N} = (@inline; ntuple(M -> size(a, M), Val(N))::Dims) +const ImmutableArray = Core.ImmutableArray +const IMArray{T,N} = Union{Array{T, N}, ImmutableArray{T,N}} +const IMVector{T} = IMArray{T, 1} +const IMMatrix{T} = IMArray{T, 2} -asize_from(a::Array, n) = n > ndims(a) ? () : (arraysize(a,n), asize_from(a, n+1)...) +ImmutableArray(a::Array) = Core.arrayfreeze(a) +Array(a::ImmutableArray) = Core.arraythaw(a) + +size(a::IMArray, d::Integer) = arraysize(a, convert(Int, d)) +size(a::IMVector) = (arraysize(a,1),) +size(a::IMMatrix) = (arraysize(a,1), arraysize(a,2)) +size(a::IMArray{<:Any,N}) where {N} = (@_inline_meta; ntuple(M -> size(a, M), Val(N))::Dims) + +asize_from(a::IMArray, n) = n > ndims(a) ? () : (arraysize(a,n), asize_from(a, n+1)...) allocatedinline(T::Type) = (@_pure_meta; ccall(:jl_stored_inline, Cint, (Any,), T) != Cint(0)) @@ -223,6 +231,13 @@ function isassigned(a::Array, i::Int...) ccall(:jl_array_isassigned, Cint, (Any, UInt), a, ii) == 1 end +function isassigned(a::ImmutableArray, i::Int...) + @_inline_meta + ii = (_sub2ind(size(a), i...) % UInt) - 1 + @boundscheck ii < length(a) % UInt || return false + ccall(:jl_array_isassigned, Cint, (Any, UInt), a, ii) == 1 +end + ## copy ## """ @@ -921,7 +936,10 @@ function getindex end @eval getindex(A::Array, i1::Int) = arrayref($(Expr(:boundscheck)), A, i1) @eval getindex(A::Array, i1::Int, i2::Int, I::Int...) = (@inline; arrayref($(Expr(:boundscheck)), A, i1, i2, I...)) -# Faster contiguous indexing using copyto! for AbstractUnitRange and Colon +@eval getindex(A::ImmutableArray, i1::Int) = arrayref($(Expr(:boundscheck)), A, i1) +@eval getindex(A::ImmutableArray, i1::Int, i2::Int, I::Int...) = (@_inline_meta; arrayref($(Expr(:boundscheck)), A, i1, i2, I...)) + +# Faster contiguous indexing using copyto! for UnitRange and Colon function getindex(A::Array, I::AbstractUnitRange{<:Integer}) @inline @boundscheck checkbounds(A, I) diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index a86e125fcb307..41a69b6d25dd4 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -319,6 +319,13 @@ function setindex!(x::IRCode, repl::Instruction, s::SSAValue) return x end +function ssadominates(ir::IRCode, domtree::DomTree, ssa1::Int, ssa2::Int) + bb1 = block_for_inst(ir.cfg, ssa1) + bb2 = block_for_inst(ir.cfg, ssa2) + bb1 == bb2 && return ssa1 < ssa2 + return dominates(domtree, bb1, bb2) +end + # SSA values that need renaming struct OldSSAValue id::Int diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 4c25654f83f1b..1843de2f395f2 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1760,3 +1760,77 @@ function cfg_simplify!(ir::IRCode) compact.active_result_bb = length(bb_starts) return finish(compact) end + +function is_allocation(stmt) + isexpr(stmt, :foreigncall) || return false + s = stmt.args[1] + isa(s, QuoteNode) && (s = s.value) + return s === :jl_alloc_array_1d +end + +function memory_opt!(ir::IRCode) + compact = IncrementalCompact(ir, false) + uses = IdDict{Int, Vector{Int}}() + relevant = IdSet{Int}() + revisit = Int[] + function mark_val(val) + isa(val, SSAValue) || return + val.id in relevant && pop!(relevant, val.id) + end + for ((_, idx), stmt) in compact + if isa(stmt, ReturnNode) + isdefined(stmt, :val) || continue + val = stmt.val + if isa(val, SSAValue) && val.id in relevant + (haskey(uses, val.id)) || (uses[val.id] = Int[]) + push!(uses[val.id], idx) + end + continue + end + (isexpr(stmt, :call) || isexpr(stmt, :foreigncall)) || continue + if is_allocation(stmt) + push!(relevant, idx) + # TODO: Mark everything else here + continue + end + # TODO: Replace this by interprocedural escape analysis + if is_known_call(stmt, arrayset, compact) + # The value being set escapes, everything else doesn't + mark_val(stmt.args[4]) + arr = stmt.args[3] + if isa(arr, SSAValue) && arr.id in relevant + (haskey(uses, arr.id)) || (uses[arr.id] = Int[]) + push!(uses[arr.id], idx) + end + elseif is_known_call(stmt, Core.arrayfreeze, compact) && isa(stmt.args[2], SSAValue) + push!(revisit, idx) + else + # For now we assume everything escapes + # TODO: We could handle PhiNodes specially and improve this + for ur in userefs(stmt) + mark_val(ur[]) + end + end + end + ir = finish(compact) + isempty(revisit) && return ir + domtree = construct_domtree(ir.cfg.blocks) + for idx in revisit + # Make sure that the value we reference didn't escape + id = ir.stmts[idx][:inst].args[2].id + (id in relevant) || continue + + # We're ok to steal the memory if we don't dominate any uses + ok = true + for use in uses[id] + if ssadominates(ir, domtree, idx, use) + ok = false + break + end + end + ok || continue + + ir.stmts[idx][:inst].args[1] = Core.mutating_arrayfreeze + end + return ir +end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index e67594f196c90..71845fe13b503 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -1803,6 +1803,21 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp sv::Union{InferenceState,Nothing}) if f === tuple return tuple_tfunc(argtypes) + elseif f === Core.arrayfreeze || f === Core.arraythaw + if length(argtypes) != 1 + isva && return Any + return Bottom + end + a = widenconst(argtypes[1]) + at = (f === Core.arrayfreeze ? Array : ImmutableArray) + rt = (f === Core.arrayfreeze ? ImmutableArray : Array) + if a <: at + unw = unwrap_unionall(a) + if isa(unw, DataType) + return rewrap_unionall(rt{unw.parameters[1], unw.parameters[2]}, a) + end + end + return rt end if isa(f, IntrinsicFunction) if is_pure_intrinsic_infer(f) && _all(@nospecialize(a) -> isa(a, Const), argtypes) diff --git a/base/dict.jl b/base/dict.jl index dabdfa5c34773..83ef7f423f7e7 100644 --- a/base/dict.jl +++ b/base/dict.jl @@ -373,7 +373,7 @@ end function setindex!(h::Dict{K,V}, v0, key0) where V where K key = convert(K, key0) if !isequal(key, key0) - throw(ArgumentError("$(limitrepr(key0)) is not a valid key for type $K")) + throw(KeyTypeError(K, key0)) end setindex!(h, v0, key) end diff --git a/base/experimental.jl b/base/experimental.jl index d5af876cbbb23..12fcb9273b4b6 100644 --- a/base/experimental.jl +++ b/base/experimental.jl @@ -11,6 +11,8 @@ module Experimental using Base: Threads, sync_varname using Base.Meta +using Base: ImmutableArray + """ Const(A::Array) diff --git a/src/builtin_proto.h b/src/builtin_proto.h index 7b11813e7a58b..c8daaf800040a 100644 --- a/src/builtin_proto.h +++ b/src/builtin_proto.h @@ -55,6 +55,9 @@ DECLARE_BUILTIN(_typebody); DECLARE_BUILTIN(typeof); DECLARE_BUILTIN(_typevar); DECLARE_BUILTIN(donotdelete); +DECLARE_BUILTIN(arrayfreeze); +DECLARE_BUILTIN(arraythaw); +DECLARE_BUILTIN(mutating_arrayfreeze); JL_CALLABLE(jl_f_invoke_kwsorter); #ifdef DEFINE_BUILTIN_GLOBALS diff --git a/src/builtins.c b/src/builtins.c index ca2f56adaf6d8..0f3129f57293c 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -1373,7 +1373,9 @@ JL_CALLABLE(jl_f__typevar) JL_CALLABLE(jl_f_arraysize) { JL_NARGS(arraysize, 2, 2); - JL_TYPECHK(arraysize, array, args[0]); + if (!jl_is_arrayish(args[0])) { + jl_type_error("arraysize", (jl_value_t*)jl_array_type, args[0]); + } jl_array_t *a = (jl_array_t*)args[0]; size_t nd = jl_array_ndims(a); JL_TYPECHK(arraysize, long, args[1]); @@ -1412,7 +1414,9 @@ JL_CALLABLE(jl_f_arrayref) { JL_NARGSV(arrayref, 3); JL_TYPECHK(arrayref, bool, args[0]); - JL_TYPECHK(arrayref, array, args[1]); + if (!jl_is_arrayish(args[1])) { + jl_type_error("arrayref", (jl_value_t*)jl_array_type, args[1]); + } jl_array_t *a = (jl_array_t*)args[1]; size_t i = array_nd_index(a, &args[2], nargs - 2, "arrayref"); return jl_arrayref(a, i); @@ -1735,6 +1739,54 @@ JL_CALLABLE(jl_f_set_binding_type) return jl_nothing; } +JL_CALLABLE(jl_f_arrayfreeze) +{ + JL_NARGSV(arrayfreeze, 1); + JL_TYPECHK(arrayfreeze, array, args[0]); + jl_array_t *a = (jl_array_t*)args[0]; + jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_immutable_array_type, + jl_tparam0(jl_typeof(a)), jl_tparam1(jl_typeof(a))); + JL_GC_PUSH1(&it); + // The idea is to elide this copy if the compiler or runtime can prove that + // doing so is safe to do. + jl_array_t *na = jl_array_copy(a); + jl_set_typeof(na, it); + JL_GC_POP(); + return (jl_value_t*)na; +} + +JL_CALLABLE(jl_f_mutating_arrayfreeze) +{ + // N.B.: These error checks pretend to be arrayfreeze since this is a drop + // in replacement and we don't want to change the visible error type in the + // optimizer + JL_NARGSV(arrayfreeze, 1); + JL_TYPECHK(arrayfreeze, array, args[0]); + jl_array_t *a = (jl_array_t*)args[0]; + jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_immutable_array_type, + jl_tparam0(jl_typeof(a)), jl_tparam1(jl_typeof(a))); + jl_set_typeof(a, it); + return (jl_value_t*)a; +} + +JL_CALLABLE(jl_f_arraythaw) +{ + JL_NARGSV(arraythaw, 1); + if (((jl_datatype_t*)jl_typeof(args[0]))->name != jl_immutable_array_typename) { + jl_type_error("arraythaw", (jl_value_t*)jl_immutable_array_type, args[0]); + } + jl_array_t *a = (jl_array_t*)args[0]; + jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_array_type, + jl_tparam0(jl_typeof(a)), jl_tparam1(jl_typeof(a))); + JL_GC_PUSH1(&it); + // The idea is to elide this copy if the compiler or runtime can prove that + // doing so is safe to do. + jl_array_t *na = jl_array_copy(a); + jl_set_typeof(na, it); + JL_GC_POP(); + return (jl_value_t*)na; +} + // IntrinsicFunctions --------------------------------------------------------- static void (*runtime_fp[num_intrinsics])(void); @@ -1890,6 +1942,10 @@ void jl_init_primitives(void) JL_GC_DISABLED jl_builtin_arrayset = add_builtin_func("arrayset", jl_f_arrayset); jl_builtin_arraysize = add_builtin_func("arraysize", jl_f_arraysize); + jl_builtin_arrayfreeze = add_builtin_func("arrayfreeze", jl_f_arrayfreeze); + jl_builtin_mutating_arrayfreeze = add_builtin_func("mutating_arrayfreeze", jl_f_mutating_arrayfreeze); + jl_builtin_arraythaw = add_builtin_func("arraythaw", jl_f_arraythaw); + // method table utils jl_builtin_applicable = add_builtin_func("applicable", jl_f_applicable); jl_builtin_invoke = add_builtin_func("invoke", jl_f_invoke); @@ -1965,6 +2021,7 @@ void jl_init_primitives(void) JL_GC_DISABLED add_builtin("AbstractArray", (jl_value_t*)jl_abstractarray_type); add_builtin("DenseArray", (jl_value_t*)jl_densearray_type); add_builtin("Array", (jl_value_t*)jl_array_type); + add_builtin("ImmutableArray", (jl_value_t*)jl_immutable_array_type); add_builtin("Expr", (jl_value_t*)jl_expr_type); add_builtin("LineNumberNode", (jl_value_t*)jl_linenumbernode_type); diff --git a/src/cgutils.cpp b/src/cgutils.cpp index b219498315905..09822257db3b2 100644 --- a/src/cgutils.cpp +++ b/src/cgutils.cpp @@ -491,7 +491,7 @@ static Type *_julia_type_to_llvm(jl_codegen_params_t *ctx, LLVMContext &ctxt, jl if (isboxed) *isboxed = false; if (jt == (jl_value_t*)jl_bottom_type) return getVoidTy(ctxt); - if (jl_is_concrete_immutable(jt)) { + if (jl_is_concrete_immutable(jt) && !jl_is_arrayish_type(jt)) { if (jl_datatype_nbits(jt) == 0) return getVoidTy(ctxt); Type *t = _julia_struct_to_llvm(ctx, ctxt, jt, isboxed); diff --git a/src/codegen.cpp b/src/codegen.cpp index 46f6899028f13..e748742b1ebb2 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1054,6 +1054,15 @@ static const auto pointer_from_objref_func = new JuliaFunction{ Attributes(C, {Attribute::NonNull}), None); }, }; +static const auto mutating_arrayfreeze_func = new JuliaFunction{ + "julia.mutating_arrayfreeze", + [](LLVMContext &C) { return FunctionType::get(T_prjlvalue, + {T_prjlvalue, T_prjlvalue}, false); }, + [](LLVMContext &C) { return AttributeList::get(C, + Attributes(C, {Attribute::NoUnwind, Attribute::NoRecurse}), + Attributes(C, {Attribute::NonNull}), + None); }, +}; static const auto jltuple_func = new JuliaFunction{XSTR(jl_f_tuple), get_func_sig, get_func_attrs}; static std::map builtin_func_map; @@ -1125,7 +1134,7 @@ static bool deserves_retbox(jl_value_t* t) static bool deserves_sret(jl_value_t *dt, Type *T) { assert(jl_is_datatype(dt)); - return (size_t)jl_datatype_size(dt) > sizeof(void*) && !T->isFloatingPointTy() && !T->isVectorTy(); + return (size_t)jl_datatype_size(dt) > sizeof(void*) && !T->isFloatingPointTy() && !T->isVectorTy() && !jl_is_arrayish_type(dt); } @@ -2996,6 +3005,21 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f, } } + else if (f == jl_builtin_mutating_arrayfreeze && nargs == 1) { + const jl_cgval_t &ary = argv[1]; + jl_value_t *aty_dt = jl_unwrap_unionall(ary.typ); + if (jl_is_array_type(aty_dt)) { + jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_immutable_array_type, + jl_tparam0(aty_dt), jl_tparam1(aty_dt)); + *ret = mark_julia_type(ctx, + ctx.builder.CreateCall(prepare_call(mutating_arrayfreeze_func), + { boxed(ctx, ary), + track_pjlvalue(ctx, literal_pointer_val(ctx, (jl_value_t*)it)) }), true, it); + return true; + } + return false; + } + else if (f == jl_builtin_arrayset && nargs >= 4) { const jl_cgval_t &ary = argv[2]; jl_cgval_t val = argv[3]; @@ -8176,7 +8200,10 @@ extern "C" void jl_init_llvm(void) { jl_f_arrayset_addr, new JuliaFunction{XSTR(jl_f_arrayset), get_func_sig, get_func_attrs} }, { jl_f_arraysize_addr, new JuliaFunction{XSTR(jl_f_arraysize), get_func_sig, get_func_attrs} }, { jl_f_apply_type_addr, new JuliaFunction{XSTR(jl_f_apply_type), get_func_sig, get_func_attrs} }, - { jl_f_donotdelete_addr, new JuliaFunction{XSTR(jl_f_donotdelete), get_func_sig, get_donotdelete_func_attrs} } + { jl_f_donotdelete_addr, new JuliaFunction{XSTR(jl_f_donotdelete), get_func_sig, get_donotdelete_func_attrs} }, + { jl_f_arrayfreeze_addr, new JuliaFunction{XSTR(jl_f_arrayfreeze), get_func_sig, get_func_attrs} }, + { jl_f_arraythaw_addr, new JuliaFunction{XSTR(jl_f_arraythaw), get_func_sig, get_func_attrs} }, + { jl_f_mutating_arrayfreeze_addr, new JuliaFunction{XSTR(jl_f_mutating_arrayfreeze), get_func_sig, get_func_attrs} } }; jl_default_debug_info_kind = (int) DICompileUnit::DebugEmissionKind::FullDebug; diff --git a/src/datatype.c b/src/datatype.c index e7f1ab22365b8..36ea0e08bced1 100644 --- a/src/datatype.c +++ b/src/datatype.c @@ -223,7 +223,8 @@ unsigned jl_special_vector_alignment(size_t nfields, jl_value_t *t) STATIC_INLINE int jl_is_datatype_make_singleton(jl_datatype_t *d) JL_NOTSAFEPOINT { - return (!d->name->abstract && jl_datatype_size(d) == 0 && d != jl_symbol_type && d->name != jl_array_typename && + return (!d->name->abstract && jl_datatype_size(d) == 0 && d != jl_symbol_type && + d->name != jl_array_typename && d->name != jl_immutable_array_typename && d->isconcretetype && !d->name->mutabl); } @@ -395,7 +396,9 @@ void jl_compute_field_offsets(jl_datatype_t *st) st->layout = &opaque_byte_layout; return; } - else if (st == jl_simplevector_type || st == jl_module_type || st->name == jl_array_typename) { + else if (st == jl_simplevector_type || st == jl_module_type || + st->name == jl_array_typename || + st->name == jl_immutable_array_typename) { static const jl_datatype_layout_t opaque_ptr_layout = {0, 1, -1, sizeof(void*), 0, 0}; st->layout = &opaque_ptr_layout; return; diff --git a/src/gc.c b/src/gc.c index 609c2009bf103..ab3a28da55dc7 100644 --- a/src/gc.c +++ b/src/gc.c @@ -859,7 +859,8 @@ void jl_gc_force_mark_old(jl_ptls_t ptls, jl_value_t *v) JL_NOTSAFEPOINT size_t l = jl_svec_len(v); dtsz = l * sizeof(void*) + sizeof(jl_svec_t); } - else if (dt->name == jl_array_typename) { + else if (dt->name == jl_array_typename || + dt->name == jl_immutable_array_typename) { jl_array_t *a = (jl_array_t*)v; if (!a->flags.pooled) dtsz = GC_MAX_SZCLASS + 1; @@ -2560,7 +2561,8 @@ mark: { objary = (gc_mark_objarray_t*)sp.data; goto objarray_loaded; } - else if (vt->name == jl_array_typename) { + else if (vt->name == jl_array_typename || + vt->name == jl_immutable_array_typename) { jl_array_t *a = (jl_array_t*)new_obj; jl_array_flags_t flags = a->flags; if (update_meta) { diff --git a/src/intrinsics.cpp b/src/intrinsics.cpp index 4ca4794ab7733..1d889827b73d6 100644 --- a/src/intrinsics.cpp +++ b/src/intrinsics.cpp @@ -1081,7 +1081,7 @@ static jl_cgval_t emit_intrinsic(jl_codectx_t &ctx, intrinsic f, jl_value_t **ar assert(nargs == 1); const jl_cgval_t &x = argv[0]; jl_value_t *typ = jl_unwrap_unionall(x.typ); - if (!jl_is_datatype(typ) || ((jl_datatype_t*)typ)->name != jl_array_typename) + if (!jl_is_arrayish_type(typ)) return emit_runtime_call(ctx, f, argv, nargs); return mark_julia_type(ctx, emit_arraylen(ctx, x), false, jl_long_type); } diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index 09d2949c22489..bf2e9ac0e4bd0 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -17,6 +17,8 @@ XX(jl_array_symbol_type) \ XX(jl_array_type) \ XX(jl_array_typename) \ + XX(jl_immutable_array_type) \ + XX(jl_immutable_array_typename) \ XX(jl_array_uint8_type) \ XX(jl_array_uint64_type) \ XX(jl_atomicerror_type) \ diff --git a/src/jltypes.c b/src/jltypes.c index 86630ac39c059..92de7c10ef178 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2261,6 +2261,15 @@ void jl_init_types(void) JL_GC_DISABLED jl_atomic_store_relaxed(&jl_nonfunction_mt->leafcache, (jl_array_t*)jl_an_empty_vec_any); jl_atomic_store_relaxed(&jl_type_type_mt->leafcache, (jl_array_t*)jl_an_empty_vec_any); + tv = jl_svec2(tvar("T"), tvar("N")); + jl_immutable_array_type = (jl_unionall_t*) + jl_new_datatype(jl_symbol("ImmutableArray"), core, + (jl_datatype_t*) + jl_apply_type((jl_value_t*)jl_densearray_type, jl_svec_data(tv), 2), + tv, jl_emptysvec, jl_emptysvec, jl_emptysvec, 0, 0, 0)->name->wrapper; + jl_immutable_array_typename = ((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_immutable_array_type))->name; + jl_compute_field_offsets((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_immutable_array_type)); + jl_expr_type = jl_new_datatype(jl_symbol("Expr"), core, jl_any_type, jl_emptysvec, @@ -2687,6 +2696,7 @@ void jl_init_types(void) JL_GC_DISABLED // override the preferred layout for a couple types jl_lineinfonode_type->name->mayinlinealloc = 0; // FIXME: assumed to be a pointer by codegen + jl_immutable_array_typename->mayinlinealloc = 0; // It seems like we probably usually end up needing the box for kinds (used in an Any context)--but is that true? jl_uniontype_type->name->mayinlinealloc = 0; jl_unionall_type->name->mayinlinealloc = 0; diff --git a/src/julia.h b/src/julia.h index f3905897a1202..641b99b766dd4 100644 --- a/src/julia.h +++ b/src/julia.h @@ -699,6 +699,8 @@ extern JL_DLLIMPORT jl_unionall_t *jl_abstractarray_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_unionall_t *jl_densearray_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_unionall_t *jl_array_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_typename_t *jl_array_typename JL_GLOBALLY_ROOTED; +extern JL_DLLEXPORT jl_unionall_t *jl_immutable_array_type JL_GLOBALLY_ROOTED; +extern JL_DLLEXPORT jl_typename_t *jl_immutable_array_typename JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_weakref_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_abstractstring_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_string_type JL_GLOBALLY_ROOTED; @@ -1252,11 +1254,25 @@ STATIC_INLINE int jl_is_primitivetype(void *v) JL_NOTSAFEPOINT jl_datatype_size(v) > 0); } +STATIC_INLINE int jl_is_array_type(void *t) JL_NOTSAFEPOINT +{ + return (jl_is_datatype(t) && + (((jl_datatype_t*)(t))->name == jl_array_typename)); +} + +STATIC_INLINE int jl_is_arrayish_type(void *t) JL_NOTSAFEPOINT +{ + return (jl_is_datatype(t) && + (((jl_datatype_t*)(t))->name == jl_array_typename || + ((jl_datatype_t*)(t))->name == jl_immutable_array_typename)); +} + STATIC_INLINE int jl_is_structtype(void *v) JL_NOTSAFEPOINT { return (jl_is_datatype(v) && !((jl_datatype_t*)(v))->name->abstract && - !jl_is_primitivetype(v)); + !jl_is_primitivetype(v) && + !jl_is_arrayish_type(v)); } STATIC_INLINE int jl_isbits(void *t) JL_NOTSAFEPOINT // corresponding to isbits() in julia @@ -1274,16 +1290,16 @@ STATIC_INLINE int jl_is_abstracttype(void *v) JL_NOTSAFEPOINT return (jl_is_datatype(v) && ((jl_datatype_t*)(v))->name->abstract); } -STATIC_INLINE int jl_is_array_type(void *t) JL_NOTSAFEPOINT +STATIC_INLINE int jl_is_array(void *v) JL_NOTSAFEPOINT { - return (jl_is_datatype(t) && - ((jl_datatype_t*)(t))->name == jl_array_typename); + jl_value_t *t = jl_typeof(v); + return jl_is_array_type(t); } -STATIC_INLINE int jl_is_array(void *v) JL_NOTSAFEPOINT +STATIC_INLINE int jl_is_arrayish(void *v) JL_NOTSAFEPOINT { jl_value_t *t = jl_typeof(v); - return jl_is_array_type(t); + return jl_is_arrayish_type(t); } @@ -1554,6 +1570,7 @@ JL_DLLEXPORT jl_value_t *jl_array_to_string(jl_array_t *a); JL_DLLEXPORT jl_array_t *jl_alloc_vec_any(size_t n); JL_DLLEXPORT jl_value_t *jl_arrayref(jl_array_t *a, size_t i); // 0-indexed JL_DLLEXPORT jl_value_t *jl_ptrarrayref(jl_array_t *a JL_PROPAGATES_ROOT, size_t i) JL_NOTSAFEPOINT; // 0-indexed +JL_DLLEXPORT jl_array_t *jl_array_copy(jl_array_t *ary); JL_DLLEXPORT void jl_arrayset(jl_array_t *a JL_ROOTING_ARGUMENT, jl_value_t *v JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED, size_t i); // 0-indexed JL_DLLEXPORT void jl_arrayunset(jl_array_t *a, size_t i); // 0-indexed JL_DLLEXPORT int jl_array_isassigned(jl_array_t *a, size_t i); // 0-indexed diff --git a/src/llvm-late-gc-lowering.cpp b/src/llvm-late-gc-lowering.cpp index 3586527668135..30f9ecd67bc3e 100644 --- a/src/llvm-late-gc-lowering.cpp +++ b/src/llvm-late-gc-lowering.cpp @@ -359,6 +359,7 @@ struct LateLowerGCFrame: public FunctionPass, private JuliaPassContext { void RefineLiveSet(BitVector &LS, State &S, const std::vector &CalleeRoots); Value *EmitTagPtr(IRBuilder<> &builder, Type *T, Value *V); Value *EmitLoadTag(IRBuilder<> &builder, Value *V); + Value *EmitStoreTag(IRBuilder<> &builder, Value *V, Value *Typ); }; static unsigned getValueAddrSpace(Value *V) { @@ -2184,6 +2185,16 @@ Value *LateLowerGCFrame::EmitLoadTag(IRBuilder<> &builder, Value *V) return load; } +Value *LateLowerGCFrame::EmitStoreTag(IRBuilder<> &builder, Value *V, Value *Typ) +{ + auto addr = EmitTagPtr(builder, T_size, V); + StoreInst *store = builder.CreateAlignedStore(Typ, addr, Align(sizeof(size_t))); + store->setOrdering(AtomicOrdering::Unordered); + store->setMetadata(LLVMContext::MD_tbaa, tbaa_tag); + return store; +} + + // Enable this optimization only on LLVM 4.0+ since this cause LLVM to optimize // constant store loop to produce a `memset_pattern16` with a global variable // that's initialized by `addrspacecast`. Such a global variable is not supported by the backend. @@ -2370,6 +2381,19 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S) { typ->takeName(CI); CI->replaceAllUsesWith(typ); UpdatePtrNumbering(CI, typ, S); + } else if (mutating_arrayfreeze_func && callee == mutating_arrayfreeze_func) { + assert(CI->getNumArgOperands() == 2); + IRBuilder<> builder(CI); + builder.SetCurrentDebugLocation(CI->getDebugLoc()); + auto array = CI->getArgOperand(0); + auto tag = EmitLoadTag(builder, array); + auto mark_bits = builder.CreateAnd(tag, ConstantInt::get(T_size, (uintptr_t)15)); + auto new_typ = builder.CreateAddrSpaceCast(CI->getArgOperand(1), + T_pjlvalue); + auto new_typ_marked = builder.CreateOr(builder.CreatePtrToInt(new_typ, T_size), mark_bits); + EmitStoreTag(builder, array, new_typ_marked); + CI->replaceAllUsesWith(array); + UpdatePtrNumbering(CI, array, S); } else if (write_barrier_func && callee == write_barrier_func) { // The replacement for this requires creating new BasicBlocks // which messes up the loop. Queue all of them to be replaced later. diff --git a/src/llvm-pass-helpers.cpp b/src/llvm-pass-helpers.cpp index 2821f9838a0a7..7ba412081a786 100644 --- a/src/llvm-pass-helpers.cpp +++ b/src/llvm-pass-helpers.cpp @@ -29,7 +29,8 @@ JuliaPassContext::JuliaPassContext() pgcstack_getter(nullptr), gc_flush_func(nullptr), gc_preserve_begin_func(nullptr), gc_preserve_end_func(nullptr), pointer_from_objref_func(nullptr), alloc_obj_func(nullptr), - typeof_func(nullptr), write_barrier_func(nullptr), module(nullptr) + typeof_func(nullptr), mutating_arrayfreeze_func(nullptr), + write_barrier_func(nullptr), module(nullptr) { } @@ -50,6 +51,7 @@ void JuliaPassContext::initFunctions(Module &M) gc_preserve_end_func = M.getFunction("llvm.julia.gc_preserve_end"); pointer_from_objref_func = M.getFunction("julia.pointer_from_objref"); typeof_func = M.getFunction("julia.typeof"); + mutating_arrayfreeze_func = M.getFunction("julia.mutating_arrayfreeze"); write_barrier_func = M.getFunction("julia.write_barrier"); alloc_obj_func = M.getFunction("julia.gc_alloc_obj"); } diff --git a/src/llvm-pass-helpers.h b/src/llvm-pass-helpers.h index f80786d1e7149..9352d01e2fbe9 100644 --- a/src/llvm-pass-helpers.h +++ b/src/llvm-pass-helpers.h @@ -67,6 +67,7 @@ struct JuliaPassContext { llvm::Function *pointer_from_objref_func; llvm::Function *alloc_obj_func; llvm::Function *typeof_func; + llvm::Function *mutating_arrayfreeze_func; llvm::Function *write_barrier_func; // Creates a pass context. Type and function pointers diff --git a/src/rtutils.c b/src/rtutils.c index b4432d8af3d0c..b79970ba47f47 100644 --- a/src/rtutils.c +++ b/src/rtutils.c @@ -1001,8 +1001,8 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt n += jl_printf(out, ")"); } } - else if (jl_array_type && jl_is_array_type(vt)) { - n += jl_printf(out, "Array{"); + else if (jl_array_type && jl_is_arrayish_type(vt)) { + n += jl_printf(out, jl_is_array_type(vt) ? "Array{" : "ImmutableArray{"); n += jl_static_show_x(out, (jl_value_t*)jl_tparam0(vt), depth); n += jl_printf(out, ", ("); size_t i, ndims = jl_array_ndims(v); diff --git a/src/staticdata.c b/src/staticdata.c index 28a21e9ea7c2b..5b6622b71038e 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -79,7 +79,7 @@ extern "C" { // TODO: put WeakRefs on the weak_refs list during deserialization // TODO: handle finalizers -#define NUM_TAGS 153 +#define NUM_TAGS 156 // An array of references that need to be restored from the sysimg // This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C. @@ -99,6 +99,7 @@ jl_value_t **const*const get_tags(void) { INSERT_TAG(jl_slotnumber_type); INSERT_TAG(jl_simplevector_type); INSERT_TAG(jl_array_type); + INSERT_TAG(jl_immutable_array_type); INSERT_TAG(jl_typedslot_type); INSERT_TAG(jl_expr_type); INSERT_TAG(jl_globalref_type); @@ -184,6 +185,7 @@ jl_value_t **const*const get_tags(void) { INSERT_TAG(jl_pointer_typename); INSERT_TAG(jl_llvmpointer_typename); INSERT_TAG(jl_array_typename); + INSERT_TAG(jl_immutable_array_typename); INSERT_TAG(jl_type_typename); INSERT_TAG(jl_namedtuple_typename); INSERT_TAG(jl_vecelement_typename); @@ -252,6 +254,9 @@ jl_value_t **const*const get_tags(void) { INSERT_TAG(jl_builtin_ifelse); INSERT_TAG(jl_builtin__typebody); INSERT_TAG(jl_builtin_donotdelete); + INSERT_TAG(jl_builtin_arrayfreeze); + INSERT_TAG(jl_builtin_mutating_arrayfreeze); + INSERT_TAG(jl_builtin_arraythaw); // All optional tags must be placed at the end, so that we // don't accidentally have a `NULL` in the middle @@ -310,6 +315,7 @@ static const jl_fptr_args_t id_to_fptrs[] = { &jl_f_ifelse, &jl_f__structtype, &jl_f__abstracttype, &jl_f__primitivetype, &jl_f__typebody, &jl_f__setsuper, &jl_f__equiv_typedef, &jl_f_get_binding_type, &jl_f_set_binding_type, &jl_f_opaque_closure_call, &jl_f_donotdelete, + &jl_f_arrayfreeze, &jl_f_mutating_arrayfreeze, &jl_f_arraythaw, NULL }; typedef struct { diff --git a/test/choosetests.jl b/test/choosetests.jl index f86f665bc2217..98eb5f6d70fbc 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -143,7 +143,8 @@ function choosetests(choices = []) filtertests!(tests, "compiler", ["compiler/inference", "compiler/validation", "compiler/ssair", "compiler/irpasses", "compiler/codegen", "compiler/inline", "compiler/contextual", - "compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"]) + "compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural", + "compiler/immutablearray"]) filtertests!(tests, "compiler/EscapeAnalysis", [ "compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"]) filtertests!(tests, "stdlib", STDLIBS) diff --git a/test/compiler/immutablearray.jl b/test/compiler/immutablearray.jl new file mode 100644 index 0000000000000..474e5dfc0f657 --- /dev/null +++ b/test/compiler/immutablearray.jl @@ -0,0 +1,12 @@ +using Base.Experimental: ImmutableArray +function simple() + a = Vector{Float64}(undef, 5) + for i = 1:5 + a[i] = i + end + ImmutableArray(a) +end +let + @allocated(simple()) + @test @allocated(simple()) < 100 +end From 572a6fea6ed19d2e61c08f858b3a3aedf966858d Mon Sep 17 00:00:00 2001 From: Ian Atol Date: Mon, 28 Feb 2022 13:48:44 -0500 Subject: [PATCH 4/5] Fixup with latest ImmutableArrays changes --- base/abstractarray.jl | 2 + base/array.jl | 55 ++- base/broadcast.jl | 13 + .../ssair/EscapeAnalysis/EscapeAnalysis.jl | 4 +- base/compiler/tfuncs.jl | 72 ++- base/compiler/types.jl | 2 +- base/exports.jl | 2 + base/indices.jl | 1 + base/pointer.jl | 1 + src/builtins.c | 6 +- src/codegen.cpp | 10 +- src/staticdata.c | 2 +- stdlib/LinearAlgebra/test/adjtrans.jl | 16 +- stdlib/LinearAlgebra/test/bidiag.jl | 16 +- stdlib/LinearAlgebra/test/diagonal.jl | 10 +- stdlib/LinearAlgebra/test/hessenberg.jl | 10 +- stdlib/LinearAlgebra/test/symmetric.jl | 14 +- stdlib/LinearAlgebra/test/triangular.jl | 14 +- stdlib/LinearAlgebra/test/tridiag.jl | 18 +- test/choosetests.jl | 2 +- test/compiler/immutablearray.jl | 427 +++++++++++++++++- test/compiler/inference.jl | 17 + test/immutablearray.jl | 183 ++++++++ test/testhelpers/ImmutableArrays.jl | 28 -- test/testhelpers/SimpleImmutableArrays.jl | 28 ++ 25 files changed, 815 insertions(+), 138 deletions(-) create mode 100644 test/immutablearray.jl delete mode 100644 test/testhelpers/ImmutableArrays.jl create mode 100644 test/testhelpers/SimpleImmutableArrays.jl diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 8bba670b96473..70f8e9875f1aa 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1085,6 +1085,8 @@ function copy(a::AbstractArray) copymutable(a) end +copy(a::Core.ImmutableArray) = a + function copyto!(B::AbstractVecOrMat{R}, ir_dest::AbstractRange{Int}, jr_dest::AbstractRange{Int}, A::AbstractVecOrMat{S}, ir_src::AbstractRange{Int}, jr_src::AbstractRange{Int}) where {R,S} if length(ir_dest) != length(ir_src) diff --git a/base/array.jl b/base/array.jl index dfcaed3725a1a..0dfd062791ac6 100644 --- a/base/array.jl +++ b/base/array.jl @@ -118,6 +118,36 @@ Union type of [`DenseVector{T}`](@ref) and [`DenseMatrix{T}`](@ref). """ const DenseVecOrMat{T} = Union{DenseVector{T}, DenseMatrix{T}} +""" + ImmutableArray{T,N} <: AbstractArray{T,N} +Dynamically allocated, immutable array. +""" +const ImmutableArray = Core.ImmutableArray + +""" + ImmutableVector{T} <: AbstractVector{T} +Dynamically allocated, immutable vector. +""" +const ImmutableVector{T} = ImmutableArray{T,1} + +""" + IMArray{T,N} +Union type of [`Array{T,N}`](@ref) and [`ImmutableArray{T,N}`](@ref) +""" +const IMArray{T,N} = Union{Array{T, N}, ImmutableArray{T,N}} + +""" + IMVector{T} +One-dimensional [`ImmutableArray`](@ref) or [`Array`](@ref) with elements of type `T`. Alias for `IMArray{T, 1}`. +""" +const IMVector{T} = IMArray{T, 1} + +""" + IMMatrix{T} +Two-dimensional [`ImmutableArray`](@ref) or [`Array`](@ref) with elements of type `T`. Alias for `IMArray{T,2}`. +""" +const IMMatrix{T} = IMArray{T, 2} + ## Basic functions ## import Core: arraysize, arrayset, arrayref, const_arrayref @@ -147,18 +177,13 @@ function vect(X...) return copyto!(Vector{T}(undef, length(X)), X) end -const ImmutableArray = Core.ImmutableArray -const IMArray{T,N} = Union{Array{T, N}, ImmutableArray{T,N}} -const IMVector{T} = IMArray{T, 1} -const IMMatrix{T} = IMArray{T, 2} - ImmutableArray(a::Array) = Core.arrayfreeze(a) Array(a::ImmutableArray) = Core.arraythaw(a) size(a::IMArray, d::Integer) = arraysize(a, convert(Int, d)) size(a::IMVector) = (arraysize(a,1),) size(a::IMMatrix) = (arraysize(a,1), arraysize(a,2)) -size(a::IMArray{<:Any,N}) where {N} = (@_inline_meta; ntuple(M -> size(a, M), Val(N))::Dims) +size(a::IMArray{<:Any,N}) where {N} = (@inline; ntuple(M -> size(a, M), Val(N))::Dims) asize_from(a::IMArray, n) = n > ndims(a) ? () : (arraysize(a,n), asize_from(a, n+1)...) @@ -220,24 +245,17 @@ function bitsunionsize(u::Union) return sz end -length(a::Array) = arraylen(a) +length(a::IMArray) = arraylen(a) elsize(@nospecialize _::Type{A}) where {T,A<:Array{T}} = aligned_sizeof(T) -sizeof(a::Array) = Core.sizeof(a) +sizeof(a::IMArray) = Core.sizeof(a) -function isassigned(a::Array, i::Int...) +function isassigned(a::IMArray, i::Int...) @inline ii = (_sub2ind(size(a), i...) % UInt) - 1 @boundscheck ii < length(a) % UInt || return false ccall(:jl_array_isassigned, Cint, (Any, UInt), a, ii) == 1 end -function isassigned(a::ImmutableArray, i::Int...) - @_inline_meta - ii = (_sub2ind(size(a), i...) % UInt) - 1 - @boundscheck ii < length(a) % UInt || return false - ccall(:jl_array_isassigned, Cint, (Any, UInt), a, ii) == 1 -end - ## copy ## """ @@ -626,7 +644,7 @@ oneunit(x::AbstractMatrix{T}) where {T} = _one(oneunit(T), x) ## Conversions ## -convert(::Type{T}, a::AbstractArray) where {T<:Array} = a isa T ? a : T(a) +convert(T::Type{<:IMArray}, a::AbstractArray) = a isa T ? a : T(a) convert(::Type{Union{}}, a::AbstractArray) = throw(MethodError(convert, (Union{}, a))) promote_rule(a::Type{Array{T,n}}, b::Type{Array{S,n}}) where {T,n,S} = el_same(promote_type(T,S), a, b) @@ -637,6 +655,7 @@ if nameof(@__MODULE__) === :Base # avoid method overwrite # constructors should make copies Array{T,N}(x::AbstractArray{S,N}) where {T,N,S} = copyto_axcheck!(Array{T,N}(undef, size(x)), x) AbstractArray{T,N}(A::AbstractArray{S,N}) where {T,N,S} = copyto_axcheck!(similar(A,T), A) +ImmutableArray{T,N}(Ar::AbstractArray{S,N}) where {T,N,S} = Core.arrayfreeze(copyto_axcheck!(Array{T,N}(undef, size(Ar)), Ar)) end ## copying iterators to containers @@ -937,7 +956,7 @@ function getindex end @eval getindex(A::Array, i1::Int, i2::Int, I::Int...) = (@inline; arrayref($(Expr(:boundscheck)), A, i1, i2, I...)) @eval getindex(A::ImmutableArray, i1::Int) = arrayref($(Expr(:boundscheck)), A, i1) -@eval getindex(A::ImmutableArray, i1::Int, i2::Int, I::Int...) = (@_inline_meta; arrayref($(Expr(:boundscheck)), A, i1, i2, I...)) +@eval getindex(A::ImmutableArray, i1::Int, i2::Int, I::Int...) = (@inline; arrayref($(Expr(:boundscheck)), A, i1, i2, I...)) # Faster contiguous indexing using copyto! for UnitRange and Colon function getindex(A::Array, I::AbstractUnitRange{<:Integer}) diff --git a/base/broadcast.jl b/base/broadcast.jl index fb9ba9555cfd9..85b9057e8ceef 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -1345,4 +1345,17 @@ function Base.show(io::IO, op::BroadcastFunction) end Base.show(io::IO, ::MIME"text/plain", op::BroadcastFunction) = show(io, op) +struct IMArrayStyle <: Broadcast.AbstractArrayStyle{Any} end +BroadcastStyle(::Type{<:Core.ImmutableArray}) = IMArrayStyle() + +#similar has to return mutable array +function Base.similar(bc::Broadcasted{IMArrayStyle}, ::Type{ElType}) where ElType + similar(Array{ElType}, axes(bc)) +end + +@inline function copy(bc::Broadcasted{IMArrayStyle}) + ElType = combine_eltypes(bc.f, bc.args) + return Core.ImmutableArray(copyto!(similar(bc, ElType), bc)) +end + end # module diff --git a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl index 0cb34e76c36bb..73c6c7b3d0b2d 100644 --- a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl +++ b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl @@ -1608,7 +1608,7 @@ function escape_builtin!(::typeof(arrayref), astate::AnalysisState, pc::Int, arg argtypes = Any[argextype(args[i], astate.ir) for i in 2:length(args)] boundcheckt = argtypes[1] aryt = argtypes[2] - if !array_builtin_common_typecheck(boundcheckt, aryt, argtypes, 3) + if !array_builtin_common_typecheck(Array, boundcheckt, aryt, argtypes, 3) add_thrown_escapes!(astate, pc, args, 2) end ary = args[3] @@ -1670,7 +1670,7 @@ function escape_builtin!(::typeof(arrayset), astate::AnalysisState, pc::Int, arg boundcheckt = argtypes[1] aryt = argtypes[2] valt = argtypes[3] - if !(array_builtin_common_typecheck(boundcheckt, aryt, argtypes, 4) && + if !(array_builtin_common_typecheck(Array, boundcheckt, aryt, argtypes, 4) && arrayset_typecheck(aryt, valt)) add_thrown_escapes!(astate, pc, args, 2) end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 71845fe13b503..08df099971516 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -461,8 +461,10 @@ add_tfunc(Core._typevar, 3, 3, typevar_tfunc, 100) add_tfunc(applicable, 1, INT_INF, (@nospecialize(f), args...)->Bool, 100) add_tfunc(Core.Intrinsics.arraylen, 1, 1, @nospecialize(x)->Int, 4) +const Arrayish = Union{Array,ImmutableArray} + function arraysize_tfunc(@nospecialize(ary), @nospecialize(dim)) - hasintersect(widenconst(ary), Array) || return Bottom + hasintersect(widenconst(ary), Arrayish) || return Bottom hasintersect(widenconst(dim), Int) || return Bottom return Int end @@ -472,7 +474,7 @@ function arraysize_nothrow(argtypes::Vector{Any}) length(argtypes) == 2 || return false ary = argtypes[1] dim = argtypes[2] - ary ⊑ Array || return false + widenconst(ary) <: Arrayish || return false if isa(dim, Const) dimval = dim.val return isa(dimval, Int) && dimval > 0 @@ -1535,27 +1537,27 @@ function tuple_tfunc(argtypes::Vector{Any}) end arrayref_tfunc(@nospecialize(boundscheck), @nospecialize(ary), @nospecialize idxs...) = - _arrayref_tfunc(boundscheck, ary, idxs) -function _arrayref_tfunc(@nospecialize(boundscheck), @nospecialize(ary), - @nospecialize idxs::Tuple) + _arrayref_tfunc(Arrayish, boundscheck, ary, idxs) +function _arrayref_tfunc(@nospecialize(Arytype), + @nospecialize(boundscheck), @nospecialize(ary), @nospecialize idxs::Tuple) isempty(idxs) && return Bottom - array_builtin_common_errorcheck(boundscheck, ary, idxs) || return Bottom - return array_elmtype(ary) + array_builtin_common_errorcheck(Arytype, boundscheck, ary, idxs) || return Bottom + return array_elmtype(Arytype, ary) end add_tfunc(arrayref, 3, INT_INF, arrayref_tfunc, 20) add_tfunc(const_arrayref, 3, INT_INF, arrayref_tfunc, 20) function arrayset_tfunc(@nospecialize(boundscheck), @nospecialize(ary), @nospecialize(item), @nospecialize idxs...) - hasintersect(widenconst(item), _arrayref_tfunc(boundscheck, ary, idxs)) || return Bottom + hasintersect(widenconst(item), _arrayref_tfunc(Array, boundscheck, ary, idxs)) || return Bottom return ary end add_tfunc(arrayset, 4, INT_INF, arrayset_tfunc, 20) -function array_builtin_common_errorcheck(@nospecialize(boundscheck), @nospecialize(ary), - @nospecialize idxs::Tuple) +function array_builtin_common_errorcheck(@nospecialize(Arytype), + @nospecialize(boundscheck), @nospecialize(ary), @nospecialize idxs::Tuple) hasintersect(widenconst(boundscheck), Bool) || return false - hasintersect(widenconst(ary), Array) || return false + hasintersect(widenconst(ary), Arytype) || return false for i = 1:length(idxs) idx = getfield(idxs, i) idx = isvarargtype(idx) ? unwrapva(idx) : widenconst(idx) @@ -1564,9 +1566,9 @@ function array_builtin_common_errorcheck(@nospecialize(boundscheck), @nospeciali return true end -function array_elmtype(@nospecialize ary) +function array_elmtype(@nospecialize(Arytype), @nospecialize ary) a = widenconst(ary) - if !has_free_typevars(a) && a <: Array + if !has_free_typevars(a) && a <: Arytype a0 = a if isa(a, UnionAll) a = unwrap_unionall(a0) @@ -1580,6 +1582,32 @@ function array_elmtype(@nospecialize ary) return Any end +# the ImmutableArray operations might involve copies and so their computation costs can be high, +# nevertheless we assign smaller inlining costs to them here, since the escape analysis +# at this moment isn't able to propagate array escapes interprocedurally +# and it will fail to optimize most cases without inlining + +arrayfreeze_tfunc(@nospecialize a) = immutable_array_tfunc(Array, ImmutableArray, a) +add_tfunc(Core.arrayfreeze, 1, 1, arrayfreeze_tfunc, 20) + +mutating_arrayfreeze_tfunc(@nospecialize a) = immutable_array_tfunc(Array, ImmutableArray, a) +add_tfunc(Core.mutating_arrayfreeze, 1, 1, mutating_arrayfreeze_tfunc, 10) + +arraythaw_tfunc(@nospecialize a) = immutable_array_tfunc(ImmutableArray, Array, a) +add_tfunc(Core.arraythaw, 1, 1, arraythaw_tfunc, 20) + +function immutable_array_tfunc(@nospecialize(at), @nospecialize(rt), @nospecialize(a)) + a = widenconst(a) + hasintersect(a, at) || return Bottom + if a <: at + unw = unwrap_unionall(a) + if isa(unw, DataType) + return rewrap_unionall(rt{unw.parameters[1], unw.parameters[2]}, a) + end + end + return rt +end + function _opaque_closure_tfunc(@nospecialize(arg), @nospecialize(isva), @nospecialize(lb), @nospecialize(ub), @nospecialize(source), env::Vector{Any}, linfo::MethodInstance) @@ -1613,11 +1641,12 @@ function array_type_undefable(@nospecialize(arytype)) end end -function array_builtin_common_nothrow(argtypes::Vector{Any}, first_idx_idx::Int) +function array_builtin_common_nothrow(@nospecialize(Arytype), + argtypes::Vector{Any}, first_idx_idx::Int) length(argtypes) >= 4 || return false boundscheck = argtypes[1] arytype = argtypes[2] - array_builtin_common_typecheck(boundscheck, arytype, argtypes, first_idx_idx) || return false + array_builtin_common_typecheck(Arytype, boundscheck, arytype, argtypes, first_idx_idx) || return false # If we could potentially throw undef ref errors, bail out now. arytype = widenconst(arytype) array_type_undefable(arytype) && return false @@ -1632,12 +1661,12 @@ function array_builtin_common_nothrow(argtypes::Vector{Any}, first_idx_idx::Int) return false end -function array_builtin_common_typecheck( +function array_builtin_common_typecheck(@nospecialize(Arytype), @nospecialize(boundscheck), @nospecialize(arytype), argtypes::Vector{Any}, first_idx_idx::Int) - (boundscheck ⊑ Bool && arytype ⊑ Array) || return false + (widenconst(boundscheck) <: Bool && widenconst(arytype) <: Arytype) || return false for i = first_idx_idx:length(argtypes) - argtypes[i] ⊑ Int || return false + widenconst(argtypes[i]) <: Int || return false end return true end @@ -1656,11 +1685,11 @@ end # Query whether the given builtin is guaranteed not to throw given the argtypes function _builtin_nothrow(@nospecialize(f), argtypes::Array{Any,1}, @nospecialize(rt)) if f === arrayset - array_builtin_common_nothrow(argtypes, 4) || return true + array_builtin_common_nothrow(Array, argtypes, 4) || return true # Additionally check element type compatibility return arrayset_typecheck(argtypes[2], argtypes[3]) elseif f === arrayref || f === const_arrayref - return array_builtin_common_nothrow(argtypes, 3) + return array_builtin_common_nothrow(Arrayish, argtypes, 3) elseif f === arraysize return arraysize_nothrow(argtypes) elseif f === Core._expr @@ -1818,8 +1847,7 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp end end return rt - end - if isa(f, IntrinsicFunction) + elseif isa(f, IntrinsicFunction) if is_pure_intrinsic_infer(f) && _all(@nospecialize(a) -> isa(a, Const), argtypes) argvals = anymap(@nospecialize(a) -> (a::Const).val, argtypes) try diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 1ef92ea65598e..b2ccb41071ad6 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -261,7 +261,7 @@ struct NativeInterpreter <: AbstractInterpreter end # If they didn't pass typemax(UInt) but passed something more subtly - # incorrect, fail out loudly. + # incorrect, fail out loud. @assert world <= get_world_counter() diff --git a/base/exports.jl b/base/exports.jl index 2d790f16b7986..78b16ee8876c9 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -22,6 +22,7 @@ export AbstractVector, AbstractVecOrMat, Array, + ImmutableArray, AbstractMatch, AbstractPattern, AbstractDict, @@ -96,6 +97,7 @@ export Val, VecOrMat, Vector, + ImmutableVector, VersionNumber, WeakKeyDict, diff --git a/base/indices.jl b/base/indices.jl index 28028f23c72a3..0b7d7d7212940 100644 --- a/base/indices.jl +++ b/base/indices.jl @@ -95,6 +95,7 @@ IndexStyle(A::AbstractArray) = IndexStyle(typeof(A)) IndexStyle(::Type{Union{}}) = IndexLinear() IndexStyle(::Type{<:AbstractArray}) = IndexCartesian() IndexStyle(::Type{<:Array}) = IndexLinear() +IndexStyle(::Type{<:Core.ImmutableArray}) = IndexLinear() IndexStyle(::Type{<:AbstractRange}) = IndexLinear() IndexStyle(A::AbstractArray, B::AbstractArray) = IndexStyle(IndexStyle(A), IndexStyle(B)) diff --git a/base/pointer.jl b/base/pointer.jl index b9475724f7637..334d160bc92fa 100644 --- a/base/pointer.jl +++ b/base/pointer.jl @@ -63,6 +63,7 @@ cconvert(::Type{Ptr{UInt8}}, s::AbstractString) = String(s) cconvert(::Type{Ptr{Int8}}, s::AbstractString) = String(s) unsafe_convert(::Type{Ptr{T}}, a::Array{T}) where {T} = ccall(:jl_array_ptr, Ptr{T}, (Any,), a) +unsafe_convert(::Type{Ptr{T}}, a::Core.ImmutableArray{T}) where {T} = ccall(:jl_array_ptr, Ptr{T}, (Any,), a) unsafe_convert(::Type{Ptr{S}}, a::AbstractArray{T}) where {S,T} = convert(Ptr{S}, unsafe_convert(Ptr{T}, a)) unsafe_convert(::Type{Ptr{T}}, a::AbstractArray{T}) where {T} = error("conversion to pointer not defined for $(typeof(a))") diff --git a/src/builtins.c b/src/builtins.c index 0f3129f57293c..9538c0671863f 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -1741,7 +1741,7 @@ JL_CALLABLE(jl_f_set_binding_type) JL_CALLABLE(jl_f_arrayfreeze) { - JL_NARGSV(arrayfreeze, 1); + JL_NARGS(arrayfreeze, 1, 1); JL_TYPECHK(arrayfreeze, array, args[0]); jl_array_t *a = (jl_array_t*)args[0]; jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_immutable_array_type, @@ -1760,7 +1760,7 @@ JL_CALLABLE(jl_f_mutating_arrayfreeze) // N.B.: These error checks pretend to be arrayfreeze since this is a drop // in replacement and we don't want to change the visible error type in the // optimizer - JL_NARGSV(arrayfreeze, 1); + JL_NARGS(arrayfreeze, 1, 1); JL_TYPECHK(arrayfreeze, array, args[0]); jl_array_t *a = (jl_array_t*)args[0]; jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_immutable_array_type, @@ -1771,7 +1771,7 @@ JL_CALLABLE(jl_f_mutating_arrayfreeze) JL_CALLABLE(jl_f_arraythaw) { - JL_NARGSV(arraythaw, 1); + JL_NARGS(arraythaw, 1, 1); if (((jl_datatype_t*)jl_typeof(args[0]))->name != jl_immutable_array_typename) { jl_type_error("arraythaw", (jl_value_t*)jl_immutable_array_type, args[0]); } diff --git a/src/codegen.cpp b/src/codegen.cpp index e748742b1ebb2..60132a5c2f9a7 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1056,7 +1056,9 @@ static const auto pointer_from_objref_func = new JuliaFunction{ }; static const auto mutating_arrayfreeze_func = new JuliaFunction{ "julia.mutating_arrayfreeze", - [](LLVMContext &C) { return FunctionType::get(T_prjlvalue, + [](LLVMContext &C) { + auto T_prjlvalue = JuliaType::get_prjlvalue_ty(C); + return FunctionType::get(T_prjlvalue, {T_prjlvalue, T_prjlvalue}, false); }, [](LLVMContext &C) { return AttributeList::get(C, Attributes(C, {Attribute::NoUnwind, Attribute::NoRecurse}), @@ -8201,9 +8203,9 @@ extern "C" void jl_init_llvm(void) { jl_f_arraysize_addr, new JuliaFunction{XSTR(jl_f_arraysize), get_func_sig, get_func_attrs} }, { jl_f_apply_type_addr, new JuliaFunction{XSTR(jl_f_apply_type), get_func_sig, get_func_attrs} }, { jl_f_donotdelete_addr, new JuliaFunction{XSTR(jl_f_donotdelete), get_func_sig, get_donotdelete_func_attrs} }, - { jl_f_arrayfreeze_addr, new JuliaFunction{XSTR(jl_f_arrayfreeze), get_func_sig, get_func_attrs} }, - { jl_f_arraythaw_addr, new JuliaFunction{XSTR(jl_f_arraythaw), get_func_sig, get_func_attrs} }, - { jl_f_mutating_arrayfreeze_addr, new JuliaFunction{XSTR(jl_f_mutating_arrayfreeze), get_func_sig, get_func_attrs} } + { jl_f_arrayfreeze_addr, new JuliaFunction{XSTR(jl_f_arrayfreeze), get_func_sig, get_func_attrs} }, + { jl_f_arraythaw_addr, new JuliaFunction{XSTR(jl_f_arraythaw), get_func_sig, get_func_attrs} }, + { jl_f_mutating_arrayfreeze_addr, new JuliaFunction{XSTR(jl_f_mutating_arrayfreeze), get_func_sig, get_func_attrs} }, }; jl_default_debug_info_kind = (int) DICompileUnit::DebugEmissionKind::FullDebug; diff --git a/src/staticdata.c b/src/staticdata.c index 5b6622b71038e..d0704349d01e2 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -79,7 +79,7 @@ extern "C" { // TODO: put WeakRefs on the weak_refs list during deserialization // TODO: handle finalizers -#define NUM_TAGS 156 +#define NUM_TAGS 158 // An array of references that need to be restored from the sysimg // This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C. diff --git a/stdlib/LinearAlgebra/test/adjtrans.jl b/stdlib/LinearAlgebra/test/adjtrans.jl index 7b782d463768d..ae2946a68809a 100644 --- a/stdlib/LinearAlgebra/test/adjtrans.jl +++ b/stdlib/LinearAlgebra/test/adjtrans.jl @@ -241,22 +241,22 @@ end @test convert(Transpose{Float64,Matrix{Float64}}, Transpose(intmat))::Transpose{Float64,Matrix{Float64}} == Transpose(intmat) end -isdefined(Main, :ImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "ImmutableArrays.jl")) -using .Main.ImmutableArrays +isdefined(Main, :SimpleImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SimpleImmutableArrays.jl")) +using .Main.SimpleImmutableArrays @testset "Adjoint and Transpose convert methods to AbstractArray" begin # tests corresponding to #34995 intvec, intmat = [1, 2], [1 2 3; 4 5 6] - statvec = ImmutableArray(intvec) - statmat = ImmutableArray(intmat) + statvec = SimpleImmutableArray(intvec) + statmat = SimpleImmutableArray(intmat) - @test convert(AbstractArray{Float64}, Adjoint(statvec))::Adjoint{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == Adjoint(statvec) + @test convert(AbstractArray{Float64}, Adjoint(statvec))::Adjoint{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == Adjoint(statvec) @test convert(AbstractArray{Float64}, Adjoint(statmat))::Array{Float64,2} == Adjoint(statmat) - @test convert(AbstractArray{Float64}, Transpose(statvec))::Transpose{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == Transpose(statvec) + @test convert(AbstractArray{Float64}, Transpose(statvec))::Transpose{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == Transpose(statvec) @test convert(AbstractArray{Float64}, Transpose(statmat))::Array{Float64,2} == Transpose(statmat) - @test convert(AbstractMatrix{Float64}, Adjoint(statvec))::Adjoint{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == Adjoint(statvec) + @test convert(AbstractMatrix{Float64}, Adjoint(statvec))::Adjoint{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == Adjoint(statvec) @test convert(AbstractMatrix{Float64}, Adjoint(statmat))::Array{Float64,2} == Adjoint(statmat) - @test convert(AbstractMatrix{Float64}, Transpose(statvec))::Transpose{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == Transpose(statvec) + @test convert(AbstractMatrix{Float64}, Transpose(statvec))::Transpose{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == Transpose(statvec) @test convert(AbstractMatrix{Float64}, Transpose(statmat))::Array{Float64,2} == Transpose(statmat) end diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 422984d84eb6b..1c88b14d5606c 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -740,20 +740,20 @@ end @test c \ A ≈ c \ Matrix(A) end -isdefined(Main, :ImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "ImmutableArrays.jl")) -using .Main.ImmutableArrays +isdefined(Main, :SimpleImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SimpleImmutableArrays.jl")) +using .Main.SimpleImmutableArrays @testset "Conversion to AbstractArray" begin # tests corresponding to #34995 - dv = ImmutableArray([1, 2, 3, 4]) - ev = ImmutableArray([7, 8, 9]) + dv = SimpleImmutableArray([1, 2, 3, 4]) + ev = SimpleImmutableArray([7, 8, 9]) Bu = Bidiagonal(dv, ev, :U) Bl = Bidiagonal(dv, ev, :L) - @test convert(AbstractArray{Float64}, Bu)::Bidiagonal{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == Bu - @test convert(AbstractMatrix{Float64}, Bu)::Bidiagonal{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == Bu - @test convert(AbstractArray{Float64}, Bl)::Bidiagonal{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == Bl - @test convert(AbstractMatrix{Float64}, Bl)::Bidiagonal{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == Bl + @test convert(AbstractArray{Float64}, Bu)::Bidiagonal{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == Bu + @test convert(AbstractMatrix{Float64}, Bu)::Bidiagonal{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == Bu + @test convert(AbstractArray{Float64}, Bl)::Bidiagonal{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == Bl + @test convert(AbstractMatrix{Float64}, Bl)::Bidiagonal{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == Bl end @testset "block-bidiagonal matrix indexing" begin diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 6efed3b7d9cff..f7a40a69ec84d 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -938,16 +938,16 @@ end end const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") -isdefined(Main, :ImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "ImmutableArrays.jl")) -using .Main.ImmutableArrays +isdefined(Main, :SimpleImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SimpleImmutableArrays.jl")) +using .Main.SimpleImmutableArrays @testset "Conversion to AbstractArray" begin # tests corresponding to #34995 - d = ImmutableArray([1, 2, 3, 4]) + d = SimpleImmutableArray([1, 2, 3, 4]) D = Diagonal(d) - @test convert(AbstractArray{Float64}, D)::Diagonal{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == D - @test convert(AbstractMatrix{Float64}, D)::Diagonal{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == D + @test convert(AbstractArray{Float64}, D)::Diagonal{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == D + @test convert(AbstractMatrix{Float64}, D)::Diagonal{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == D end @testset "divisions functionality" for elty in (Int, Float64, ComplexF64) diff --git a/stdlib/LinearAlgebra/test/hessenberg.jl b/stdlib/LinearAlgebra/test/hessenberg.jl index b2b23caac6865..04431e6f727b4 100644 --- a/stdlib/LinearAlgebra/test/hessenberg.jl +++ b/stdlib/LinearAlgebra/test/hessenberg.jl @@ -213,16 +213,16 @@ end end end -isdefined(Main, :ImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "ImmutableArrays.jl")) -using .Main.ImmutableArrays +isdefined(Main, :SimpleImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SimpleImmutableArrays.jl")) +using .Main.SimpleImmutableArrays @testset "Conversion to AbstractArray" begin # tests corresponding to #34995 - A = ImmutableArray([1 2 3; 4 5 6; 7 8 9]) + A = SimpleImmutableArray([1 2 3; 4 5 6; 7 8 9]) H = UpperHessenberg(A) - @test convert(AbstractArray{Float64}, H)::UpperHessenberg{Float64,ImmutableArray{Float64,2,Array{Float64,2}}} == H - @test convert(AbstractMatrix{Float64}, H)::UpperHessenberg{Float64,ImmutableArray{Float64,2,Array{Float64,2}}} == H + @test convert(AbstractArray{Float64}, H)::UpperHessenberg{Float64,SimpleImmutableArray{Float64,2,Array{Float64,2}}} == H + @test convert(AbstractMatrix{Float64}, H)::UpperHessenberg{Float64,SimpleImmutableArray{Float64,2,Array{Float64,2}}} == H end end # module TestHessenberg diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 47a36df5e7883..55af4fe456ff2 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -544,19 +544,19 @@ end end const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") -isdefined(Main, :ImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "ImmutableArrays.jl")) -using .Main.ImmutableArrays +isdefined(Main, :SimpleImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SimpleImmutableArrays.jl")) +using .Main.SimpleImmutableArrays @testset "Conversion to AbstractArray" begin # tests corresponding to #34995 - immutablemat = ImmutableArray([1 2 3; 4 5 6; 7 8 9]) + immutablemat = SimpleImmutableArray([1 2 3; 4 5 6; 7 8 9]) for SymType in (Symmetric, Hermitian) S = Float64 symmat = SymType(immutablemat) - @test convert(AbstractArray{S}, symmat).data isa ImmutableArray{S} - @test convert(AbstractMatrix{S}, symmat).data isa ImmutableArray{S} - @test AbstractArray{S}(symmat).data isa ImmutableArray{S} - @test AbstractMatrix{S}(symmat).data isa ImmutableArray{S} + @test convert(AbstractArray{S}, symmat).data isa SimpleImmutableArray{S} + @test convert(AbstractMatrix{S}, symmat).data isa SimpleImmutableArray{S} + @test AbstractArray{S}(symmat).data isa SimpleImmutableArray{S} + @test AbstractMatrix{S}(symmat).data isa SimpleImmutableArray{S} @test convert(AbstractArray{S}, symmat) == symmat @test convert(AbstractMatrix{S}, symmat) == symmat end diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index dfb4d7c8a0b95..7c0c11e0369d5 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -690,20 +690,20 @@ let A = UpperTriangular([Furlong(1) Furlong(4); Furlong(0) Furlong(1)]) @test sqrt(A) == Furlong{1//2}.(UpperTriangular([1 2; 0 1])) end -isdefined(Main, :ImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "ImmutableArrays.jl")) -using .Main.ImmutableArrays +isdefined(Main, :SimpleImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SimpleImmutableArrays.jl")) +using .Main.SimpleImmutableArrays @testset "AbstractArray constructor should preserve underlying storage type" begin # tests corresponding to #34995 local m = 4 local T, S = Float32, Float64 - immutablemat = ImmutableArray(randn(T,m,m)) + immutablemat = SimpleImmutableArray(randn(T,m,m)) for TriType in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) trimat = TriType(immutablemat) - @test convert(AbstractArray{S}, trimat).data isa ImmutableArray{S} - @test convert(AbstractMatrix{S}, trimat).data isa ImmutableArray{S} - @test AbstractArray{S}(trimat).data isa ImmutableArray{S} - @test AbstractMatrix{S}(trimat).data isa ImmutableArray{S} + @test convert(AbstractArray{S}, trimat).data isa SimpleImmutableArray{S} + @test convert(AbstractMatrix{S}, trimat).data isa SimpleImmutableArray{S} + @test AbstractArray{S}(trimat).data isa SimpleImmutableArray{S} + @test AbstractMatrix{S}(trimat).data isa SimpleImmutableArray{S} @test convert(AbstractArray{S}, trimat) == trimat @test convert(AbstractMatrix{S}, trimat) == trimat end diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index ecdf6b416baa5..fa6ce93fc296c 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -667,21 +667,21 @@ end @test ishermitian(S) end -isdefined(Main, :ImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "ImmutableArrays.jl")) -using .Main.ImmutableArrays +isdefined(Main, :SimpleImmutableArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SimpleImmutableArrays.jl")) +using .Main.SimpleImmutableArrays @testset "Conversion to AbstractArray" begin # tests corresponding to #34995 - v1 = ImmutableArray([1, 2]) - v2 = ImmutableArray([3, 4, 5]) - v3 = ImmutableArray([6, 7]) + v1 = SimpleImmutableArray([1, 2]) + v2 = SimpleImmutableArray([3, 4, 5]) + v3 = SimpleImmutableArray([6, 7]) T = Tridiagonal(v1, v2, v3) Tsym = SymTridiagonal(v2, v1) - @test convert(AbstractArray{Float64}, T)::Tridiagonal{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == T - @test convert(AbstractMatrix{Float64}, T)::Tridiagonal{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == T - @test convert(AbstractArray{Float64}, Tsym)::SymTridiagonal{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == Tsym - @test convert(AbstractMatrix{Float64}, Tsym)::SymTridiagonal{Float64,ImmutableArray{Float64,1,Array{Float64,1}}} == Tsym + @test convert(AbstractArray{Float64}, T)::Tridiagonal{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == T + @test convert(AbstractMatrix{Float64}, T)::Tridiagonal{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == T + @test convert(AbstractArray{Float64}, Tsym)::SymTridiagonal{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == Tsym + @test convert(AbstractMatrix{Float64}, Tsym)::SymTridiagonal{Float64,SimpleImmutableArray{Float64,1,Array{Float64,1}}} == Tsym end @testset "dot(x,A,y) for A::Tridiagonal or SymTridiagonal" begin diff --git a/test/choosetests.jl b/test/choosetests.jl index 98eb5f6d70fbc..54635c1edefb1 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -23,7 +23,7 @@ const TESTNAMES = [ "errorshow", "sets", "goto", "llvmcall", "llvmcall2", "ryu", "some", "meta", "stacktraces", "docs", "misc", "threads", "stress", "binaryplatforms", "atexit", - "enums", "cmdlineargs", "int", "interpreter", + "enums", "cmdlineargs", "immutablearray", "int", "interpreter", "checked", "bitset", "floatfuncs", "precompile", "boundscheck", "error", "ambiguous", "cartesian", "osutils", "channels", "iostream", "secretbuffer", "specificity", diff --git a/test/compiler/immutablearray.jl b/test/compiler/immutablearray.jl index 474e5dfc0f657..1fb8596e707b0 100644 --- a/test/compiler/immutablearray.jl +++ b/test/compiler/immutablearray.jl @@ -1,12 +1,421 @@ -using Base.Experimental: ImmutableArray -function simple() - a = Vector{Float64}(undef, 5) - for i = 1:5 - a[i] = i +import Core: arrayfreeze, mutating_arrayfreeze, arraythaw +import Core.Compiler: arrayfreeze_tfunc, mutating_arrayfreeze_tfunc, arraythaw_tfunc + +@testset "ImmutableArray tfuncs" begin + @test arrayfreeze_tfunc(Vector{Int}) === ImmutableVector{Int} + @test arrayfreeze_tfunc(Vector) === ImmutableVector + @test arrayfreeze_tfunc(Array) === ImmutableArray + @test arrayfreeze_tfunc(Any) === ImmutableArray + @test arrayfreeze_tfunc(ImmutableVector{Int}) === Union{} + @test arrayfreeze_tfunc(ImmutableVector) === Union{} + @test arrayfreeze_tfunc(ImmutableArray) === Union{} + @test mutating_arrayfreeze_tfunc(Vector{Int}) === ImmutableVector{Int} + @test mutating_arrayfreeze_tfunc(Vector) === ImmutableVector + @test mutating_arrayfreeze_tfunc(Array) === ImmutableArray + @test mutating_arrayfreeze_tfunc(Any) === ImmutableArray + @test mutating_arrayfreeze_tfunc(ImmutableVector{Int}) === Union{} + @test mutating_arrayfreeze_tfunc(ImmutableVector) === Union{} + @test mutating_arrayfreeze_tfunc(ImmutableArray) === Union{} + @test arraythaw_tfunc(ImmutableVector{Int}) === Vector{Int} + @test arraythaw_tfunc(ImmutableVector) === Vector + @test arraythaw_tfunc(ImmutableArray) === Array + @test arraythaw_tfunc(Any) === Array + @test arraythaw_tfunc(Vector{Int}) === Union{} + @test arraythaw_tfunc(Vector) === Union{} + @test arraythaw_tfunc(Array) === Union{} +end + +# mutating_arrayfreeze optimization +# ================================= + +import Core.Compiler: argextype, singleton_type +const EMPTY_SPTYPES = Any[] + +code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo + +# check if `x` is a statement with a given `head` +isnew(@nospecialize x) = Meta.isexpr(x, :new) + +# check if `x` is a dynamic call of a given function +iscall(y) = @nospecialize(x) -> iscall(y, x) +function iscall((src, f)::Tuple{Core.CodeInfo,Base.Callable}, @nospecialize(x)) + return iscall(x) do @nospecialize x + singleton_type(argextype(x, src, EMPTY_SPTYPES)) === f + end +end +iscall(pred::Base.Callable, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1]) + +# check if `x` is a statically-resolved call of a function whose name is `sym` +isinvoke(y) = @nospecialize(x) -> isinvoke(y, x) +isinvoke(sym::Symbol, @nospecialize(x)) = isinvoke(mi->mi.def.name===sym, x) +isinvoke(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :invoke) && pred(x.args[1]::Core.MethodInstance) + +function is_array_alloc(@nospecialize x) + Meta.isexpr(x, :foreigncall) || return false + args = x.args + name = args[1] + isa(name, QuoteNode) && (name = name.value) + isa(name, Symbol) || return false + return Core.Compiler.alloc_array_ndims(name) !== nothing +end + +# optimizable examples +# -------------------- + +let # simplest -- vector + function optimizable(gen) + a = [1,2,3,4,5] + return gen(a) + end + let src = code_typed1(optimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable(identity) + allocated = @allocated optimizable(identity) + optimizable(ImmutableArray) + @test allocated == @allocated optimizable(ImmutableArray) + end +end + +let # handle matrix etc. (actually this example also requires inter-procedural escape handling) + function optimizable(gen) + a = [1 2 3; 4 5 6] + b = [1 2 3 4 5 6] + return gen(a), gen(b) + end + let src = code_typed1(optimizable, (Type{ImmutableArray},)) + # @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 2 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable(identity) + allocated = @allocated optimizable(identity) + optimizable(ImmutableArray) + @test allocated == @allocated optimizable(ImmutableArray) + end +end + +let # multiple returns don't matter + function optimizable(gen) + a = [1,2,3,4,5] + return gen(a), gen(a) + end + let src = code_typed1(optimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 2 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable(identity) + allocated = @allocated optimizable(identity) + optimizable(ImmutableArray) + @test allocated == @allocated optimizable(ImmutableArray) + end +end + +let # arrayset + function optimizable1(gen) + a = Vector{Int}(undef, 5) + for i = 1:5 + a[i] = i + end + return gen(a) + end + let src = code_typed1(optimizable1, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable1(identity) + allocated = @allocated optimizable1(identity) + optimizable1(ImmutableArray) + @test allocated == @allocated optimizable1(ImmutableArray) + end + + function unoptimizable(gen) + a = Matrix{Float64}(undef, 5, 2) + for i = 1:5 + for j = 1:2 + a[i, j] = i + j + end + end + return gen(a) + end + let src = code_typed1(unoptimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + unoptimizable(identity) + allocated = @allocated unoptimizable(identity) + unoptimizable(ImmutableArray) + @test allocated == @allocated unoptimizable(ImmutableArray) + end +end + +let # arrayref + function optimizable(gen) + a = [1,2,3] + b = getindex(a, 2) + return gen(a) + end + let src = code_typed1(optimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable(identity) + allocated = @allocated optimizable(identity) + optimizable(ImmutableArray) + @test allocated == @allocated optimizable(ImmutableArray) + end +end + +let # array resize + function optimizable(gen, n) + a = Int[] + for i = 1:n + push!(a, i) + end + return gen(a) + end + let src = code_typed1(optimizable, (Type{ImmutableArray},Int,)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable(identity, 42) + allocated = @allocated optimizable(identity, 42) + optimizable(ImmutableArray, 42) + @test allocated == @allocated optimizable(ImmutableArray, 42) + end +end + +@noinline function same′(a) + return reverse(reverse(a)) +end +let # inter-procedural + function optimizable(gen) + a = ones(5) + a = same′(a) + return gen(a) + end + let src = code_typed1(optimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(isinvoke(:same′), src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable(identity) + allocated = @allocated optimizable(identity) + optimizable(ImmutableArray) + @test allocated == @allocated optimizable(ImmutableArray) + end +end + +let # ignore ThrownEscape if it never happens when `arrayfreeze` is called + function optimizable(gen, n) + a = Int[] + for i = 1:n + push!(a, i) + end + n > 100 && throw(a) + return gen(a) + end + let src = code_typed1(optimizable, (Type{ImmutableArray},Int,)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable(identity, 42) + allocated = @allocated optimizable(identity, 42) + optimizable(ImmutableArray, 42) + @test allocated == @allocated optimizable(ImmutableArray, 42) end - ImmutableArray(a) end -let - @allocated(simple()) - @test @allocated(simple()) < 100 +@noinline function ipo_getindex′(a, n) + ele = getindex(a, n) + return ele +end +let # ignore ThrownEscape if it never happens when `arrayfreeze` is called (interprocedural) + function optimizable(gen) + a = [1,2,3] + b = ipo_getindex′(a, 2) + return gen(a) + end + let src = code_typed1(optimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(isinvoke(:ipo_getindex′), src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable(identity) + allocated = @allocated optimizable(identity) + optimizable(ImmutableArray) + @test allocated == @allocated optimizable(ImmutableArray) + end +end + +let # nested case + function optimizable(gen, n) + a = [collect(1:m) for m in 1:n] + for i = 1:n + a[i][1] = i + end + return gen(a) + end + let src = code_typed1(optimizable, (Type{ImmutableArray},Int)) + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable(identity, 100) + allocated = @allocated optimizable(identity, 100) + optimizable(ImmutableArray, 100) + @test allocated == @allocated optimizable(ImmutableArray, 100) + end +end + +# demonstrate alias analysis +broadcast_identity(a) = broadcast(identity, a) +function optimizable_aa(gen, n) # can't be a closure somehow + return collect(1:n) |> + Ref |> Ref |> Ref |> + broadcast_identity |> broadcast_identity |> broadcast_identity |> + gen +end +let src = code_typed1(optimizable_aa, (Type{ImmutableArray},Int)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable_aa(identity, 100) + allocated = @allocated optimizable_aa(identity, 100) + optimizable_aa(ImmutableArray, 100) + @test allocated == @allocated optimizable_aa(ImmutableArray, 100) +end + +let # should be possible if we change BoundsError semantics (so that it doesn't capture the indexed array) + function optimizable(gen) + a = [1,2,3] + try + getindex(a, 4) + catch + end + return gen(a) + end + let src = code_typed1(optimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test_broken count(iscall((src, mutating_arrayfreeze)), src.code) == 1 + @test_broken count(iscall((src, arrayfreeze)), src.code) == 0 + optimizable(identity) + allocated = @allocated optimizable(identity) + optimizable(ImmutableArray) + @test_broken allocated == @allocated optimizable(ImmutableArray) + end +end + +# unoptimizable examples +# ---------------------- + +const Rx = Ref{Any}() # global memory + +let # return escape + function unoptimizable(gen) + a = [1,2,3,4,5] + return a, gen(a) + end + let src = code_typed1(unoptimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 0 + @test count(iscall((src, arrayfreeze)), src.code) == 1 + unoptimizable(ImmutableArray) + a, b = unoptimizable(ImmutableArray) + @test a !== b + @test !(a isa ImmutableArray) + end +end + +let # arg escape + unoptimizable(a, gen) = gen(a) + let src = code_typed1(unoptimizable, (Vector{Int}, Type{ImmutableArray},)) + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 0 + @test count(iscall((src, arrayfreeze)), src.code) == 1 + a = [1,2,3] + unoptimizable(a, ImmutableArray) + b = unoptimizable(a, ImmutableArray) + @test a !== b + @test !(a isa ImmutableArray) + @test b isa ImmutableArray + end +end + +let # global escape + function unoptimizable(gen) + a = [1,2,3,4,5] + global global_array = a + return gen(a) + end + let src = code_typed1(unoptimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 0 + @test count(iscall((src, arrayfreeze)), src.code) == 1 + unoptimizable(identity) + unoptimizable(ImmutableArray) + a = unoptimizable(ImmutableArray) + @test global_array !== a + @test !(global_array isa ImmutableArray) + end +end + +let # global escape + function unoptimizable(gen) + a = [1,2,3,4,5] + Rx[] = a + return gen(a) + end + let src = code_typed1(unoptimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 0 + @test count(iscall((src, arrayfreeze)), src.code) == 1 + unoptimizable(identity) + unoptimizable(ImmutableArray) + a = unoptimizable(ImmutableArray) + @test Rx[] !== a + @test !(Rx[] isa ImmutableArray) + end +end + +let # escapes via exception + function unoptimizable(gen) + a = [1,2,3,4,5] + try + throw(a) + catch err + global global_array = err + end + return gen(a) + end + let src = code_typed1(unoptimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 0 + @test count(iscall((src, arrayfreeze)), src.code) == 1 + unoptimizable(identity) + allocated = @allocated unoptimizable(identity) + unoptimizable(ImmutableArray) + local a + @test allocated < @allocated a = unoptimizable(ImmutableArray) + @test global_array !== a + @test !(global_array isa ImmutableArray) + end +end + +const g = Ref{Any}() +let # escapes via BoundsError + function unoptimizable(gen) + a = [1,2,3] + try + getindex(a, 4) + catch e + g[] = e.a + end + return gen(a) + end + let src = code_typed1(unoptimizable, (Type{ImmutableArray},)) + @test count(is_array_alloc, src.code) == 1 + @test count(iscall((src, arrayfreeze)), src.code) == 1 + @test count(iscall((src, mutating_arrayfreeze)), src.code) == 0 + unoptimizable(identity) + unoptimizable(ImmutableArray) + ia = unoptimizable(ImmutableArray) + @test g[] !== ia + end end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 79030bd910990..0fa4393f60c6a 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1542,9 +1542,17 @@ import Core.Compiler: Const, arrayref_tfunc, arrayset_tfunc, arraysize_tfunc @test arrayref_tfunc(Const(true), Vector{Int}, Int, Vararg{Int}) === Int @test arrayref_tfunc(Const(true), Vector{Int}, Vararg{Int}) === Int @test arrayref_tfunc(Const(true), Vector{Int}) === Union{} +@test arrayref_tfunc(Const(true), Core.ImmutableArray{Int,1}, Int) === Int +@test arrayref_tfunc(Const(true), Core.ImmutableArray{<:Integer,1}, Int) === Integer +@test arrayref_tfunc(Const(true), Core.ImmutableArray, Int) === Any +@test arrayref_tfunc(Const(true), Core.ImmutableArray{Int,1}, Int, Vararg{Int}) === Int +@test arrayref_tfunc(Const(true), Core.ImmutableArray{Int,1}, Vararg{Int}) === Int +@test arrayref_tfunc(Const(true), Core.ImmutableArray{Int,1}) === Union{} @test arrayref_tfunc(Const(true), String, Int) === Union{} @test arrayref_tfunc(Const(true), Vector{Int}, Float64) === Union{} @test arrayref_tfunc(Int, Vector{Int}, Int) === Union{} +@test arrayref_tfunc(Const(true), Core.ImmutableArray{Int,1}, Float64) === Union{} +@test arrayref_tfunc(Int, Core.ImmutableArray{Int,1}, Int) === Union{} @test arrayset_tfunc(Const(true), Vector{Int}, Int, Int) === Vector{Int} let ua = Vector{<:Integer} @test arrayset_tfunc(Const(true), ua, Int, Int) === ua @@ -1553,6 +1561,13 @@ end @test arrayset_tfunc(Const(true), Any, Int, Int) === Any @test arrayset_tfunc(Const(true), Vector{String}, String, Int, Vararg{Int}) === Vector{String} @test arrayset_tfunc(Const(true), Vector{String}, String, Vararg{Int}) === Vector{String} +@test arrayset_tfunc(Const(true), Core.ImmutableArray{Int,1}, Int, Int) === Union{} +let ua = Core.ImmutableArray{<:Integer,1} + @test arrayset_tfunc(Const(true), ua, Int, Int) === Union{} +end +@test arrayset_tfunc(Const(true), Core.ImmutableArray, Int, Int) === Union{} +@test arrayset_tfunc(Const(true), Core.ImmutableArray{String,1}, String, Int, Vararg{Int}) === Union{} +@test arrayset_tfunc(Const(true), Core.ImmutableArray{String,1}, String, Vararg{Int}) === Union{} @test arrayset_tfunc(Const(true), Vector{String}, String) === Union{} @test arrayset_tfunc(Const(true), String, Char, Int) === Union{} @test arrayset_tfunc(Const(true), Vector{Int}, Int, Float64) === Union{} @@ -1560,6 +1575,8 @@ end @test arrayset_tfunc(Const(true), Vector{Int}, Float64, Int) === Union{} @test arraysize_tfunc(Vector, Int) === Int @test arraysize_tfunc(Vector, Float64) === Union{} +@test arraysize_tfunc(Core.ImmutableArray, Int) === Int +@test arraysize_tfunc(Core.ImmutableArray, Float64) === Union{} @test arraysize_tfunc(String, Int) === Union{} function f23024(::Type{T}, ::Int) where T diff --git a/test/immutablearray.jl b/test/immutablearray.jl new file mode 100644 index 0000000000000..f56d625900769 --- /dev/null +++ b/test/immutablearray.jl @@ -0,0 +1,183 @@ + +using Test +import Core: arrayfreeze, mutating_arrayfreeze, arraythaw + +@testset "basic ImmutableArray functionality" begin + eltypes = (Float16, Float32, Float64, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128) + for t in eltypes + a = rand(t, rand(1:100), rand(1:10)) + b = ImmutableArray(a) + @test a == b + @test a !== b + @test length(a) == length(b) + for i in 1:length(a) + getindex(a, i) == getindex(b, i) + end + @test size(a) == size(b) + if !(t in (Float16, Float32, Float64)) + @test sum(a) == sum(b) + end + @test reverse(a) == reverse(b) + @test ndims(a) == ndims(b) + for d in 1:ndims(a) + @test axes(a, d) == axes(b, d) + end + @test strides(a) == strides(b) + @test keys(a) == keys(b) + @test IndexStyle(a) == IndexStyle(b) + @test eachindex(a) == eachindex(b) + @test isempty(a) == isempty(b) + # Check that broadcast precedence is working correctly + @test typeof(a .+ b) <: ImmutableArray + end + +end + +@testset "ImmutableArray builtins" begin + a = [1,2,3] + b = ImmutableArray(a) + # errors + @test_throws ArgumentError arrayfreeze() + @test_throws ArgumentError arrayfreeze([1,2,3], nothing) + @test_throws TypeError arrayfreeze(b) + @test_throws TypeError arrayfreeze("not an array") + @test_throws ArgumentError mutating_arrayfreeze() + @test_throws ArgumentError mutating_arrayfreeze([1,2,3], nothing) + @test_throws TypeError mutating_arrayfreeze(b) + @test_throws TypeError mutating_arrayfreeze("not an array") + @test_throws ArgumentError arraythaw() + @test_throws ArgumentError arraythaw([1,2,3], nothing) + @test_throws TypeError arraythaw(a) + @test_throws TypeError arraythaw("not an array") + + @test arrayfreeze(a) === b + @test arraythaw(b) !== a # arraythaw copies so not === + @test arraythaw(arrayfreeze(a)) == a + @test arraythaw(arrayfreeze(a)) !== a + @test arrayfreeze(arraythaw(b)) === b + @test arraythaw(arrayfreeze(arraythaw(b))) == b + @test arraythaw(arrayfreeze(arraythaw(b))) !== b + + mutating_arrayfreeze(a) # last because this mutates a + @test isa(a, ImmutableArray) + @test a === b + @test arraythaw(a) !== a + @test !isa(arraythaw(a), ImmutableArray) +end + +A = ImmutableArray(rand(5,4,3)) +@testset "Bounds checking" begin + @test checkbounds(Bool, A, 1, 1, 1) == true + @test checkbounds(Bool, A, 5, 4, 3) == true + @test checkbounds(Bool, A, 0, 1, 1) == false + @test checkbounds(Bool, A, 1, 0, 1) == false + @test checkbounds(Bool, A, 1, 1, 0) == false + @test checkbounds(Bool, A, 6, 4, 3) == false + @test checkbounds(Bool, A, 5, 5, 3) == false + @test checkbounds(Bool, A, 5, 4, 4) == false + @test checkbounds(Bool, A, 1) == true # linear indexing + @test checkbounds(Bool, A, 60) == true + @test checkbounds(Bool, A, 61) == false + @test checkbounds(Bool, A, 2, 2, 2, 1) == true # extra indices + @test checkbounds(Bool, A, 2, 2, 2, 2) == false + @test checkbounds(Bool, A, 1, 1) == false + @test checkbounds(Bool, A, 1, 12) == false + @test checkbounds(Bool, A, 5, 12) == false + @test checkbounds(Bool, A, 1, 13) == false + @test checkbounds(Bool, A, 6, 12) == false +end + +@testset "single CartesianIndex" begin + @test checkbounds(Bool, A, CartesianIndex((1, 1, 1))) == true + @test checkbounds(Bool, A, CartesianIndex((5, 4, 3))) == true + @test checkbounds(Bool, A, CartesianIndex((0, 1, 1))) == false + @test checkbounds(Bool, A, CartesianIndex((1, 0, 1))) == false + @test checkbounds(Bool, A, CartesianIndex((1, 1, 0))) == false + @test checkbounds(Bool, A, CartesianIndex((6, 4, 3))) == false + @test checkbounds(Bool, A, CartesianIndex((5, 5, 3))) == false + @test checkbounds(Bool, A, CartesianIndex((5, 4, 4))) == false + @test checkbounds(Bool, A, CartesianIndex((1,))) == false + @test checkbounds(Bool, A, CartesianIndex((60,))) == false + @test checkbounds(Bool, A, CartesianIndex((61,))) == false + @test checkbounds(Bool, A, CartesianIndex((2, 2, 2, 1,))) == true + @test checkbounds(Bool, A, CartesianIndex((2, 2, 2, 2,))) == false + @test checkbounds(Bool, A, CartesianIndex((1, 1,))) == false + @test checkbounds(Bool, A, CartesianIndex((1, 12,))) == false + @test checkbounds(Bool, A, CartesianIndex((5, 12,))) == false + @test checkbounds(Bool, A, CartesianIndex((1, 13,))) == false + @test checkbounds(Bool, A, CartesianIndex((6, 12,))) == false +end + +@testset "mix of CartesianIndex and Int" begin + @test checkbounds(Bool, A, CartesianIndex((1,)), 1, CartesianIndex((1,))) == true + @test checkbounds(Bool, A, CartesianIndex((5, 4)), 3) == true + @test checkbounds(Bool, A, CartesianIndex((0, 1)), 1) == false + @test checkbounds(Bool, A, 1, CartesianIndex((0, 1))) == false + @test checkbounds(Bool, A, 1, 1, CartesianIndex((0,))) == false + @test checkbounds(Bool, A, 6, CartesianIndex((4, 3))) == false + @test checkbounds(Bool, A, 5, CartesianIndex((5,)), 3) == false + @test checkbounds(Bool, A, CartesianIndex((5,)), CartesianIndex((4,)), CartesianIndex((4,))) == false +end + +@testset "vector indices" begin + @test checkbounds(Bool, A, 1:5, 1:4, 1:3) == true + @test checkbounds(Bool, A, 0:5, 1:4, 1:3) == false + @test checkbounds(Bool, A, 1:5, 0:4, 1:3) == false + @test checkbounds(Bool, A, 1:5, 1:4, 0:3) == false + @test checkbounds(Bool, A, 1:6, 1:4, 1:3) == false + @test checkbounds(Bool, A, 1:5, 1:5, 1:3) == false + @test checkbounds(Bool, A, 1:5, 1:4, 1:4) == false + @test checkbounds(Bool, A, 1:60) == true + @test checkbounds(Bool, A, 1:61) == false + @test checkbounds(Bool, A, 2, 2, 2, 1:1) == true # extra indices + @test checkbounds(Bool, A, 2, 2, 2, 1:2) == false + @test checkbounds(Bool, A, 1:5, 1:4) == false + @test checkbounds(Bool, A, 1:5, 1:12) == false + @test checkbounds(Bool, A, 1:5, 1:13) == false + @test checkbounds(Bool, A, 1:6, 1:12) == false +end + +@testset "logical" begin + @test checkbounds(Bool, A, trues(5), trues(4), trues(3)) == true + @test checkbounds(Bool, A, trues(6), trues(4), trues(3)) == false + @test checkbounds(Bool, A, trues(5), trues(5), trues(3)) == false + @test checkbounds(Bool, A, trues(5), trues(4), trues(4)) == false + @test checkbounds(Bool, A, trues(60)) == true + @test checkbounds(Bool, A, trues(61)) == false + @test checkbounds(Bool, A, 2, 2, 2, trues(1)) == true # extra indices + @test checkbounds(Bool, A, 2, 2, 2, trues(2)) == false + @test checkbounds(Bool, A, trues(5), trues(12)) == false + @test checkbounds(Bool, A, trues(5), trues(13)) == false + @test checkbounds(Bool, A, trues(6), trues(12)) == false + @test checkbounds(Bool, A, trues(5, 4, 3)) == true + @test checkbounds(Bool, A, trues(5, 4, 2)) == false + @test checkbounds(Bool, A, trues(5, 12)) == false + @test checkbounds(Bool, A, trues(1, 5), trues(1, 4, 1), trues(1, 1, 3)) == false + @test checkbounds(Bool, A, trues(1, 5), trues(1, 4, 1), trues(1, 1, 2)) == false + @test checkbounds(Bool, A, trues(1, 5), trues(1, 5, 1), trues(1, 1, 3)) == false + @test checkbounds(Bool, A, trues(1, 5), :, 2) == false + @test checkbounds(Bool, A, trues(5, 4), trues(3)) == true + @test checkbounds(Bool, A, trues(4, 4), trues(3)) == true + @test checkbounds(Bool, A, trues(5, 4), trues(2)) == false + @test checkbounds(Bool, A, trues(6, 4), trues(3)) == false + @test checkbounds(Bool, A, trues(5, 4), trues(4)) == false +end + +@testset "array of CartesianIndex" begin + @test checkbounds(Bool, A, [CartesianIndex((1, 1, 1))]) == true + @test checkbounds(Bool, A, [CartesianIndex((5, 4, 3))]) == true + @test checkbounds(Bool, A, [CartesianIndex((0, 1, 1))]) == false + @test checkbounds(Bool, A, [CartesianIndex((1, 0, 1))]) == false + @test checkbounds(Bool, A, [CartesianIndex((1, 1, 0))]) == false + @test checkbounds(Bool, A, [CartesianIndex((6, 4, 3))]) == false + @test checkbounds(Bool, A, [CartesianIndex((5, 5, 3))]) == false + @test checkbounds(Bool, A, [CartesianIndex((5, 4, 4))]) == false + @test checkbounds(Bool, A, [CartesianIndex((1, 1))], 1) == true + @test checkbounds(Bool, A, [CartesianIndex((5, 4))], 3) == true + @test checkbounds(Bool, A, [CartesianIndex((0, 1))], 1) == false + @test checkbounds(Bool, A, [CartesianIndex((1, 0))], 1) == false + @test checkbounds(Bool, A, [CartesianIndex((1, 1))], 0) == false + @test checkbounds(Bool, A, [CartesianIndex((6, 4))], 3) == false + @test checkbounds(Bool, A, [CartesianIndex((5, 5))], 3) == false + @test checkbounds(Bool, A, [CartesianIndex((5, 4))], 4) == false +end diff --git a/test/testhelpers/ImmutableArrays.jl b/test/testhelpers/ImmutableArrays.jl deleted file mode 100644 index df2a78387e07b..0000000000000 --- a/test/testhelpers/ImmutableArrays.jl +++ /dev/null @@ -1,28 +0,0 @@ -# This file is a part of Julia. License is MIT: https://julialang.org/license - -# ImmutableArrays (arrays that implement getindex but not setindex!) - -# This test file defines an array wrapper that is immutable. It can be used to -# test the action of methods on immutable arrays. - -module ImmutableArrays - -export ImmutableArray - -"An immutable wrapper type for arrays." -struct ImmutableArray{T,N,A<:AbstractArray} <: AbstractArray{T,N} - data::A -end - -ImmutableArray(data::AbstractArray{T,N}) where {T,N} = ImmutableArray{T,N,typeof(data)}(data) - -# Minimal AbstractArray interface -Base.size(A::ImmutableArray) = size(A.data) -Base.size(A::ImmutableArray, d) = size(A.data, d) -Base.getindex(A::ImmutableArray, i...) = getindex(A.data, i...) - -# The immutable array remains immutable after conversion to AbstractArray -AbstractArray{T}(A::ImmutableArray) where {T} = ImmutableArray(AbstractArray{T}(A.data)) -AbstractArray{T,N}(A::ImmutableArray{S,N}) where {S,T,N} = ImmutableArray(AbstractArray{T,N}(A.data)) - -end diff --git a/test/testhelpers/SimpleImmutableArrays.jl b/test/testhelpers/SimpleImmutableArrays.jl new file mode 100644 index 0000000000000..028d3c31627b8 --- /dev/null +++ b/test/testhelpers/SimpleImmutableArrays.jl @@ -0,0 +1,28 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +# SimpleImmutableArrays (arrays that implement getindex but not setindex!) + +# This test file defines an array wrapper that is immutable. It can be used to +# test the action of methods on immutable arrays. + +module SimpleImmutableArrays + +export SimpleImmutableArray + +"An immutable wrapper type for arrays." +struct SimpleImmutableArray{T,N,A<:AbstractArray} <: AbstractArray{T,N} + data::A +end + +SimpleImmutableArray(data::AbstractArray{T,N}) where {T,N} = SimpleImmutableArray{T,N,typeof(data)}(data) + +# Minimal AbstractArray interface +Base.size(A::SimpleImmutableArray) = size(A.data) +Base.size(A::SimpleImmutableArray, d) = size(A.data, d) +Base.getindex(A::SimpleImmutableArray, i...) = getindex(A.data, i...) + +# The immutable array remains immutable after conversion to AbstractArray +AbstractArray{T}(A::SimpleImmutableArray) where {T} = SimpleImmutableArray(AbstractArray{T}(A.data)) +AbstractArray{T,N}(A::SimpleImmutableArray{S,N}) where {S,T,N} = SimpleImmutableArray(AbstractArray{T,N}(A.data)) + +end \ No newline at end of file From 375147164e3a55a067192bae1d95c034356515c8 Mon Sep 17 00:00:00 2001 From: Ian Atol Date: Mon, 28 Feb 2022 18:32:02 -0500 Subject: [PATCH 5/5] Add simple immutable array memory optimization --- base/compiler/optimize.jl | 14 +-- .../ssair/EscapeAnalysis/EscapeAnalysis.jl | 4 +- base/compiler/ssair/passes.jl | 87 ++++--------------- 3 files changed, 28 insertions(+), 77 deletions(-) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index e84f77ae1ea48..d127239c023e5 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -551,12 +551,14 @@ function run_passes(ci::CodeInfo, sv::OptimizationState, caller::InferenceResult @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) # @timeit "verify 2" verify_ir(ir) @timeit "compact 2" ir = compact!(ir) - @timeit "SROA" ir, memory_opt = linear_pass!(ir) - if memory_opt - @timeit "memory_opt_pass!" begin - @timeit "Local EA" estate = analyze_escapes(ir, - nargs, #=call_resolved=#true, null_escape_cache) - @timeit "memory_opt_pass!" ir = memory_opt_pass!(ir, estate) + @timeit "SROA" ir, memory_opt, imarray_memory_opt = linear_pass!(ir) + if memory_opt || imarray_memory_opt + @timeit "Local EA" estate = analyze_escapes(ir, nargs, #=call_resolved=#true, null_escape_cache) + if memory_opt + @timeit "Memory opt" ir = memory_opt_pass!(ir, estate) + end + if imarray_memory_opt + @timeit "imarray opt" ir = imarray_memoryopt_pass!(ir, estate) end end @timeit "ADCE" ir = adce_pass!(ir) diff --git a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl index 73c6c7b3d0b2d..cb11c7f32dd6a 100644 --- a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl +++ b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl @@ -27,7 +27,7 @@ import ._TOP_MOD: # Base definitions pop!, push!, pushfirst!, empty!, delete!, max, min, enumerate, unwrap_unionall, ismutabletype import Core.Compiler: # Core.Compiler specific definitions - Bottom, InferenceResult, IRCode, IR_FLAG_EFFECT_FREE, + Arrayish, Bottom, InferenceResult, IRCode, IR_FLAG_EFFECT_FREE, isbitstype, isexpr, is_meta_expr_head, println, widenconst, argextype, singleton_type, fieldcount_noerror, try_compute_field, try_compute_fieldidx, hasintersect, ⊑, intrinsic_nothrow, array_builtin_common_typecheck, arrayset_typecheck, @@ -1608,7 +1608,7 @@ function escape_builtin!(::typeof(arrayref), astate::AnalysisState, pc::Int, arg argtypes = Any[argextype(args[i], astate.ir) for i in 2:length(args)] boundcheckt = argtypes[1] aryt = argtypes[2] - if !array_builtin_common_typecheck(Array, boundcheckt, aryt, argtypes, 3) + if !array_builtin_common_typecheck(Arrayish, boundcheckt, aryt, argtypes, 3) add_thrown_escapes!(astate, pc, args, 2) end ary = args[3] diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 1843de2f395f2..95068e06242ea 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -547,7 +547,11 @@ function linear_pass!(ir::IRCode) compact = IncrementalCompact(ir) lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() local memory_opt = false # whether or not to run the memory_opt_pass! pass later + local imarray_memory_opt = false for ((_, idx), stmt) in compact + # presence of arrayfreeze means possible copy elision opportunity, so run imarray_memoryopt_pass! + isa(stmt, GlobalRef) && stmt.name === :arrayfreeze && (imarray_memory_opt = true) + isa(stmt, Expr) || continue field_ordering = :unspecified if isexpr(stmt, :new) @@ -663,7 +667,7 @@ function linear_pass!(ir::IRCode) non_dce_finish!(compact) simple_dce!(compact) ir = complete(compact) - return ir, memory_opt + return ir, memory_opt, imarray_memory_opt end function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) @@ -1761,76 +1765,21 @@ function cfg_simplify!(ir::IRCode) return finish(compact) end -function is_allocation(stmt) - isexpr(stmt, :foreigncall) || return false - s = stmt.args[1] - isa(s, QuoteNode) && (s = s.value) - return s === :jl_alloc_array_1d -end - -function memory_opt!(ir::IRCode) - compact = IncrementalCompact(ir, false) - uses = IdDict{Int, Vector{Int}}() - relevant = IdSet{Int}() - revisit = Int[] - function mark_val(val) - isa(val, SSAValue) || return - val.id in relevant && pop!(relevant, val.id) - end - for ((_, idx), stmt) in compact - if isa(stmt, ReturnNode) - isdefined(stmt, :val) || continue - val = stmt.val - if isa(val, SSAValue) && val.id in relevant - (haskey(uses, val.id)) || (uses[val.id] = Int[]) - push!(uses[val.id], idx) +function imarray_memoryopt_pass!(ir::IRCode, estate::EscapeState) + # mark statements that possibly can be optimized + for idx in 1:length(ir.stmts) + stmt = ir.stmts[idx][:inst] + isexpr(stmt, :call) || continue + if is_known_call(stmt, Core.arrayfreeze, ir) + # array as SSA value might have been initialized within this frame + # (thus potentially doesn't escape to anywhere) + ary = stmt.args[2] + if isa(ary, SSAValue) + # if array doesn't escape, we can just change the tag and avoid allocation + has_no_escape(estate[ary]) || continue + stmt.args[1] = GlobalRef(Core, :mutating_arrayfreeze) end - continue end - (isexpr(stmt, :call) || isexpr(stmt, :foreigncall)) || continue - if is_allocation(stmt) - push!(relevant, idx) - # TODO: Mark everything else here - continue - end - # TODO: Replace this by interprocedural escape analysis - if is_known_call(stmt, arrayset, compact) - # The value being set escapes, everything else doesn't - mark_val(stmt.args[4]) - arr = stmt.args[3] - if isa(arr, SSAValue) && arr.id in relevant - (haskey(uses, arr.id)) || (uses[arr.id] = Int[]) - push!(uses[arr.id], idx) - end - elseif is_known_call(stmt, Core.arrayfreeze, compact) && isa(stmt.args[2], SSAValue) - push!(revisit, idx) - else - # For now we assume everything escapes - # TODO: We could handle PhiNodes specially and improve this - for ur in userefs(stmt) - mark_val(ur[]) - end - end - end - ir = finish(compact) - isempty(revisit) && return ir - domtree = construct_domtree(ir.cfg.blocks) - for idx in revisit - # Make sure that the value we reference didn't escape - id = ir.stmts[idx][:inst].args[2].id - (id in relevant) || continue - - # We're ok to steal the memory if we don't dominate any uses - ok = true - for use in uses[id] - if ssadominates(ir, domtree, idx, use) - ok = false - break - end - end - ok || continue - - ir.stmts[idx][:inst].args[1] = Core.mutating_arrayfreeze end return ir end