Skip to content

Commit fcbafaa

Browse files
committed
Implement ImmutableArray
This rebases #31630 with several fixed and modifications. After #31630, we had originally decided to hold off on said PR in favor of implementing either more efficient layouts for tuples or some sort of variable-sized struct type. However, in the two years since, neither of those have happened (I had a go at improving tuples and made some progress, but there is much still to be done there). In the meantime, all across the package ecosystem, we've seen an increasing creep of pre-allocation and mutating operations, primarily caused by our lack of sufficiently powerful immutable array abstractions and array optimizations. This works fine for the individual packages in question, but it causes a fair bit of trouble when trying to compose these packages with transformation passes such as AD or domain specific optimizations, since many of those passes do not play well with mutation. More generally, we would like to avoid people needing to pierce abstractions for performance reasons. Given these developments, I think it's getting quite important that we start to seriously look at arrays and try to provide performant and well-optimized arrays in the language. More importantly, I think this is somewhat independent from the actual implementation details. To be sure, it would be nice to move more of the array implementation into Julia by making use of one of the abovementioned langugage features, but that is a bit of an orthogonal concern and not absolutely required. This PR provides an `ImmutableArray` type that is identical in functionality and implementation to `Array`, except that it is immutable. Two new intrinsics `Core.arrayfreeze` and `Core.arraythaw` are provided which are semantically copies and turn a mutable array into an immutable array and vice versa. In the original PR, I additionally provided generic functions `freeze` and `thaw` that would simply forward to these intrinsics. However, said generic functions have been omitted from this PR in favor of simply using constructors to go between mutable and immutable arrays at the high level. Generic `freeze`/`thaw` functions can always be added later, once we have a more complete picture of how these functions would work on non-Array datatypes. Some basic compiler support is provided to elide these copies when the compiler can prove that the original object is dead after the copy. For instance, in the following example: ``` function simple() a = Vector{Float64}(undef, 5) for i = 1:5 a[i] = i end ImmutableArray(a) end ``` the compiler will recognize that the array `a` is dead after its use in `ImmutableArray` and the optimized implementation will simply rewrite the type tag in the originally allocated array to now mark it as immutable. It should be pointed out however, that *semantically* there is still no mutation of the original array, this is simply an optimization. At the moment this compiler transform is rather limited, since the analysis requires escape information in order to compute whether or not the copy may be elided. However, more complete escape analysis is being worked on at the moment, so hopefully this analysis should become more powerful in the very near future. I would like to get this cleaned up and merged resonably quickly, and then crowdsource some improvements to the Array APIs more generally. There are still a number of APIs that are quite bound to the notion of mutable `Array`s. StaticArrays and other packages have been inventing conventions for how to generalize those, but we should form a view in Base what those APIs should look like and harmonize them. Having the `ImmutableArray` in Base should help with that.
1 parent f711f0a commit fcbafaa

25 files changed

+312
-29
lines changed

base/array.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,20 @@ function vect(X...)
147147
return copyto!(Vector{T}(undef, length(X)), X)
148148
end
149149

150-
size(a::Array, d::Integer) = arraysize(a, convert(Int, d))
151-
size(a::Vector) = (arraysize(a,1),)
152-
size(a::Matrix) = (arraysize(a,1), arraysize(a,2))
153-
size(a::Array{<:Any,N}) where {N} = (@_inline_meta; ntuple(M -> size(a, M), Val(N))::Dims)
150+
const ImmutableArray = Core.ImmutableArray
151+
const IMArray{T,N} = Union{Array{T, N}, ImmutableArray{T,N}}
152+
const IMVector{T} = IMArray{T, 1}
153+
const IMMatrix{T} = IMArray{T, 2}
154154

155-
asize_from(a::Array, n) = n > ndims(a) ? () : (arraysize(a,n), asize_from(a, n+1)...)
155+
ImmutableArray(a::Array) = Core.arrayfreeze(a)
156+
Array(a::ImmutableArray) = Core.arraythaw(a)
157+
158+
size(a::IMArray, d::Integer) = arraysize(a, convert(Int, d))
159+
size(a::IMVector) = (arraysize(a,1),)
160+
size(a::IMMatrix) = (arraysize(a,1), arraysize(a,2))
161+
size(a::IMArray{<:Any,N}) where {N} = (@_inline_meta; ntuple(M -> size(a, M), Val(N))::Dims)
162+
163+
asize_from(a::IMArray, n) = n > ndims(a) ? () : (arraysize(a,n), asize_from(a, n+1)...)
156164

157165
allocatedinline(T::Type) = (@_pure_meta; ccall(:jl_stored_inline, Cint, (Any,), T) != Cint(0))
158166

@@ -223,6 +231,13 @@ function isassigned(a::Array, i::Int...)
223231
ccall(:jl_array_isassigned, Cint, (Any, UInt), a, ii) == 1
224232
end
225233

234+
function isassigned(a::ImmutableArray, i::Int...)
235+
@_inline_meta
236+
ii = (_sub2ind(size(a), i...) % UInt) - 1
237+
@boundscheck ii < length(a) % UInt || return false
238+
ccall(:jl_array_isassigned, Cint, (Any, UInt), a, ii) == 1
239+
end
240+
226241
## copy ##
227242

228243
"""
@@ -895,6 +910,9 @@ function getindex end
895910
@eval getindex(A::Array, i1::Int) = arrayref($(Expr(:boundscheck)), A, i1)
896911
@eval getindex(A::Array, i1::Int, i2::Int, I::Int...) = (@_inline_meta; arrayref($(Expr(:boundscheck)), A, i1, i2, I...))
897912

913+
@eval getindex(A::ImmutableArray, i1::Int) = arrayref($(Expr(:boundscheck)), A, i1)
914+
@eval getindex(A::ImmutableArray, i1::Int, i2::Int, I::Int...) = (@_inline_meta; arrayref($(Expr(:boundscheck)), A, i1, i2, I...))
915+
898916
# Faster contiguous indexing using copyto! for UnitRange and Colon
899917
function getindex(A::Array, I::AbstractUnitRange{<:Integer})
900918
@_inline_meta

base/compiler/optimize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ function run_passes(ci::CodeInfo, sv::OptimizationState)
307307
ir = adce_pass!(ir)
308308
#@Base.show ("after_adce", ir)
309309
@timeit "type lift" ir = type_lift_pass!(ir)
310-
@timeit "compact 3" ir = compact!(ir)
310+
ir = memory_opt!(ir)
311311
#@Base.show ir
312312
if JLOptions().debug_level == 2
313313
@timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable))

base/compiler/ssair/ir.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,13 @@ function setindex!(x::IRCode, @nospecialize(repl), s::SSAValue)
319319
return x
320320
end
321321

322+
function ssadominates(ir::IRCode, domtree::DomTree, ssa1::Int, ssa2::Int)
323+
bb1 = block_for_inst(ir.cfg, ssa1)
324+
bb2 = block_for_inst(ir.cfg, ssa2)
325+
bb1 == bb2 && return ssa1 < ssa2
326+
return dominates(domtree, bb1, bb2)
327+
end
328+
322329
# SSA values that need renaming
323330
struct OldSSAValue
324331
id::Int

base/compiler/ssair/passes.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,3 +1255,77 @@ function cfg_simplify!(ir::IRCode)
12551255
compact.active_result_bb = length(bb_starts)
12561256
return finish(compact)
12571257
end
1258+
1259+
function is_allocation(stmt)
1260+
isexpr(stmt, :foreigncall) || return false
1261+
s = stmt.args[1]
1262+
isa(s, QuoteNode) && (s = s.value)
1263+
return s === :jl_alloc_array_1d
1264+
end
1265+
1266+
function memory_opt!(ir::IRCode)
1267+
compact = IncrementalCompact(ir, false)
1268+
uses = IdDict{Int, Vector{Int}}()
1269+
relevant = IdSet{Int}()
1270+
revisit = Int[]
1271+
function mark_val(val)
1272+
isa(val, SSAValue) || return
1273+
val.id in relevant && pop!(relevant, val.id)
1274+
end
1275+
for ((_, idx), stmt) in compact
1276+
if isa(stmt, ReturnNode)
1277+
isdefined(stmt, :val) || continue
1278+
val = stmt.val
1279+
if isa(val, SSAValue) && val.id in relevant
1280+
(haskey(uses, val.id)) || (uses[val.id] = Int[])
1281+
push!(uses[val.id], idx)
1282+
end
1283+
continue
1284+
end
1285+
(isexpr(stmt, :call) || isexpr(stmt, :foreigncall)) || continue
1286+
if is_allocation(stmt)
1287+
push!(relevant, idx)
1288+
# TODO: Mark everything else here
1289+
continue
1290+
end
1291+
# TODO: Replace this by interprocedural escape analysis
1292+
if is_known_call(stmt, arrayset, compact)
1293+
# The value being set escapes, everything else doesn't
1294+
mark_val(stmt.args[4])
1295+
arr = stmt.args[3]
1296+
if isa(arr, SSAValue) && arr.id in relevant
1297+
(haskey(uses, arr.id)) || (uses[arr.id] = Int[])
1298+
push!(uses[arr.id], idx)
1299+
end
1300+
elseif is_known_call(stmt, Core.arrayfreeze, compact) && isa(stmt.args[2], SSAValue)
1301+
push!(revisit, idx)
1302+
else
1303+
# For now we assume everything escapes
1304+
# TODO: We could handle PhiNodes specially and improve this
1305+
for ur in userefs(stmt)
1306+
mark_val(ur[])
1307+
end
1308+
end
1309+
end
1310+
ir = finish(compact)
1311+
isempty(revisit) && return ir
1312+
domtree = construct_domtree(ir.cfg.blocks)
1313+
for idx in revisit
1314+
# Make sure that the value we reference didn't escape
1315+
id = ir.stmts[idx][:inst].args[2].id
1316+
(id in relevant) || continue
1317+
1318+
# We're ok to steal the memory if we don't dominate any uses
1319+
ok = true
1320+
for use in uses[id]
1321+
if ssadominates(ir, domtree, idx, use)
1322+
ok = false
1323+
break
1324+
end
1325+
end
1326+
ok || continue
1327+
1328+
ir.stmts[idx][:inst].args[1] = Core.mutating_arrayfreeze
1329+
end
1330+
return ir
1331+
end

base/compiler/tfuncs.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,21 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
15321532
sv::Union{InferenceState,Nothing})
15331533
if f === tuple
15341534
return tuple_tfunc(argtypes)
1535+
elseif f === Core.arrayfreeze || f === Core.arraythaw
1536+
if length(argtypes) != 1
1537+
isva && return Any
1538+
return Bottom
1539+
end
1540+
a = widenconst(argtypes[1])
1541+
at = (f === Core.arrayfreeze ? Array : ImmutableArray)
1542+
rt = (f === Core.arrayfreeze ? ImmutableArray : Array)
1543+
if a <: at
1544+
unw = unwrap_unionall(a)
1545+
if isa(unw, DataType)
1546+
return rewrap_unionall(rt{unw.parameters[1], unw.parameters[2]}, a)
1547+
end
1548+
end
1549+
return rt
15351550
end
15361551
if isa(f, IntrinsicFunction)
15371552
if is_pure_intrinsic_infer(f) && _all(@nospecialize(a) -> isa(a, Const), argtypes)

base/dict.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ end
372372
function setindex!(h::Dict{K,V}, v0, key0) where V where K
373373
key = convert(K, key0)
374374
if !isequal(key, key0)
375-
throw(ArgumentError("$(limitrepr(key0)) is not a valid key for type $K"))
375+
throw(KeyTypeError(K, key0))
376376
end
377377
setindex!(h, v0, key)
378378
end

base/experimental.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ module Experimental
1111

1212
using Base: Threads, sync_varname
1313
using Base.Meta
14+
using Base: ImmutableArray
15+
1416

1517
"""
1618
Const(A::Array)

src/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ $(BUILDDIR)/interpreter.o $(BUILDDIR)/interpreter.dbg.obj: $(SRCDIR)/builtin_pro
261261
$(BUILDDIR)/jitlayers.o $(BUILDDIR)/jitlayers.dbg.obj: $(SRCDIR)/jitlayers.h $(SRCDIR)/codegen_shared.h
262262
$(BUILDDIR)/jltypes.o $(BUILDDIR)/jltypes.dbg.obj: $(SRCDIR)/builtin_proto.h
263263
$(build_shlibdir)/libllvmcalltest.$(SHLIB_EXT): $(SRCDIR)/codegen_shared.h $(BUILDDIR)/julia_version.h
264-
$(BUILDDIR)/llvm-alloc-opt.o $(BUILDDIR)/llvm-alloc-opt.dbg.obj: $(SRCDIR)/codegen_shared.h
264+
$(BUILDDIR)/llvm-alloc-opt.o $(BUILDDIR)/llvm-alloc-opt.dbg.obj: $(SRCDIR)/codegen_shared.h $(SRCDIR)/llvm-pass-helpers.h
265265
$(BUILDDIR)/llvm-final-gc-lowering.o $(BUILDDIR)/llvm-final-gc-lowering.dbg.obj: $(SRCDIR)/llvm-pass-helpers.h
266266
$(BUILDDIR)/llvm-gc-invariant-verifier.o $(BUILDDIR)/llvm-gc-invariant-verifier.dbg.obj: $(SRCDIR)/codegen_shared.h
267267
$(BUILDDIR)/llvm-late-gc-lowering.o $(BUILDDIR)/llvm-late-gc-lowering.dbg.obj: $(SRCDIR)/llvm-pass-helpers.h

src/builtin_proto.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ DECLARE_BUILTIN(typeassert);
5151
DECLARE_BUILTIN(_typebody);
5252
DECLARE_BUILTIN(typeof);
5353
DECLARE_BUILTIN(_typevar);
54+
DECLARE_BUILTIN(arrayfreeze);
55+
DECLARE_BUILTIN(arraythaw);
56+
DECLARE_BUILTIN(mutating_arrayfreeze);
5457

5558
JL_CALLABLE(jl_f_invoke_kwsorter);
5659
JL_CALLABLE(jl_f__structtype);

src/builtins.c

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,9 @@ JL_CALLABLE(jl_f__typevar)
13301330
JL_CALLABLE(jl_f_arraysize)
13311331
{
13321332
JL_NARGS(arraysize, 2, 2);
1333-
JL_TYPECHK(arraysize, array, args[0]);
1333+
if (!jl_is_arrayish(args[0])) {
1334+
jl_type_error("arraysize", (jl_value_t*)jl_array_type, args[0]);
1335+
}
13341336
jl_array_t *a = (jl_array_t*)args[0];
13351337
size_t nd = jl_array_ndims(a);
13361338
JL_TYPECHK(arraysize, long, args[1]);
@@ -1369,7 +1371,9 @@ JL_CALLABLE(jl_f_arrayref)
13691371
{
13701372
JL_NARGSV(arrayref, 3);
13711373
JL_TYPECHK(arrayref, bool, args[0]);
1372-
JL_TYPECHK(arrayref, array, args[1]);
1374+
if (!jl_is_arrayish(args[1])) {
1375+
jl_type_error("arrayref", (jl_value_t*)jl_array_type, args[1]);
1376+
}
13731377
jl_array_t *a = (jl_array_t*)args[1];
13741378
size_t i = array_nd_index(a, &args[2], nargs - 2, "arrayref");
13751379
return jl_arrayref(a, i);
@@ -1645,6 +1649,54 @@ JL_CALLABLE(jl_f__equiv_typedef)
16451649
return equiv_type(args[0], args[1]) ? jl_true : jl_false;
16461650
}
16471651

1652+
JL_CALLABLE(jl_f_arrayfreeze)
1653+
{
1654+
JL_NARGSV(arrayfreeze, 1);
1655+
JL_TYPECHK(arrayfreeze, array, args[0]);
1656+
jl_array_t *a = (jl_array_t*)args[0];
1657+
jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_immutable_array_type,
1658+
jl_tparam0(jl_typeof(a)), jl_tparam1(jl_typeof(a)));
1659+
JL_GC_PUSH1(&it);
1660+
// The idea is to elide this copy if the compiler or runtime can prove that
1661+
// doing so is safe to do.
1662+
jl_array_t *na = jl_array_copy(a);
1663+
jl_set_typeof(na, it);
1664+
JL_GC_POP();
1665+
return (jl_value_t*)na;
1666+
}
1667+
1668+
JL_CALLABLE(jl_f_mutating_arrayfreeze)
1669+
{
1670+
// N.B.: These error checks pretend to be arrayfreeze since this is a drop
1671+
// in replacement and we don't want to change the visible error type in the
1672+
// optimizer
1673+
JL_NARGSV(arrayfreeze, 1);
1674+
JL_TYPECHK(arrayfreeze, array, args[0]);
1675+
jl_array_t *a = (jl_array_t*)args[0];
1676+
jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_immutable_array_type,
1677+
jl_tparam0(jl_typeof(a)), jl_tparam1(jl_typeof(a)));
1678+
jl_set_typeof(a, it);
1679+
return (jl_value_t*)a;
1680+
}
1681+
1682+
JL_CALLABLE(jl_f_arraythaw)
1683+
{
1684+
JL_NARGSV(arraythaw, 1);
1685+
if (((jl_datatype_t*)jl_typeof(args[0]))->name != jl_immutable_array_typename) {
1686+
jl_type_error("arraythaw", (jl_value_t*)jl_immutable_array_type, args[0]);
1687+
}
1688+
jl_array_t *a = (jl_array_t*)args[0];
1689+
jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_array_type,
1690+
jl_tparam0(jl_typeof(a)), jl_tparam1(jl_typeof(a)));
1691+
JL_GC_PUSH1(&it);
1692+
// The idea is to elide this copy if the compiler or runtime can prove that
1693+
// doing so is safe to do.
1694+
jl_array_t *na = jl_array_copy(a);
1695+
jl_set_typeof(na, it);
1696+
JL_GC_POP();
1697+
return (jl_value_t*)na;
1698+
}
1699+
16481700
// IntrinsicFunctions ---------------------------------------------------------
16491701

16501702
static void (*runtime_fp[num_intrinsics])(void);
@@ -1797,6 +1849,10 @@ void jl_init_primitives(void) JL_GC_DISABLED
17971849
jl_builtin_arrayset = add_builtin_func("arrayset", jl_f_arrayset);
17981850
jl_builtin_arraysize = add_builtin_func("arraysize", jl_f_arraysize);
17991851

1852+
jl_builtin_arrayfreeze = add_builtin_func("arrayfreeze", jl_f_arrayfreeze);
1853+
jl_builtin_mutating_arrayfreeze = add_builtin_func("mutating_arrayfreeze", jl_f_mutating_arrayfreeze);
1854+
jl_builtin_arraythaw = add_builtin_func("arraythaw", jl_f_arraythaw);
1855+
18001856
// method table utils
18011857
jl_builtin_applicable = add_builtin_func("applicable", jl_f_applicable);
18021858
jl_builtin_invoke = add_builtin_func("invoke", jl_f_invoke);
@@ -1868,6 +1924,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
18681924
add_builtin("AbstractArray", (jl_value_t*)jl_abstractarray_type);
18691925
add_builtin("DenseArray", (jl_value_t*)jl_densearray_type);
18701926
add_builtin("Array", (jl_value_t*)jl_array_type);
1927+
add_builtin("ImmutableArray", (jl_value_t*)jl_immutable_array_type);
18711928

18721929
add_builtin("Expr", (jl_value_t*)jl_expr_type);
18731930
add_builtin("LineNumberNode", (jl_value_t*)jl_linenumbernode_type);

0 commit comments

Comments
 (0)