diff --git a/base/broadcast.jl b/base/broadcast.jl index 5422f5edd3e6f..e45e944f3aa25 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -464,7 +464,8 @@ julia> Broadcast.combine_axes(1, 1, 1) () ``` """ -@inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...)) +@inline combine_axes(A, B, C...) = broadcast_shape(axes(A), combine_axes(B, C...)) +@inline combine_axes(A, B) = broadcast_shape(axes(A), axes(B)) combine_axes(A) = axes(A) # shape (i.e., tuple-of-indices) inputs @@ -502,7 +503,7 @@ function check_broadcast_shape(shp, Ashp::Tuple) _bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination")) check_broadcast_shape(tail(shp), tail(Ashp)) end -check_broadcast_axes(shp, A) = check_broadcast_shape(shp, axes(A)) +@inline check_broadcast_axes(shp, A) = check_broadcast_shape(shp, axes(A)) # comparing many inputs @inline function check_broadcast_axes(shp, A, As...) check_broadcast_axes(shp, A) @@ -911,7 +912,7 @@ _is_static_broadcast_28126(dest::AbstractArray, x::AbstractArray{<:Any,1}) = axe _is_static_broadcast_28126(dest::AbstractArray, x::AbstractArray) = axes(dest) == axes(x) # This can be better with other missing dimensions @inline _is_static_broadcast_28126_args(dest, args::Tuple) = _is_static_broadcast_28126(dest, args[1]) && _is_static_broadcast_28126_args(dest, tail(args)) -_is_static_broadcast_28126_args(dest, args::Tuple{Any}) = _is_static_broadcast_28126(dest, args[1]) +@inline _is_static_broadcast_28126_args(dest, args::Tuple{Any}) = _is_static_broadcast_28126(dest, args[1]) _is_static_broadcast_28126_args(dest, args::Tuple{}) = true struct _NonExtruded28126{T} diff --git a/test/broadcast.jl b/test/broadcast.jl index a18c5fabca662..483e3ebe01ee5 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -791,14 +791,15 @@ let end @testset "large fusions vectorize and don't allocate (#28126)" begin - u, k1, k2, k3, k4, k5, k6, k7 = (ones(1000) for i in 1:8) - function goo(u, k1, k2, k3, k4, k5, k6, k7) - @. u = 0.1*(0.1*k1 + 0.2*k2 + 0.3*k3 + 0.4*k4 + 0.5*k5 + 0.6*k6 + 0.7*k7) + using InteractiveUtils: code_llvm + u, uprev, k1, k2, k3, k4, k5, k6, k7 = (ones(1000) for i in 1:9) + function goo(u, uprev, k1, k2, k3, k4, k5, k6, k7) + @. u = uprev + 0.1*(0.1*k1 + 0.2*k2 + 0.3*k3 + 0.4*k4 + 0.5*k5 + 0.6*k6 + 0.7*k7) nothing end - @allocated goo(u, k1, k2, k3, k4, k5, k6, k7) - @test @allocated(goo(u, k1, k2, k3, k4, k5, k6, k7)) == 0 - @test occursin("vector.body", sprint(code_llvm, goo, NTuple{8, Vector{Float32}})) + @allocated goo(u, uprev, k1, k2, k3, k4, k5, k6, k7) + @test @allocated(goo(u, uprev, k1, k2, k3, k4, k5, k6, k7)) == 0 + @test occursin("vector.body", sprint(code_llvm, goo, NTuple{9, Vector{Float32}})) end # Broadcasted iterable/indexable APIs