Skip to content

Commit

Permalink
Add dimension parameter to HasShape
Browse files Browse the repository at this point in the history
Needed to determine the shape of indices in a type-stable way.
  • Loading branch information
nalimilan committed Jan 22, 2018
1 parent b72d9eb commit 8da655f
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 20 deletions.
5 changes: 3 additions & 2 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,9 @@ _similar_for(c, T, itr, isz) = similar(c, T)
collect(collection)
Return an `Array` of all items in a collection or iterator. For dictionaries, returns
`Pair{KeyType, ValType}`. If the argument is array-like or is an iterator with the `HasShape()`
trait, the result will have the same shape and number of dimensions as the argument.
`Pair{KeyType, ValType}`. If the argument is array-like or is an iterator with the
[`HasShape`](@ref IteratorSize) trait, the result will have the same shape
and number of dimensions as the argument.
# Examples
```jldoctest
Expand Down
2 changes: 1 addition & 1 deletion base/asyncmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ function verify_ntasks(iterable, ntasks)

if ntasks == 0
chklen = IteratorSize(iterable)
if (chklen == HasLength()) || (chklen == HasShape())
if (chklen isa HasLength) || (chklen isa HasShape)
ntasks = max(1,min(100, length(iterable)))
else
ntasks = 100
Expand Down
11 changes: 6 additions & 5 deletions base/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ end
abstract type IteratorSize end
struct SizeUnknown <: IteratorSize end
struct HasLength <: IteratorSize end
struct HasShape <: IteratorSize end
struct HasShape{N} <: IteratorSize end
struct IsInfinite <: IteratorSize end

"""
Expand All @@ -63,8 +63,9 @@ Given the type of an iterator, return one of the following values:
* `SizeUnknown()` if the length (number of elements) cannot be determined in advance.
* `HasLength()` if there is a fixed, finite length.
* `HasShape()` if there is a known length plus a notion of multidimensional shape (as for an array).
In this case the [`size`](@ref) function is valid for the iterator.
* `HasShape{N}()` if there is a known length plus a notion of multidimensional shape (as for an array).
In this case `N` should give the number of dimensions, and the [`size`](@ref) function is valid
for the iterator.
* `IsInfinite()` if the iterator yields values forever.
The default value (for iterators that do not define this function) is `HasLength()`.
Expand All @@ -75,7 +76,7 @@ result, and algorithms that resize their result incrementally.
```jldoctest
julia> Base.IteratorSize(1:5)
Base.HasShape()
Base.HasShape{1}()
julia> Base.IteratorSize((2,3))
Base.HasLength()
Expand Down Expand Up @@ -110,7 +111,7 @@ Base.HasEltype()
IteratorEltype(x) = IteratorEltype(typeof(x))
IteratorEltype(::Type) = HasEltype() # HasEltype is the default

IteratorSize(::Type{<:AbstractArray}) = HasShape()
IteratorSize(::Type{<:AbstractArray{<:Any,N}}) where {N} = HasShape{N}()
IteratorSize(::Type{Generator{I,F}}) where {I,F} = IteratorSize(I)
length(g::Generator) = length(g.iter)
size(g::Generator) = size(g.iter)
Expand Down
8 changes: 6 additions & 2 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -705,11 +705,15 @@ julia> collect(Iterators.product(1:2,3:5))
"""
product(iters...) = ProductIterator(iters)

IteratorSize(::Type{ProductIterator{Tuple{}}}) = HasShape()
IteratorSize(::Type{ProductIterator{Tuple{}}}) = HasShape{0}()
IteratorSize(::Type{ProductIterator{T}}) where {T<:Tuple} =
prod_iteratorsize( IteratorSize(tuple_type_head(T)), IteratorSize(ProductIterator{tuple_type_tail(T)}) )

prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasShape()
prod_iteratorsize(::HasLength, ::HasLength) = HasShape{2}()
prod_iteratorsize(::HasLength, ::HasShape{N}) where {N} = HasShape{N+1}()
prod_iteratorsize(::HasShape{N}, ::HasLength) where {N} = HasShape{N+1}()
prod_iteratorsize(::HasShape{M}, ::HasShape{N}) where {M,N} = HasShape{M+N}()

# products can have an infinite iterator
prod_iteratorsize(::IsInfinite, ::IsInfinite) = IsInfinite()
prod_iteratorsize(a, ::IsInfinite) = IsInfinite()
Expand Down
2 changes: 1 addition & 1 deletion base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ module IteratorsMD
eltype(R::CartesianIndices) = eltype(typeof(R))
eltype(::Type{CartesianIndices{N}}) where {N} = CartesianIndex{N}
eltype(::Type{CartesianIndices{N,TT}}) where {N,TT} = CartesianIndex{N}
IteratorSize(::Type{<:CartesianIndices}) = Base.HasShape()
IteratorSize(::Type{<:CartesianIndices{N}}) where {N} = Base.HasShape{N}()

@inline function start(iter::CartesianIndices)
iterfirst, iterlast = first(iter), last(iter)
Expand Down
2 changes: 1 addition & 1 deletion base/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ ndims(x::Number) = 0
ndims(::Type{<:Number}) = 0
length(x::Number) = 1
endof(x::Number) = 1
IteratorSize(::Type{<:Number}) = HasShape()
IteratorSize(::Type{<:Number}) = HasShape{0}()
keys(::Number) = OneTo(1)

getindex(x::Number) = x
Expand Down
4 changes: 2 additions & 2 deletions doc/src/manual/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ to generically build upon those behaviors.
| `next(iter, state)` |   | Returns the current item and the next state |
| `done(iter, state)` |   | Tests if there are any items remaining |
| **Important optional methods** | **Default definition** | **Brief description** |
| `IteratorSize(IterType)` | `HasLength()` | One of `HasLength()`, `HasShape()`, `IsInfinite()`, or `SizeUnknown()` as appropriate |
| `IteratorSize(IterType)` | `HasLength()` | One of `HasLength()`, `HasShape{N}()`, `IsInfinite()`, or `SizeUnknown()` as appropriate |
| `IteratorEltype(IterType)` | `HasEltype()` | Either `EltypeUnknown()` or `HasEltype()` as appropriate |
| `eltype(IterType)` | `Any` | The type of the items returned by `next()` |
| `length(iter)` | (*undefined*) | The number of items, if known |
Expand All @@ -22,7 +22,7 @@ to generically build upon those behaviors.
| Value returned by `IteratorSize(IterType)` | Required Methods |
|:------------------------------------------ |:------------------------------------------ |
| `HasLength()` | `length(iter)` |
| `HasShape()` | `length(iter)` and `size(iter, [dim...])` |
| `HasShape{N}()` | `length(iter)` and `size(iter, [dim...])` |
| `IsInfinite()` | (*none*) |
| `SizeUnknown()` | (*none*) |

Expand Down
2 changes: 1 addition & 1 deletion test/generic_map_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function testmap_equivalence(mapf, f, c...)
x1 = mapf(f,c...)
x2 = map(f,c...)

if Base.IteratorSize == Base.HasShape()
if Base.IteratorSize isa Base.HasShape
@test size(x1) == size(x2)
else
@test length(x1) == length(x2)
Expand Down
10 changes: 5 additions & 5 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,11 @@ end
@test Base.IteratorSize(product(1:2, countfrom(1))) == Base.IsInfinite()
@test Base.IteratorSize(product(countfrom(2), countfrom(1))) == Base.IsInfinite()
@test Base.IteratorSize(product(countfrom(1), 1:2)) == Base.IsInfinite()
@test Base.IteratorSize(product(1:2)) == Base.HasShape()
@test Base.IteratorSize(product(1:2, 1:2)) == Base.HasShape()
@test Base.IteratorSize(product(take(1:2, 1), take(1:2, 1))) == Base.HasShape()
@test Base.IteratorSize(product(take(1:2, 2))) == Base.HasShape()
@test Base.IteratorSize(product([1 2; 3 4])) == Base.HasShape()
@test Base.IteratorSize(product(1:2)) == Base.HasShape{1}()
@test Base.IteratorSize(product(1:2, 1:2)) == Base.HasShape{2}()
@test Base.IteratorSize(product(take(1:2, 1), take(1:2, 1))) == Base.HasShape{2}()
@test Base.IteratorSize(product(take(1:2, 2))) == Base.HasShape{1}()
@test Base.IteratorSize(product([1 2; 3 4])) == Base.HasShape{2}()

# IteratorEltype trait business
let f1 = Iterators.filter(i->i>0, 1:10)
Expand Down

0 comments on commit 8da655f

Please sign in to comment.