Skip to content

Commit a164ef4

Browse files
Ian AtolIan Atol
authored andcommitted
ImmutableArray works as uType for DiffEq problems
memory_opt! performance and safety improvements Quick fixes to build
1 parent fcbafaa commit a164ef4

File tree

7 files changed

+337
-23
lines changed

7 files changed

+337
-23
lines changed

base/abstractarray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,10 @@ function copy(a::AbstractArray)
10731073
copymutable(a)
10741074
end
10751075

1076+
function copy(a::Core.ImmutableArray)
1077+
a
1078+
end
1079+
10761080
function copyto!(B::AbstractVecOrMat{R}, ir_dest::AbstractRange{Int}, jr_dest::AbstractRange{Int},
10771081
A::AbstractVecOrMat{S}, ir_src::AbstractRange{Int}, jr_src::AbstractRange{Int}) where {R,S}
10781082
if length(ir_dest) != length(ir_src)

base/array.jl

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,35 @@ Union type of [`DenseVector{T}`](@ref) and [`DenseMatrix{T}`](@ref).
118118
"""
119119
const DenseVecOrMat{T} = Union{DenseVector{T}, DenseMatrix{T}}
120120

121+
"""
122+
ImmutableArray
123+
124+
Dynamically allocated, immutable array.
125+
126+
"""
127+
const ImmutableArray = Core.ImmutableArray
128+
129+
"""
130+
IMArray{T,N}
131+
132+
Union type of [`Array{T,N}`](@ref) and [`ImmutableArray{T,N}`](@ref)
133+
"""
134+
const IMArray{T,N} = Union{Array{T, N}, ImmutableArray{T,N}}
135+
136+
"""
137+
IMVector{T}
138+
139+
One-dimensional [`ImmutableArray`](@ref) or [`Array`](@ref) with elements of type `T`. Alias for `IMArray{T, 1}`.
140+
"""
141+
const IMVector{T} = IMArray{T, 1}
142+
143+
"""
144+
IMMatrix{T}
145+
146+
Two-dimensional [`ImmutableArray`](@ref) or [`Array`](@ref) with elements of type `T`. Alias for `IMArray{T,2}`.
147+
"""
148+
const IMMatrix{T} = IMArray{T, 2}
149+
121150
## Basic functions ##
122151

123152
import Core: arraysize, arrayset, arrayref, const_arrayref
@@ -147,14 +176,13 @@ function vect(X...)
147176
return copyto!(Vector{T}(undef, length(X)), X)
148177
end
149178

150-
const ImmutableArray = Core.ImmutableArray
151-
const IMArray{T,N} = Union{Array{T, N}, ImmutableArray{T,N}}
152-
const IMVector{T} = IMArray{T, 1}
153-
const IMMatrix{T} = IMArray{T, 2}
154-
179+
# Freeze and thaw constructors
155180
ImmutableArray(a::Array) = Core.arrayfreeze(a)
156181
Array(a::ImmutableArray) = Core.arraythaw(a)
157182

183+
ImmutableArray(a::AbstractArray{T,N}) where {T,N} = ImmutableArray{T,N}(a)
184+
185+
# Size functions for arrays, both mutable and immutable
158186
size(a::IMArray, d::Integer) = arraysize(a, convert(Int, d))
159187
size(a::IMVector) = (arraysize(a,1),)
160188
size(a::IMMatrix) = (arraysize(a,1), arraysize(a,2))
@@ -393,6 +421,9 @@ similar(a::Array{T}, m::Int) where {T} = Vector{T}(undef, m)
393421
similar(a::Array, T::Type, dims::Dims{N}) where {N} = Array{T,N}(undef, dims)
394422
similar(a::Array{T}, dims::Dims{N}) where {T,N} = Array{T,N}(undef, dims)
395423

424+
ImmutableArray{T}(undef::UndefInitializer, m::Int) where T = ImmutableArray(Array{T}(undef, m))
425+
ImmutableArray{T}(undef::UndefInitializer, dims::Dims) where T = ImmutableArray(Array{T}(undef, dims))
426+
396427
# T[x...] constructs Array{T,1}
397428
"""
398429
getindex(type[, elements...])
@@ -626,8 +657,8 @@ oneunit(x::AbstractMatrix{T}) where {T} = _one(oneunit(T), x)
626657

627658
## Conversions ##
628659

629-
convert(::Type{T}, a::AbstractArray) where {T<:Array} = a isa T ? a : T(a)
630660
convert(::Type{Union{}}, a::AbstractArray) = throw(MethodError(convert, (Union{}, a)))
661+
convert(T::Union{Type{<:Array},Type{<:Core.ImmutableArray}}, a::AbstractArray) = a isa T ? a : T(a)
631662

632663
promote_rule(a::Type{Array{T,n}}, b::Type{Array{S,n}}) where {T,n,S} = el_same(promote_type(T,S), a, b)
633664

@@ -637,6 +668,7 @@ if nameof(@__MODULE__) === :Base # avoid method overwrite
637668
# constructors should make copies
638669
Array{T,N}(x::AbstractArray{S,N}) where {T,N,S} = copyto_axcheck!(Array{T,N}(undef, size(x)), x)
639670
AbstractArray{T,N}(A::AbstractArray{S,N}) where {T,N,S} = copyto_axcheck!(similar(A,T), A)
671+
ImmutableArray{T,N}(Ar::AbstractArray{S,N}) where {T,N,S} = Core.arrayfreeze(copyto_axcheck!(Array{T,N}(undef, size(Ar)), Ar))
640672
end
641673

642674
## copying iterators to containers

base/broadcast.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,4 +1385,17 @@ function Base.show(io::IO, op::BroadcastFunction)
13851385
end
13861386
Base.show(io::IO, ::MIME"text/plain", op::BroadcastFunction) = show(io, op)
13871387

1388+
struct IMArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
1389+
BroadcastStyle(::Type{<:Core.ImmutableArray}) = IMArrayStyle()
1390+
1391+
#similar has to return mutable array
1392+
function Base.similar(bc::Broadcasted{IMArrayStyle}, ::Type{ElType}) where ElType
1393+
similar(Array{ElType}, axes(bc))
1394+
end
1395+
1396+
@inline function copy(bc::Broadcasted{IMArrayStyle})
1397+
ElType = combine_eltypes(bc.f, bc.args)
1398+
return Core.ImmutableArray(copyto!(similar(bc, ElType), bc))
1399+
end
1400+
13881401
end # module

base/compiler/optimize.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ function run_passes(ci::CodeInfo, sv::OptimizationState)
307307
ir = adce_pass!(ir)
308308
#@Base.show ("after_adce", ir)
309309
@timeit "type lift" ir = type_lift_pass!(ir)
310+
#@timeit "compact 3" ir = compact!(ir)
310311
ir = memory_opt!(ir)
311312
#@Base.show ir
312313
if JLOptions().debug_level == 2

base/compiler/ssair/passes.jl

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,23 +1256,26 @@ function cfg_simplify!(ir::IRCode)
12561256
return finish(compact)
12571257
end
12581258

1259-
function is_allocation(stmt)
1259+
function is_allocation(stmt::Expr)
12601260
isexpr(stmt, :foreigncall) || return false
12611261
s = stmt.args[1]
12621262
isa(s, QuoteNode) && (s = s.value)
1263-
return s === :jl_alloc_array_1d
1263+
return (s === :jl_alloc_array_1d || s === :jl_alloc_array_2d || s === :jl_alloc_array_3d || s === :jl_new_array)
12641264
end
12651265

12661266
function memory_opt!(ir::IRCode)
12671267
compact = IncrementalCompact(ir, false)
12681268
uses = IdDict{Int, Vector{Int}}()
1269-
relevant = IdSet{Int}()
1270-
revisit = Int[]
1271-
function mark_val(val)
1269+
relevant = IdSet{Int}() # allocations
1270+
revisit = Int[] # potential targets for mutating_arrayfreeze
1271+
1272+
function mark_escape(@nospecialize val)
12721273
isa(val, SSAValue) || return
12731274
val.id in relevant && pop!(relevant, val.id)
12741275
end
1276+
12751277
for ((_, idx), stmt) in compact
1278+
12761279
if isa(stmt, ReturnNode)
12771280
isdefined(stmt, :val) || continue
12781281
val = stmt.val
@@ -1282,50 +1285,92 @@ function memory_opt!(ir::IRCode)
12821285
end
12831286
continue
12841287
end
1288+
12851289
(isexpr(stmt, :call) || isexpr(stmt, :foreigncall)) || continue
1290+
12861291
if is_allocation(stmt)
12871292
push!(relevant, idx)
12881293
# TODO: Mark everything else here
12891294
continue
12901295
end
1296+
12911297
# TODO: Replace this by interprocedural escape analysis
12921298
if is_known_call(stmt, arrayset, compact)
1299+
# arrayset expr.args:
1300+
# :(Base.arrayset)
1301+
# false
1302+
# :(%2) array
1303+
# :(%8) value
1304+
# :(%7) index
12931305
# The value being set escapes, everything else doesn't
1294-
mark_val(stmt.args[4])
1306+
(length(stmt.args) == 5) || continue # fix boundserror during precompile --- but how do we have arrayset with < 5 args?
1307+
mark_escape(stmt.args[4])
12951308
arr = stmt.args[3]
12961309
if isa(arr, SSAValue) && arr.id in relevant
12971310
(haskey(uses, arr.id)) || (uses[arr.id] = Int[])
12981311
push!(uses[arr.id], idx)
12991312
end
1313+
13001314
elseif is_known_call(stmt, Core.arrayfreeze, compact) && isa(stmt.args[2], SSAValue)
13011315
push!(revisit, idx)
1316+
1317+
elseif is_known_call(stmt, arraysize, compact) && isa(stmt.args[2], SSAValue) && isa(stmt.args[3], Number)
1318+
arr = stmt.args[2]
1319+
dim = stmt.args[3]
1320+
typ = types(compact)[arr]
1321+
1322+
while !isa(typ, Type)
1323+
typ = typeof(typ)
1324+
end
1325+
1326+
if isa(typ, Core.Const)
1327+
typ = typ.val
1328+
end
1329+
1330+
# make sure this call isn't going to throw
1331+
if typ <: AbstractArray && dim >= 1
1332+
# don't escape the array, but mark usage for dom analysis
1333+
if arr.id in relevant
1334+
(haskey(uses, arr.id)) || (uses[arr.id] = Int[])
1335+
push!(uses[arr.id], idx)
1336+
end
1337+
else # if this call throws or we can't tell, the array definitely escapes
1338+
for ur in userefs(stmt)
1339+
mark_escape(ur[])
1340+
end
1341+
end
13021342
else
13031343
# For now we assume everything escapes
13041344
# TODO: We could handle PhiNodes specially and improve this
13051345
for ur in userefs(stmt)
1306-
mark_val(ur[])
1346+
mark_escape(ur[])
13071347
end
13081348
end
13091349
end
1350+
13101351
ir = finish(compact)
13111352
isempty(revisit) && return ir
1353+
13121354
domtree = construct_domtree(ir.cfg.blocks)
1355+
13131356
for idx in revisit
13141357
# Make sure that the value we reference didn't escape
1315-
id = ir.stmts[idx][:inst].args[2].id
1358+
stmt = ir.stmts[idx][:inst]::Expr
1359+
id = (stmt.args[2]::SSAValue).id
13161360
(id in relevant) || continue
13171361

13181362
# We're ok to steal the memory if we don't dominate any uses
13191363
ok = true
1320-
for use in uses[id]
1321-
if ssadominates(ir, domtree, idx, use)
1322-
ok = false
1323-
break
1364+
if haskey(uses, id)
1365+
for use in uses[id]
1366+
if ssadominates(ir, domtree, idx, use)
1367+
ok = false
1368+
break
1369+
end
13241370
end
13251371
end
13261372
ok || continue
1327-
1328-
ir.stmts[idx][:inst].args[1] = Core.mutating_arrayfreeze
1373+
stmt.args[1] = Core.mutating_arrayfreeze
13291374
end
13301375
return ir
13311376
end

base/pointer.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ cconvert(::Type{Ptr{UInt8}}, s::AbstractString) = String(s)
6363
cconvert(::Type{Ptr{Int8}}, s::AbstractString) = String(s)
6464

6565
unsafe_convert(::Type{Ptr{T}}, a::Array{T}) where {T} = ccall(:jl_array_ptr, Ptr{T}, (Any,), a)
66+
unsafe_convert(::Type{Ptr{T}}, a::Core.ImmutableArray{T}) where {T} = ccall(:jl_array_ptr, Ptr{T}, (Any,), a)
6667
unsafe_convert(::Type{Ptr{S}}, a::AbstractArray{T}) where {S,T} = convert(Ptr{S}, unsafe_convert(Ptr{T}, a))
6768
unsafe_convert(::Type{Ptr{T}}, a::AbstractArray{T}) where {T} = error("conversion to pointer not defined for $(typeof(a))")
6869

0 commit comments

Comments
 (0)