diff --git a/base/iterators.jl b/base/iterators.jl index 11e94d3384de8..8f97e6e42db38 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -1001,6 +1001,35 @@ struct ProductIterator{T<:Tuple} iterators::T end +""" +turns a tuple of iterators `Tuple{It1,...,Itn}` with Eltypes `T1,...,Tn` into +a `Tuple{T1,...,Tn}` which is the return value of `Product` types +""" +function _iterator_eltype_tuple(::Type{T}) where {T<:Tuple} + return Tuple{ntuple(n -> eltype(fieldtype(T, n)), _counttuple(T))...} +end + +""" +Whenever ProductIterator can be a subtype of AbstractArray, it should be. This +is this case +""" +struct ProductArray{T<:Tuple,Eltype,N} <: AbstractArray{Eltype,N} + iterators::T + ProductArray(iterators::T) where {T} = new{T,_iterator_eltype_tuple(T),_counttuple(T)}(iterators) +end + +struct GeneralIterators end +struct AllLinearIndexed end +function _all_linear_indexed(::T) where {T<:Tuple} + all(ntuple( + n -> hasmethod(Base.getindex, (fieldtype(T, n), Int)), + Base._counttuple(T) + )) && return AllLinearIndexed() + return GeneralIterators() +end + +const Product{T<:Tuple} where T = Union{ProductIterator{T}, ProductArray{T}} + """ product(iters...) @@ -1021,10 +1050,12 @@ julia> ans == [(x,y) for x in 1:2, y in 3:5] # collects a generator involving I true ``` """ -product(iters...) = ProductIterator(iters) +product(iters...) = _product(_all_linear_indexed(iters), iters) +_product(::AllLinearIndexed, iters) = ProductArray(iters) +_product(::GeneralIterators, iters) = ProductIterator(iters) -IteratorSize(::Type{ProductIterator{Tuple{}}}) = HasShape{0}() -IteratorSize(::Type{ProductIterator{T}}) where {T<:Tuple} = +IteratorSize(::Type{PT{Tuple{}}}) where {PT<:Product} = HasShape{0}() +IteratorSize(::Type{PT{T}}) where {T<:Tuple, PT<:Product} = prod_iteratorsize(ntuple(n -> IteratorSize(fieldtype(T, n)), _counttuple(T)::Int)..., HasShape{0}()) prod_iteratorsize() = HasShape{0}() @@ -1042,7 +1073,7 @@ prod_iteratorsize(::IsInfinite, b) = IsInfinite() prod_iteratorsize(a, b) = SizeUnknown() prod_iteratorsize(a, b, tail...) = prod_iteratorsize(a, prod_iteratorsize(b, tail...)) -size(P::ProductIterator) = _prod_size(P.iterators) +size(P::PT) where {PT<:Product} = _prod_size(P.iterators) _prod_size(::Tuple{}) = () _prod_size(t::Tuple) = (_prod_size1(t[1], IteratorSize(t[1]))..., _prod_size(tail(t))...) _prod_size1(a, ::HasShape) = size(a) @@ -1050,7 +1081,7 @@ _prod_size1(a, ::HasLength) = (length(a),) _prod_size1(a, A) = throw(ArgumentError("Cannot compute size for object of type $(typeof(a))")) -axes(P::ProductIterator) = _prod_indices(P.iterators) +axes(P::PT) where {PT<:Product} = _prod_indices(P.iterators) _prod_indices(::Tuple{}) = () _prod_indices(t::Tuple) = (_prod_axes1(t[1], IteratorSize(t[1]))..., _prod_indices(tail(t))...) _prod_axes1(a, ::HasShape) = axes(a) @@ -1058,26 +1089,26 @@ _prod_axes1(a, ::HasLength) = (OneTo(length(a)),) _prod_axes1(a, A) = throw(ArgumentError("Cannot compute indices for object of type $(typeof(a))")) -ndims(p::ProductIterator) = length(axes(p)) -length(P::ProductIterator) = reduce(checked_mul, size(P); init=1) +ndims(p::PT) where {PT<:Product} = length(axes(p)) +length(P::PT) where {PT<:Product} = reduce(checked_mul, size(P); init=1) -IteratorEltype(::Type{ProductIterator{Tuple{}}}) = HasEltype() -IteratorEltype(::Type{ProductIterator{Tuple{I}}}) where {I} = IteratorEltype(I) +IteratorEltype(::Type{PT{Tuple{}}}) where {PT<:Product} = HasEltype() +IteratorEltype(::Type{PT{Tuple{I}}}) where {I,PT<:Product} = IteratorEltype(I) -function IteratorEltype(::Type{ProductIterator{T}}) where {T<:Tuple} +function IteratorEltype(::Type{PT{T}}) where {T<:Tuple,PT<:Product} E = ntuple(n -> IteratorEltype(fieldtype(T, n)), _counttuple(T)::Int) any(I -> I == EltypeUnknown(), E) && return EltypeUnknown() return E[end] end -eltype(::Type{ProductIterator{I}}) where {I} = _prod_eltype(I) +eltype(::Type{PT{I}}) where {I,PT<:Product} = _prod_eltype(I) _prod_eltype(::Type{Tuple{}}) = Tuple{} _prod_eltype(::Type{I}) where {I<:Tuple} = TupleOrBottom(ntuple(n -> eltype(fieldtype(I, n)), _counttuple(I)::Int)...) -iterate(::ProductIterator{Tuple{}}) = (), true -iterate(::ProductIterator{Tuple{}}, state) = nothing +iterate(::Product{Tuple{}}) = (), true +iterate(::Product{Tuple{}}, state) = nothing -@inline isdone(P::ProductIterator) = any(isdone, P.iterators) +@inline isdone(P::PT) where {PT<:Product} = any(isdone, P.iterators) @inline function _pisdone(iters, states) iter1 = first(iters) done1 = isdone(iter1, first(states)[2]) # check step @@ -1086,8 +1117,8 @@ iterate(::ProductIterator{Tuple{}}, state) = nothing done1 === true || return done1 # false or missing return _pisdone(tail(iters), tail(states)) # check tail end -@inline isdone(::ProductIterator{Tuple{}}, states) = true -@inline isdone(P::ProductIterator, states) = _pisdone(P.iterators, states) +@inline isdone(::Product{Tuple{}}, states) = true +@inline isdone(P::PT, states) where {PT<:Product} = _pisdone(P.iterators, states) @inline _piterate() = () @inline function _piterate(iter1, rest...) @@ -1097,7 +1128,7 @@ end restnext === nothing && return nothing return (next, restnext...) end -@inline function iterate(P::ProductIterator) +@inline function iterate(P::PT) where {PT<:Product} isdone(P) === true && return nothing next = _piterate(P.iterators...) next === nothing && return nothing @@ -1118,16 +1149,19 @@ end end return (next, restnext...) end -@inline function iterate(P::ProductIterator, states) +@inline function iterate(P::PT, states) where {PT<:Product} isdone(P, states) === true && return nothing next = _piterate1(P.iterators, states) next === nothing && return nothing return (Base.map(x -> x[1], next), next) end -reverse(p::ProductIterator) = ProductIterator(Base.map(reverse, p.iterators)) -last(p::ProductIterator) = Base.map(last, p.iterators) -intersect(a::ProductIterator, b::ProductIterator) = ProductIterator(intersect.(a.iterators, b.iterators)) +reverse(p::PT) where {PT<:Product} = PT(Base.map(reverse, p.iterators)) +last(p::PT) = Base.map(last, p.iterators) +intersect(a::PT, b::PT) where {PT<:Product} = ProductIterator(intersect.(a.iterators, b.iterators)) +intersect(a::ProductIterator, b::ProductArray) = ProductIterator(intersect.(a.iterators, b.iterators)) +intersect(a::ProductArray, b::ProductIterator) = ProductIterator(intersect.(a.iterators, b.iterators)) +getindex(p::PT, inds...) where {PT<:Product} = map(getindex, p.iterators, inds) # flatten an iterator of iterators