Skip to content

Commit

Permalink
Implement accumulate and friends for arbitrary iterators (JuliaLang…
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Mar 26, 2020
1 parent 4c45a91 commit 5ecc17f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 12 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ New library features

* `isapprox` (or ``) now has a one-argument "curried" method `isapprox(x)` which returns a function, like `isequal` (or `==`)` ([#32305]).
* `Ref{NTuple{N,T}}` can be passed to `Ptr{T}`/`Ref{T}` `ccall` signatures ([#34199])
* `accumulate`, `cumsum`, and `cumprod` now support `Tuple` ([#34654]).
* `accumulate`, `cumsum`, and `cumprod` now support `Tuple` ([#34654]) and arbitrary iterators ([#34656]).
* In `splice!` with no replacement, values to be removed can now be specified with an
arbitrary iterable (instead of a `UnitRange`) ([#34524]).

Expand Down
27 changes: 23 additions & 4 deletions base/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ function cumsum(A::AbstractArray{T}; dims::Integer) where T
end

"""
cumsum(itr::Union{AbstractVector,Tuple})
cumsum(itr)
Cumulative sum an iterator. See also [`cumsum!`](@ref)
to use a preallocated output array, both for performance and to control the precision of the
output (e.g. to avoid overflow).
!!! compat "Julia 1.5"
`cumsum` on a tuple requires at least Julia 1.5.
`cumsum` on a non-array iterator requires at least Julia 1.5.
# Examples
```jldoctest
Expand All @@ -117,6 +117,12 @@ julia> cumsum([fill(1, 2) for i in 1:3])
julia> cumsum((1, 1, 1))
(1, 2, 3)
julia> cumsum(x^2 for x in 1:3)
3-element Array{Int64,1}:
1
5
14
```
"""
cumsum(x::AbstractVector) = cumsum(x, dims=1)
Expand Down Expand Up @@ -170,14 +176,14 @@ function cumprod(A::AbstractArray; dims::Integer)
end

"""
cumprod(itr::Union{AbstractVector,Tuple})
cumprod(itr)
Cumulative product of an iterator. See also
[`cumprod!`](@ref) to use a preallocated output array, both for performance and
to control the precision of the output (e.g. to avoid overflow).
!!! compat "Julia 1.5"
`cumprod` on a tuple requires at least Julia 1.5.
`cumprod` on a non-array iterator requires at least Julia 1.5.
# Examples
```jldoctest
Expand All @@ -195,6 +201,12 @@ julia> cumprod([fill(1//3, 2, 2) for i in 1:3])
julia> cumprod((1, 2, 1))
(1, 2, 2)
julia> cumprod(x^2 for x in 1:3)
3-element Array{Int64,1}:
1
4
36
```
"""
cumprod(x::AbstractVector) = cumprod(x, dims=1)
Expand All @@ -210,6 +222,9 @@ also [`accumulate!`](@ref) to use a preallocated output array, both for performa
to control the precision of the output (e.g. to avoid overflow). For common operations
there are specialized variants of `accumulate`, see: [`cumsum`](@ref), [`cumprod`](@ref)
!!! compat "Julia 1.5"
`accumulate` on a non-array iterator requires at least Julia 1.5.
# Examples
```jldoctest
julia> accumulate(+, [1,2,3])
Expand Down Expand Up @@ -250,6 +265,10 @@ julia> accumulate(+, fill(1, 3, 3), dims=2)
```
"""
function accumulate(op, A; dims::Union{Nothing,Integer}=nothing, kw...)
if dims === nothing && !(A isa AbstractVector)
# This branch takes care of the cases not handled by `_accumulate!`.
return collect(Iterators.accumulate(op, A; kw...))
end
nt = kw.data
if nt isa NamedTuple{()}
out = similar(A, promote_op(op, eltype(A), eltype(A)))
Expand Down
25 changes: 18 additions & 7 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,40 +443,51 @@ reverse(f::Filter) = Filter(f.flt, reverse(f.itr))

# Accumulate -- partial reductions of a function over an iterator

struct Accumulate{F,I}
struct Accumulate{F,I,T}
f::F
itr::I
init::T
end

"""
Iterators.accumulate(f, itr)
Iterators.accumulate(f, itr; [init])
Given a 2-argument function `f` and an iterator `itr`, return a new
iterator that successively applies `f` to the previous value and the
next element of `itr`.
This is effectively a lazy version of [`Base.accumulate`](@ref).
!!! compat "Julia 1.5"
Keyword argument `init` is added in Julia 1.5.
# Examples
```jldoctest
julia> f = Iterators.accumulate(+, [1,2,3,4])
Base.Iterators.Accumulate{typeof(+),Array{Int64,1}}(+, [1, 2, 3, 4])
julia> f = Iterators.accumulate(+, [1,2,3,4]);
julia> foreach(println, f)
1
3
6
10
julia> f = Iterators.accumulate(+, [1,2,3]; init = 100);
julia> foreach(println, f)
101
103
106
```
"""
accumulate(f, itr) = Accumulate(f, itr)
accumulate(f, itr; init = Base._InitialValue()) = Accumulate(f, itr, init)

function iterate(itr::Accumulate)
state = iterate(itr.itr)
if state === nothing
return nothing
end
return (state[1], state)
val = Base.BottomRF(itr.f)(itr.init, state[1])
return (val, (val, state[2]))
end

function iterate(itr::Accumulate, state)
Expand All @@ -491,7 +502,7 @@ end
length(itr::Accumulate) = length(itr.itr)
size(itr::Accumulate) = size(itr.itr)

IteratorSize(::Type{Accumulate{F,I}}) where {F,I} = IteratorSize(I)
IteratorSize(::Type{<:Accumulate{F,I}}) where {F,I} = IteratorSize(I)
IteratorEltype(::Type{<:Accumulate}) = EltypeUnknown()

# Rest -- iterate starting at the given state
Expand Down
9 changes: 9 additions & 0 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,17 @@ end
@test collect(Iterators.accumulate(+, [1,2])) == [1,3]
@test collect(Iterators.accumulate(+, [1,2,3])) == [1,3,6]
@test collect(Iterators.accumulate(=>, [:a,:b,:c])) == [:a, :a => :b, (:a => :b) => :c]
@test collect(Iterators.accumulate(+, (x for x in [true])))::Vector{Int} == [1]
@test collect(Iterators.accumulate(+, (x for x in [true, true, false])))::Vector{Int} == [1, 2, 2]
@test collect(Iterators.accumulate(+, (x for x in [true]), init=10.0))::Vector{Float64} == [11.0]
@test length(Iterators.accumulate(+, [10,20,30])) == 3
@test size(Iterators.accumulate(max, rand(2,3))) == (2,3)
@test Base.IteratorSize(Iterators.accumulate(max, rand(2,3))) === Base.IteratorSize(rand(2,3))
@test Base.IteratorEltype(Iterators.accumulate(*, ())) isa Base.EltypeUnknown
end

@testset "Base.accumulate" begin
@test cumsum(x^2 for x in 1:3) == [1, 5, 14]
@test cumprod(x + 1 for x in 1:3) == [2, 6, 24]
@test accumulate(+, (x^2 for x in 1:3); init=100) == [101, 105, 114]
end

0 comments on commit 5ecc17f

Please sign in to comment.