Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/Testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ jobs:
strategy:
matrix:
version:
- '1.10'
- 'min'
- '1'
- 'pre'
# TODO(mhauru) Reenable the below once there is a 'pre' version different from '1'.
# - 'pre'
os:
- ubuntu-latest
- windows-latest
Expand Down
11 changes: 11 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# 0.9.6

Add support for Julia v1.12.

# 0.9.0

From version 0.9.0, the old `TArray` and `TRef` types are completely removed, where previously they were only deprecated. Additionally, the internals have been completely overhauled, and the public interface more precisely defined. See the docs for more info.

# 0.6.0

From v0.6.0 Libtask is implemented by recording all the computing to a tape and copying that tape. Before that version, it is based on a tricky hack on the Julia internals. You can check the commit history of this repo to see the details.
8 changes: 0 additions & 8 deletions NEWS.md

This file was deleted.

2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.9.5"
version = "0.9.6"

[deps]
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"
Expand Down
102 changes: 73 additions & 29 deletions src/bbcode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,44 @@ end

collect_stmts(bb::BBlock)::Vector{IDInstPair} = collect(zip(bb.inst_ids, bb.insts))

struct BBCode
blocks::Vector{BBlock}
argtypes::Vector{Any}
sptypes::Vector{CC.VarState}
linetable::Vector{Core.LineInfoNode}
meta::Vector{Expr}
end
@static if VERSION >= v"1.12-"
struct BBCode
blocks::Vector{BBlock}
argtypes::Vector{Any}
sptypes::Vector{CC.VarState}
debuginfo::CC.DebugInfoStream
meta::Vector{Expr}
valid_worlds::CC.WorldRange
end

function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock})
return BBCode(
new_blocks,
CC.copy(ir.argtypes),
CC.copy(ir.sptypes),
CC.copy(ir.linetable),
CC.copy(ir.meta),
)
function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock})
return BBCode(
new_blocks,
CC.copy(ir.argtypes),
CC.copy(ir.sptypes),
CC.copy(ir.debuginfo),
CC.copy(ir.meta),
ir.valid_worlds,
)
end
else
struct BBCode
blocks::Vector{BBlock}
argtypes::Vector{Any}
sptypes::Vector{CC.VarState}
linetable::Vector{Core.LineInfoNode}
meta::Vector{Expr}
end

function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock})
return BBCode(
new_blocks,
CC.copy(ir.argtypes),
CC.copy(ir.sptypes),
CC.copy(ir.linetable),
CC.copy(ir.meta),
)
end
end

# Makes use of the above outer constructor for `BBCode`.
Expand Down Expand Up @@ -352,20 +374,42 @@ function CC.IRCode(bb_code::BBCode)
insts = _ids_to_line_numbers(bb_code)
cfg = control_flow_graph(bb_code)
insts = _lines_to_blocks(insts, cfg)
return IRCode(
CC.InstructionStream(
map(x -> x.stmt, insts),
map(x -> x.type, insts),
map(x -> x.info, insts),
map(x -> x.line, insts),
map(x -> x.flag, insts),
),
cfg,
CC.copy(bb_code.linetable),
CC.copy(bb_code.argtypes),
CC.copy(bb_code.meta),
CC.copy(bb_code.sptypes),
)
@static if VERSION >= v"1.12-"
# See e.g. here for how the NTuple{3,Int}s get flattened for InstructionStream:
# https://github.com/JuliaLang/julia/blob/16a2bf0a3b106b03dda23b8c9478aab90ffda5e1/Compiler/src/ssair/ir.jl#L299
lines = map(x -> x.line, insts)
lines = collect(Iterators.flatten(lines))
return IRCode(
CC.InstructionStream(
map(x -> x.stmt, insts),
collect(Any, map(x -> x.type, insts)),
collect(CC.CallInfo, map(x -> x.info, insts)),
lines,
map(x -> x.flag, insts),
),
cfg,
CC.copy(bb_code.debuginfo),
CC.copy(bb_code.argtypes),
CC.copy(bb_code.meta),
CC.copy(bb_code.sptypes),
bb_code.valid_worlds,
)
else
return IRCode(
CC.InstructionStream(
map(x -> x.stmt, insts),
map(x -> x.type, insts),
map(x -> x.info, insts),
map(x -> x.line, insts),
map(x -> x.flag, insts),
),
cfg,
CC.copy(bb_code.linetable),
CC.copy(bb_code.argtypes),
CC.copy(bb_code.meta),
CC.copy(bb_code.sptypes),
)
end
end

function _lower_switch_statements(bb_code::BBCode)
Expand Down
29 changes: 23 additions & 6 deletions src/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ function build_callable(sig::Type{<:Tuple})
unoptimised_ir = IRCode(bb)
optimised_ir = optimise_ir!(unoptimised_ir)
mc_ret_type = callable_ret_type(sig, types)
mc = misty_closure(mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true)
mc = optimized_misty_closure(
mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true
)
mc_cache[key] = mc
return mc, refs[end]
end
Expand Down Expand Up @@ -277,6 +279,13 @@ The above gives the broad outline of how `TapedTask`s are implemented. We refer
readers to the code, which is extensively commented to explain implementation details.
"""
function TapedTask(taped_globals::Any, fargs...; kwargs...)
@static if v"1.12.1" > VERSION >= v"1.12.0-"
@warn """
Libtask.jl does not work correctly on Julia v1.12.0 and may crash your Julia
session. Please upgrade to at least v1.12.1. See
https://github.com/JuliaLang/julia/issues/59222 for the bug in question.
"""
end
all_args = isempty(kwargs) ? fargs : (Core.kwcall, getfield(kwargs, :data), fargs...)
seed_id!() # a BBCode thing.
mc, count_ref = build_callable(typeof(all_args))
Expand Down Expand Up @@ -441,8 +450,10 @@ get_value(x) = x
expression, otherwise `false`.
"""
function is_produce_stmt(x)::Bool
if Meta.isexpr(x, :invoke) && length(x.args) == 3 && x.args[1] isa Core.MethodInstance
return x.args[1].specTypes <: Tuple{typeof(produce),Any}
if Meta.isexpr(x, :invoke) &&
length(x.args) == 3 &&
x.args[1] isa Union{Core.MethodInstance,Core.CodeInstance}
return get_mi(x.args[1]).specTypes <: Tuple{typeof(produce),Any}
elseif Meta.isexpr(x, :call) && length(x.args) == 2
return get_value(x.args[1]) === produce
else
Expand All @@ -465,7 +476,7 @@ function stmt_might_produce(x, ret_type::Type)::Bool

# Statement will terminate in the usual fashion, so _do_ bother recusing.
is_produce_stmt(x) && return true
Meta.isexpr(x, :invoke) && return might_produce(x.args[1].specTypes)
Meta.isexpr(x, :invoke) && return might_produce(get_mi(x.args[1]).specTypes)
if Meta.isexpr(x, :call)
# This is a hack -- it's perfectly possible for `DataType` calls to produce in general.
f = get_function(x.args[1])
Expand Down Expand Up @@ -1029,7 +1040,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}

# Derive TapedTask for this statement.
(callable, callable_args) = if Meta.isexpr(stmt, :invoke)
sig = stmt.args[1].specTypes
sig = get_mi(stmt.args[1]).specTypes
v = Any[Any]
(LazyCallable{sig,callable_ret_type(sig, v)}(), stmt.args[2:end])
elseif Meta.isexpr(stmt, :call)
Expand Down Expand Up @@ -1144,7 +1155,13 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}
new_argtypes = vcat(typeof(refs), copy(ir.argtypes))

# Return BBCode and the `Ref`s.
new_ir = BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta)
@static if VERSION >= v"1.12-"
new_ir = BBCode(
new_bblocks, new_argtypes, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds
)
else
new_ir = BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta)
end
return new_ir, refs, possible_produce_types
end

Expand Down
126 changes: 113 additions & 13 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
function get_mi(ci::Core.CodeInstance)
@static isdefined(CC, :get_ci_mi) ? CC.get_ci_mi(ci) : ci.def
end
get_mi(mi::Core.MethodInstance) = mi

"""
replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}
Expand Down Expand Up @@ -68,7 +72,11 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true)

ir = CC.compact!(ir)
# CC.verify_ir(ir, true, false, CC.optimizer_lattice(local_interp))
CC.verify_linetable(ir.linetable, true)
@static if VERSION >= v"1.12-"
CC.verify_linetable(ir.debuginfo, div(length(ir.debuginfo.codelocs), 3), true)
else
CC.verify_linetable(ir.linetable, true)
end
if show_ir
println("Post-optimization")
display(ir)
Expand Down Expand Up @@ -96,13 +104,27 @@ end
# Run type inference and constant propagation on the ir. Credit to @oxinabox:
# https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54
function __infer_ir!(ir, interp::CC.AbstractInterpreter, mi::CC.MethodInstance)
method_info = CC.MethodInfo(true, nothing) #=propagate_inbounds=#
min_world = world = get_inference_world(interp)
max_world = Base.get_world_counter()
irsv = CC.IRInterpretationState(
interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world
)
rt = CC._ir_abstract_constant_propagation(interp, irsv)
@static if VERSION >= v"1.12-"
nargs = length(ir.argtypes) - 1
# TODO(mhauru) How should we figure out isva? I don't think it's in ir or mi.
isva = false
propagate_inbounds = true
spec_info = CC.SpecInfo(nargs, isva, propagate_inbounds, nothing)
min_world = world = get_inference_world(interp)
max_world = Base.get_world_counter()
irsv = CC.IRInterpretationState(
interp, spec_info, ir, mi, ir.argtypes, world, min_world, max_world
)
rt = CC.ir_abstract_constant_propagation(interp, irsv)
else
method_info = CC.MethodInfo(true, nothing) #=propagate_inbounds=#
min_world = world = get_inference_world(interp)
max_world = Base.get_world_counter()
irsv = CC.IRInterpretationState(
interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world
)
rt = CC._ir_abstract_constant_propagation(interp, irsv)
end
return ir
end

Expand Down Expand Up @@ -168,19 +190,85 @@ function opaque_closure(
)
# This implementation is copied over directly from `Core.OpaqueClosure`.
ir = CC.copy(ir)
nargs = length(ir.argtypes) - 1
sig = Base.Experimental.compute_oc_signature(ir, nargs, isva)
@static if VERSION >= v"1.12-"
# On v1.12 OpaqueClosure expects the first arg to be the environment.
ir.argtypes[1] = typeof(env)
end
nargtypes = length(ir.argtypes)
nargs = nargtypes - 1
@static if VERSION >= v"1.12-"
sig = CC.compute_oc_signature(ir, nargs, isva)
else
sig = Base.Experimental.compute_oc_signature(ir, nargs, isva)
end
src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
src.slotnames = fill(:none, nargs + 1)
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
src.slotnames = [Symbol(:_, i) for i in 1:nargtypes]
src.slotflags = fill(zero(UInt8), nargtypes)
src.slottypes = copy(ir.argtypes)
src.rettype = ret_type
@static if VERSION > v"1.12-"
ir.debuginfo.def === nothing &&
(ir.debuginfo.def = :var"generated IR for OpaqueClosure")
src.min_world = ir.valid_worlds.min_world
src.max_world = ir.valid_worlds.max_world
src.isva = isva
src.nargs = nargtypes
end
src = CC.ir_to_codeinf!(src, ir)
src.rettype = ret_type
return Base.Experimental.generate_opaque_closure(
sig, Union{}, ret_type, src, nargs, isva, env...; do_compile
)::Core.OpaqueClosure{sig,ret_type}
end

function optimized_opaque_closure(rtype, ir::IRCode, env...; kwargs...)
oc = opaque_closure(rtype, ir, env...; kwargs...)
world = UInt(oc.world)
set_world_bounds_for_optimization!(oc)
optimized_oc = optimize_opaque_closure(oc, rtype, env...; kwargs...)
return optimized_oc
end

function optimize_opaque_closure(oc::Core.OpaqueClosure, rtype, env...; kwargs...)
method = oc.source
ci = method.specializations.cache
world = UInt(oc.world)
ir = reinfer_and_inline(ci, world)
ir === nothing && return oc # nothing to optimize
return opaque_closure(rtype, ir, env...; kwargs...)
end

# Allows optimization to make assumptions about binding access,
# enabling inlining and other optimizations.
function set_world_bounds_for_optimization!(oc::Core.OpaqueClosure)
ci = oc.source.specializations.cache
ci.inferred === nothing && return nothing
ci.inferred.min_world = oc.world
return ci.inferred.max_world = oc.world
end

function reinfer_and_inline(ci::Core.CodeInstance, world::UInt)
interp = CC.NativeInterpreter(world)
mi = get_mi(ci)
argtypes = collect(Any, mi.specTypes.parameters)
irsv = CC.IRInterpretationState(interp, ci, mi, argtypes, world)
irsv === nothing && return nothing
for stmt in irsv.ir.stmts
inst = stmt[:inst]
if Meta.isexpr(inst, :loopinfo) ||
Meta.isexpr(inst, :pop_exception) ||
isa(inst, CC.GotoIfNot) ||
isa(inst, CC.GotoNode) ||
Meta.isexpr(inst, :copyast)
continue
end
stmt[:flag] |= CC.IR_FLAG_REFINED
end
CC.ir_abstract_constant_propagation(interp, irsv)
state = CC.InliningState(interp)
ir = CC.ssa_inlining_pass!(irsv.ir, state, CC.propagate_inbounds(irsv))
return ir
end

"""
misty_closure(
ret_type::Type,
Expand All @@ -202,3 +290,15 @@ function misty_closure(
)
return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir))
end

function optimized_misty_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)
return MistyClosure(
optimized_opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir)
)
end
Loading