Skip to content

Commit

Permalink
HasLength trait for some Flatten iterators (JuliaLang#22691)
Browse files Browse the repository at this point in the history
  • Loading branch information
mschauer authored and rfourquet committed Aug 3, 2017
1 parent cd56e07 commit 5e32423
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
19 changes: 18 additions & 1 deletion base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -731,11 +731,28 @@ julia> collect(Iterators.flatten((1:2, 8:9)))
flatten(itr) = Flatten(itr)

eltype(::Type{Flatten{I}}) where {I} = eltype(eltype(I))
iteratorsize(::Type{Flatten{I}}) where {I} = SizeUnknown()
iteratoreltype(::Type{Flatten{I}}) where {I} = _flatteneltype(I, iteratoreltype(I))
_flatteneltype(I, ::HasEltype) = iteratoreltype(eltype(I))
_flatteneltype(I, et) = EltypeUnknown()

flatten_iteratorsize(::Union{HasShape, HasLength}, b::Type{<:Tuple}) = isleaftype(b) ? HasLength() : SizeUnknown()
flatten_iteratorsize(::Union{HasShape, HasLength}, b::Type{<:Number}) = HasLength()
flatten_iteratorsize(a, b) = SizeUnknown()

iteratorsize(::Type{Flatten{I}}) where {I} = flatten_iteratorsize(iteratorsize(I), eltype(I))

function flatten_length(f, ::Type{T}) where {T<:Tuple}
if !isleaftype(T)
throw(ArgumentError(
"Cannot compute length of a tuple-type which is not a leaf-type"))
end
fieldcount(T)*length(f.it)
end
flatten_length(f, ::Type{<:Number}) = length(f.it)
flatten_length(f, T) = throw(ArgumentError(
"Iterates of the argument to Flatten are not known to have constant length"))
length(f::Flatten{I}) where {I} = flatten_length(f, eltype(I))

function start(f::Flatten)
local inner, s2
s = start(f.it)
Expand Down
4 changes: 4 additions & 0 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,11 @@ end
@test collect(flatten(Any[flatten(Any[1:2, 6:5]), flatten(Any[6:7, 8:9])])) == Any[1,2,6,7,8,9]
@test collect(flatten(Any[2:1])) == Any[]
@test eltype(flatten(UnitRange{Int8}[1:2, 3:4])) == Int8
@test length(flatten(zip(1:3, 4:6))) == 6
@test length(flatten(1:6)) == 6
@test_throws ArgumentError collect(flatten(Any[]))
@test_throws ArgumentError length(flatten(NTuple[(1,), ()])) # #16680
@test_throws ArgumentError length(flatten([[1], [1]]))

@test Base.iteratoreltype(Base.Flatten((i for i=1:2) for j=1:1)) == Base.EltypeUnknown()

Expand Down

0 comments on commit 5e32423

Please sign in to comment.