diff --git a/base/broadcast.jl b/base/broadcast.jl index d284cc52081ae..b6280745be288 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -464,7 +464,8 @@ julia> Broadcast.combine_axes(1, 1, 1) () ``` """ -@inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...)) +@inline combine_axes(A, B, C...) = broadcast_shape(axes(A), combine_axes(B, C...)) +@inline combine_axes(A, B) = broadcast_shape(axes(A), axes(B)) combine_axes(A) = axes(A) # shape (i.e., tuple-of-indices) inputs @@ -502,7 +503,7 @@ function check_broadcast_shape(shp, Ashp::Tuple) _bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination")) check_broadcast_shape(tail(shp), tail(Ashp)) end -check_broadcast_axes(shp, A) = check_broadcast_shape(shp, axes(A)) +@inline check_broadcast_axes(shp, A) = check_broadcast_shape(shp, axes(A)) # comparing many inputs @inline function check_broadcast_axes(shp, A, As...) check_broadcast_axes(shp, A) @@ -864,13 +865,14 @@ broadcast_unalias(::Nothing, src) = src # Preprocessing a `Broadcasted` does two things: # * unaliases any arguments from `dest` -# * "extrudes" the arguments where it is advantageous to pre-compute the broadcasted indices -@inline preprocess(dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(dest, bc.args), bc.axes) -preprocess(dest, x) = extrude(broadcast_unalias(dest, x)) +# * calls `f` on the arguments (typically `extrude`, which pre-computes the broadcasted indices where advantageous) +@inline preprocess(dest, bc) = preprocess(extrude, dest, bc) +@inline preprocess(f, dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(f, dest, bc.args), bc.axes) +preprocess(f, dest, x) = f(broadcast_unalias(dest, x)) -@inline preprocess_args(dest, args::Tuple) = (preprocess(dest, args[1]), preprocess_args(dest, tail(args))...) -preprocess_args(dest, args::Tuple{Any}) = (preprocess(dest, args[1]),) -preprocess_args(dest, args::Tuple{}) = () +@inline preprocess_args(f, dest, args::Tuple) = (preprocess(f, dest, args[1]), preprocess_args(f, dest, tail(args))...) +@inline preprocess_args(f, dest, args::Tuple{Any}) = (preprocess(f, dest, args[1]),) +preprocess_args(f, dest, args::Tuple{}) = () # Specialize this method if all you want to do is specialize on typeof(dest) @inline function copyto!(dest::AbstractArray, bc::Broadcasted{Nothing}) @@ -882,13 +884,48 @@ preprocess_args(dest, args::Tuple{}) = () return copyto!(dest, A) end end - bc′ = preprocess(dest, bc) - @simd for I in eachindex(bc′) - @inbounds dest[I] = bc′[I] + # Ugly performance hack around issue #28126: determine if all arguments to the + # broadcast are sized such that the broadcasting core can statically determine + # whether a given dimension is "extruded" or not. If so, we don't need to check + # any array sizes within the inner loop. Ideally this really should be something + # that Julia and/or LLVM could figure out and eliminate... and indeed they can + # for limited numbers of arguments. + if _is_static_broadcast_28126(dest, bc) + bcs′ = preprocess(_nonextrude_28126, dest, bc) + @simd for I in eachindex(bcs′) + @inbounds dest[I] = bcs′[I] + end + else + bc′ = preprocess(extrude, dest, bc) + @simd for I in eachindex(bc′) + @inbounds dest[I] = bc′[I] + end end return dest end +@inline _is_static_broadcast_28126(dest, bc::Broadcasted{Style}) where {Style} = _is_static_broadcast_28126_args(dest, bc.args) +_is_static_broadcast_28126(dest, x) = false +_is_static_broadcast_28126(dest, x::Union{Ref, Tuple, Type, Number, AbstractArray{<:Any,0}}) = true +_is_static_broadcast_28126(dest::AbstractArray, x::AbstractArray{<:Any,0}) = true +_is_static_broadcast_28126(dest::AbstractArray, x::AbstractArray{<:Any,1}) = axes(dest, 1) == axes(x, 1) +_is_static_broadcast_28126(dest::AbstractArray, x::AbstractArray) = axes(dest) == axes(x) # This can be better with other missing dimensions + +@inline _is_static_broadcast_28126_args(dest, args::Tuple) = _is_static_broadcast_28126(dest, args[1]) && _is_static_broadcast_28126_args(dest, tail(args)) +@inline _is_static_broadcast_28126_args(dest, args::Tuple{Any}) = _is_static_broadcast_28126(dest, args[1]) +_is_static_broadcast_28126_args(dest, args::Tuple{}) = true + +struct _NonExtruded28126{T} + x::T +end +@inline axes(b::_NonExtruded28126) = axes(b.x) +Base.@propagate_inbounds _broadcast_getindex(b::_NonExtruded28126, i) = _broadcast_getindex(b, i) +Base.@propagate_inbounds _broadcast_getindex(b::_NonExtruded28126{<:AbstractArray{<:Any,0}}, i) = b.x[] +Base.@propagate_inbounds _broadcast_getindex(b::_NonExtruded28126{<:AbstractVector}, i) = b.x[i[1]] +Base.@propagate_inbounds _broadcast_getindex(b::_NonExtruded28126{<:AbstractArray}, i) = b.x[i] +_nonextrude_28126(x::AbstractArray) = _NonExtruded28126(x) +_nonextrude_28126(x) = x + # Performance optimization: for BitArray outputs, we cache the result # in a "small" Vector{Bool}, and then copy in chunks into the output @inline function copyto!(dest::BitArray, bc::Broadcasted{Nothing}) diff --git a/test/boundscheck_exec.jl b/test/boundscheck_exec.jl index 62a20921bd44e..f86baca82bbef 100644 --- a/test/boundscheck_exec.jl +++ b/test/boundscheck_exec.jl @@ -251,5 +251,13 @@ if bc_opt == bc_default || bc_opt == bc_off @test occursin("vector.body", sprint(code_llvm, g27079, Tuple{Vector{Int}})) end +# Ensure broadcasting can vectorize when bounds checks are off +if bc_opt != bc_on + function goo28126(u, uprev, k1, k2, k3, k4, k5, k6, k7) + @. u = uprev + 0.1*(0.1*k1 + 0.2*k2 + 0.3*k3 + 0.4*k4 + 0.5*k5 + 0.6*k6 + 0.7*k7) + nothing + end + @test occursin("vector.body", sprint(code_llvm, goo28126, NTuple{9, Vector{Float32}})) +end end diff --git a/test/broadcast.jl b/test/broadcast.jl index a8de5518f3724..0430dd2eb4bd6 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -790,6 +790,16 @@ let @test Dict(c .=> d) == Dict("foo" => 1, "bar" => 2) end +@testset "large fusions vectorize and don't allocate (#28126)" begin + u, uprev, k1, k2, k3, k4, k5, k6, k7 = (ones(1000) for i in 1:9) + function goo(u, uprev, k1, k2, k3, k4, k5, k6, k7) + @. u = uprev + 0.1*(0.1*k1 + 0.2*k2 + 0.3*k3 + 0.4*k4 + 0.5*k5 + 0.6*k6 + 0.7*k7) + nothing + end + @allocated goo(u, uprev, k1, k2, k3, k4, k5, k6, k7) + @test @allocated(goo(u, uprev, k1, k2, k3, k4, k5, k6, k7)) == 0 +end + # Broadcasted iterable/indexable APIs let bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5))