Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/collection/fused_assemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ function transform_assemble(e::Expr, sym)
margs = materialize_args(se)
subexpr = :($sym = ($sym..., Pair($(margs[1]), $(margs[2]))))
subexpr
elseif e.head == Symbol(".=")
se = code_lowered_single_expression(e)
margs = materialize_args(se)
subexpr = :($sym = ($sym..., Pair($(margs[1]), $(margs[2]))))
subexpr
else
Expr(
transform_assemble(e.head, sym),
Expand Down Expand Up @@ -90,7 +95,7 @@ function check_restrictions_assemble(expr::Expr)
arg isa LineNumberNode && continue
s_error = if arg isa QuoteNode
"Dangling symbols are not allowed inside fused blocks"
elseif arg.head == :call
elseif arg.head == :call && !(isa_dot_op(arg[1]))
"Function calls are not allowed inside fused blocks"
elseif arg.head == :(=)
"Non-broadcast assignments are not allowed inside fused blocks"
Expand All @@ -109,6 +114,13 @@ function check_restrictions_assemble(expr::Expr)
elseif arg.head == :if
check_restrictions(arg.args[2])
elseif arg.head == :macrocall && arg.args[1] == Symbol("@inbounds")
elseif arg.head == :call && isa_dot_op(arg.args[1])
# Allows for :(a .+ foo(b))
# where foo(b) could be a getter to an array.
# This technically opens the door to incorrectness,
# as foo could change the pointer of `b` to something else
# however, this seems unlikely.
elseif isa_dot_op(arg.head) # dot function call
else
@show dump(arg)
error("Uncaught edge case")
Expand Down
14 changes: 13 additions & 1 deletion src/collection/fused_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ function transform(e::Expr)
margs = materialize_args(se)
subexpr = :(Pair($(margs[1]), $(margs[2])))
subexpr
elseif e.head == Symbol(".=")
se = code_lowered_single_expression(e)
margs = materialize_args(se)
subexpr = :(Pair($(margs[1]), $(margs[2])))
subexpr
else
Expr(transform(e.head), transform.(e.args)...)
end
Expand Down Expand Up @@ -82,7 +87,7 @@ function check_restrictions(expr::Expr)
"Loops are not allowed inside fused blocks"
elseif _expr.head == :if
"If-statements are not allowed inside fused blocks"
elseif _expr.head == :call
elseif _expr.head == :call && !(isa_dot_op(_expr.args[1]))
"Function calls are not allowed inside fused blocks"
elseif _expr.head == :(=)
"Non-broadcast assignments are not allowed inside fused blocks"
Expand All @@ -95,6 +100,13 @@ function check_restrictions(expr::Expr)
end
isempty(s_error) || error(s_error)
if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__")
elseif _expr.head == :call && isa_dot_op(_expr.args[1])
# Allows for :(a .+ foo(b))
# where foo(b) could be a getter to an array.
# This technically opens the door to incorrectness,
# as foo could change the pointer of `b` to something else
# however, this seems unlikely.
elseif isa_dot_op(_expr.head) # dot function call
else
@show dump(_expr)
error("Uncaught edge case")
Expand Down
28 changes: 26 additions & 2 deletions src/collection/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@ end

function materialize_args(expr::Expr)
@assert expr.head == :call
@assert expr.args[1] == :(Base.materialize!)
return (expr.args[2], expr.args[3])
if expr.args[1] == :(Base.materialize!)
return (expr.args[2], expr.args[3])
elseif expr.args[1] == :(Base.materialize)
return (expr.args[2], expr.args[2])
else
error("Uncaught edge case.")
end
end

const dot_ops = (
Symbol(".+"),
Symbol(".-"),
Symbol(".*"),
Symbol("./"),
Symbol(".="),
Symbol(".=="),
Symbol(".≠"),
Symbol(".^"),
Symbol(".!="),
Symbol(".>"),
Symbol(".<"),
Symbol(".>="),
Symbol(".<="),
Symbol(".≤"),
Symbol(".≥"),
)
isa_dot_op(op) = any(x -> op == x, dot_ops)
19 changes: 19 additions & 0 deletions test/collection/expr_fused_assemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ import MultiBroadcastFusion as MBF
@test MBF.fused_assemble(expr_in, :tup) == expr_out
end

#! format: off
@testset "fused_assemble - simple sequential, explicit dots" begin
expr_in = quote
y1 .= x1 .+ x2 .+ x3 .+ x4
y2 .= x2 .+ x3 .+ x4 .+ x5
end

expr_out = quote
tup = ()
tup = (tup..., Pair(y1, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x1, x2), x3), x4)))
tup = (tup..., Pair(y2, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x2, x3), x4), x5)))
tup
end

@test MBF.linefilter!(MBF.fused_assemble(expr_in, :tup)) ==
MBF.linefilter!(expr_out)
@test MBF.fused_assemble(expr_in, :tup) == expr_out
end
#! format: on

@testset "fused_assemble - loop" begin
expr_in = quote
Expand Down
15 changes: 15 additions & 0 deletions test/collection/expr_fused_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,18 @@ import MultiBroadcastFusion as MBF
))
@test MBF.fused_direct(expr_in) == expr_out
end

#! format: off
@testset "fused_direct - explicit dots" begin
expr_in = quote
y1 .= x1 .+ x2 .+ x3 .+ x4
y2 .= x2 .+ x3 .+ x4 .+ x5
end

expr_out = :(tuple(
Pair(y1, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x1, x2), x3), x4)),
Pair(y2, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x2, x3), x4), x5)),
))
@test MBF.fused_direct(expr_in) == expr_out
end
#! format: on