Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,35 @@ struct ProductIterator{T<:Tuple}
iterators::T
end

"""
turns a tuple of iterators `Tuple{It1,...,Itn}` with Eltypes `T1,...,Tn` into
a `Tuple{T1,...,Tn}` which is the return value of `Product` types
"""
function _iterator_eltype_tuple(::Type{T}) where {T<:Tuple}
return Tuple{ntuple(n -> eltype(fieldtype(T, n)), _counttuple(T))...}
end

"""
Whenever ProductIterator can be a subtype of AbstractArray, it should be. This
is this case
"""
struct ProductArray{T<:Tuple,Eltype,N} <: AbstractArray{Eltype,N}
iterators::T
ProductArray(iterators::T) where {T} = new{T,_iterator_eltype_tuple(T),_counttuple(T)}(iterators)
end

struct GeneralIterators end
struct AllLinearIndexed end
function _all_linear_indexed(::T) where {T<:Tuple}
all(ntuple(
n -> hasmethod(Base.getindex, (fieldtype(T, n), Int)),
Base._counttuple(T)
)) && return AllLinearIndexed()
return GeneralIterators()
end

const Product{T<:Tuple} where T = Union{ProductIterator{T}, ProductArray{T}}

"""
product(iters...)

Expand All @@ -1021,10 +1050,12 @@ julia> ans == [(x,y) for x in 1:2, y in 3:5] # collects a generator involving I
true
```
"""
product(iters...) = ProductIterator(iters)
product(iters...) = _product(_all_linear_indexed(iters), iters)
_product(::AllLinearIndexed, iters) = ProductArray(iters)
_product(::GeneralIterators, iters) = ProductIterator(iters)

IteratorSize(::Type{ProductIterator{Tuple{}}}) = HasShape{0}()
IteratorSize(::Type{ProductIterator{T}}) where {T<:Tuple} =
IteratorSize(::Type{PT{Tuple{}}}) where {PT<:Product} = HasShape{0}()
IteratorSize(::Type{PT{T}}) where {T<:Tuple, PT<:Product} =
prod_iteratorsize(ntuple(n -> IteratorSize(fieldtype(T, n)), _counttuple(T)::Int)..., HasShape{0}())

prod_iteratorsize() = HasShape{0}()
Expand All @@ -1042,42 +1073,42 @@ prod_iteratorsize(::IsInfinite, b) = IsInfinite()
prod_iteratorsize(a, b) = SizeUnknown()
prod_iteratorsize(a, b, tail...) = prod_iteratorsize(a, prod_iteratorsize(b, tail...))

size(P::ProductIterator) = _prod_size(P.iterators)
size(P::PT) where {PT<:Product} = _prod_size(P.iterators)
_prod_size(::Tuple{}) = ()
_prod_size(t::Tuple) = (_prod_size1(t[1], IteratorSize(t[1]))..., _prod_size(tail(t))...)
_prod_size1(a, ::HasShape) = size(a)
_prod_size1(a, ::HasLength) = (length(a),)
_prod_size1(a, A) =
throw(ArgumentError("Cannot compute size for object of type $(typeof(a))"))

axes(P::ProductIterator) = _prod_indices(P.iterators)
axes(P::PT) where {PT<:Product} = _prod_indices(P.iterators)
_prod_indices(::Tuple{}) = ()
_prod_indices(t::Tuple) = (_prod_axes1(t[1], IteratorSize(t[1]))..., _prod_indices(tail(t))...)
_prod_axes1(a, ::HasShape) = axes(a)
_prod_axes1(a, ::HasLength) = (OneTo(length(a)),)
_prod_axes1(a, A) =
throw(ArgumentError("Cannot compute indices for object of type $(typeof(a))"))

ndims(p::ProductIterator) = length(axes(p))
length(P::ProductIterator) = reduce(checked_mul, size(P); init=1)
ndims(p::PT) where {PT<:Product} = length(axes(p))
length(P::PT) where {PT<:Product} = reduce(checked_mul, size(P); init=1)

IteratorEltype(::Type{ProductIterator{Tuple{}}}) = HasEltype()
IteratorEltype(::Type{ProductIterator{Tuple{I}}}) where {I} = IteratorEltype(I)
IteratorEltype(::Type{PT{Tuple{}}}) where {PT<:Product} = HasEltype()
IteratorEltype(::Type{PT{Tuple{I}}}) where {I,PT<:Product} = IteratorEltype(I)

function IteratorEltype(::Type{ProductIterator{T}}) where {T<:Tuple}
function IteratorEltype(::Type{PT{T}}) where {T<:Tuple,PT<:Product}
E = ntuple(n -> IteratorEltype(fieldtype(T, n)), _counttuple(T)::Int)
any(I -> I == EltypeUnknown(), E) && return EltypeUnknown()
return E[end]
end

eltype(::Type{ProductIterator{I}}) where {I} = _prod_eltype(I)
eltype(::Type{PT{I}}) where {I,PT<:Product} = _prod_eltype(I)
_prod_eltype(::Type{Tuple{}}) = Tuple{}
_prod_eltype(::Type{I}) where {I<:Tuple} = TupleOrBottom(ntuple(n -> eltype(fieldtype(I, n)), _counttuple(I)::Int)...)

iterate(::ProductIterator{Tuple{}}) = (), true
iterate(::ProductIterator{Tuple{}}, state) = nothing
iterate(::Product{Tuple{}}) = (), true
iterate(::Product{Tuple{}}, state) = nothing

@inline isdone(P::ProductIterator) = any(isdone, P.iterators)
@inline isdone(P::PT) where {PT<:Product} = any(isdone, P.iterators)
@inline function _pisdone(iters, states)
iter1 = first(iters)
done1 = isdone(iter1, first(states)[2]) # check step
Expand All @@ -1086,8 +1117,8 @@ iterate(::ProductIterator{Tuple{}}, state) = nothing
done1 === true || return done1 # false or missing
return _pisdone(tail(iters), tail(states)) # check tail
end
@inline isdone(::ProductIterator{Tuple{}}, states) = true
@inline isdone(P::ProductIterator, states) = _pisdone(P.iterators, states)
@inline isdone(::Product{Tuple{}}, states) = true
@inline isdone(P::PT, states) where {PT<:Product} = _pisdone(P.iterators, states)

@inline _piterate() = ()
@inline function _piterate(iter1, rest...)
Expand All @@ -1097,7 +1128,7 @@ end
restnext === nothing && return nothing
return (next, restnext...)
end
@inline function iterate(P::ProductIterator)
@inline function iterate(P::PT) where {PT<:Product}
isdone(P) === true && return nothing
next = _piterate(P.iterators...)
next === nothing && return nothing
Expand All @@ -1118,16 +1149,19 @@ end
end
return (next, restnext...)
end
@inline function iterate(P::ProductIterator, states)
@inline function iterate(P::PT, states) where {PT<:Product}
isdone(P, states) === true && return nothing
next = _piterate1(P.iterators, states)
next === nothing && return nothing
return (Base.map(x -> x[1], next), next)
end

reverse(p::ProductIterator) = ProductIterator(Base.map(reverse, p.iterators))
last(p::ProductIterator) = Base.map(last, p.iterators)
intersect(a::ProductIterator, b::ProductIterator) = ProductIterator(intersect.(a.iterators, b.iterators))
reverse(p::PT) where {PT<:Product} = PT(Base.map(reverse, p.iterators))
last(p::PT) = Base.map(last, p.iterators)
intersect(a::PT, b::PT) where {PT<:Product} = ProductIterator(intersect.(a.iterators, b.iterators))
intersect(a::ProductIterator, b::ProductArray) = ProductIterator(intersect.(a.iterators, b.iterators))
intersect(a::ProductArray, b::ProductIterator) = ProductIterator(intersect.(a.iterators, b.iterators))
getindex(p::PT, inds...) where {PT<:Product} = map(getindex, p.iterators, inds)

# flatten an iterator of iterators

Expand Down