Skip to content

Commit

Permalink
eliminate the dead iterate branch in _unsafe_(get)setindex!. (Jul…
Browse files Browse the repository at this point in the history
…iaLang#52809)

It's sad that compiler can't do this automatically.
Some benchmark with `setindex!`:
```julia
julia> a = zeros(Int, 100, 100);
julia> @Btime $a[:,:] = $(1:10000);
  1.340 μs (0 allocations: 0 bytes) #master: 3.350 μs (0 allocations: 0 bytes)

julia> @Btime $a[:,:] = $(view(LinearIndices(a), 1:100, 1:100));
  10.000 μs (0 allocations: 0 bytes) #master: 11.000 μs (0 allocations: 0 bytes)
```

BTW optimization for `FastSubArray` introduced in JuliaLang#45371 still work
after this change as the parent array might have their own `copyto!`
optimization.
  • Loading branch information
N5N3 committed Feb 13, 2024
1 parent e8bf9bc commit 75cb2a5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
43 changes: 30 additions & 13 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,28 @@ _maybe_linear_logical_index(::IndexLinear, A, i) = LogicalIndex{Int}(i)
uncolon(::Tuple{}) = Slice(OneTo(1))
uncolon(inds::Tuple) = Slice(inds[1])

"""
_prechecked_iterate(iter[, state])
Internal function used to eliminate the dead branch in `iterate`.
Fallback to `iterate` by default, but optimized for indices type in `Base`.
"""
@propagate_inbounds _prechecked_iterate(iter) = iterate(iter)
@propagate_inbounds _prechecked_iterate(iter, state) = iterate(iter, state)

_prechecked_iterate(iter::AbstractUnitRange, i = first(iter)) = i, convert(eltype(iter), i + step(iter))
_prechecked_iterate(iter::LinearIndices, i = first(iter)) = i, i + 1
_prechecked_iterate(iter::CartesianIndices) = first(iter), first(iter)
function _prechecked_iterate(iter::CartesianIndices, i)
i′ = IteratorsMD.inc(i.I, iter.indices)
return i′, i′
end
_prechecked_iterate(iter::SCartesianIndices2) = first(iter), first(iter)
function _prechecked_iterate(iter::SCartesianIndices2{K}, (;i, j)) where {K}
I = i < K ? SCartesianIndex2{K}(i + 1, j) : SCartesianIndex2{K}(1, j + 1)
return I, I
end

### From abstractarray.jl: Internal multidimensional indexing definitions ###
getindex(x::Union{Number,AbstractChar}, ::CartesianIndex{0}) = x
getindex(t::Tuple, i::CartesianIndex{1}) = getindex(t, i.I[1])
Expand Down Expand Up @@ -910,14 +932,11 @@ function _generate_unsafe_getindex!_body(N::Int)
quote
@inline
D = eachindex(dest)
Dy = iterate(D)
Dy = _prechecked_iterate(D)
@inbounds @nloops $N j d->I[d] begin
# This condition is never hit, but at the moment
# the optimizer is not clever enough to split the union without it
Dy === nothing && return dest
(idx, state) = Dy
(idx, state) = Dy::NTuple{2,Any}
dest[idx] = @ncall $N getindex src j
Dy = iterate(D, state)
Dy = _prechecked_iterate(D, state)
end
return dest
end
Expand Down Expand Up @@ -953,14 +972,12 @@ function _generate_unsafe_setindex!_body(N::Int)
@nexprs $N d->(I_d = unalias(A, I[d]))
idxlens = @ncall $N index_lengths I
@ncall $N setindex_shape_check x′ (d->idxlens[d])
Xy = iterate(x′)
X = eachindex(x′)
Xy = _prechecked_iterate(X)
@inbounds @nloops $N i d->I_d begin
# This is never reached, but serves as an assumption for
# the optimizer that it does not need to emit error paths
Xy === nothing && break
(val, state) = Xy
@ncall $N setindex! A val i
Xy = iterate(x′, state)
(idx, state) = Xy::NTuple{2,Any}
@ncall $N setindex! A x′[idx] i
Xy = _prechecked_iterate(X, state)
end
A
end
Expand Down
18 changes: 18 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1961,3 +1961,21 @@ end
@test zero([[2,2], [3,3,3]]) isa Vector{Vector{Int}}
@test zero([[2,2], [3,3,3]]) == [[0,0], [0, 0, 0]]
end

@testset "`_prechecked_iterate` optimization" begin
function test_prechecked_iterate(iter)
Js = Base._prechecked_iterate(iter)
for I in iter
J, s = Js::NTuple{2,Any}
@test J === I
Js = Base._prechecked_iterate(iter, s)
end
end
test_prechecked_iterate(1:10)
test_prechecked_iterate(Base.OneTo(10))
test_prechecked_iterate(CartesianIndices((3, 3)))
test_prechecked_iterate(CartesianIndices(()))
test_prechecked_iterate(LinearIndices((3, 3)))
test_prechecked_iterate(LinearIndices(()))
test_prechecked_iterate(Base.SCartesianIndices2{3}(1:3))
end

0 comments on commit 75cb2a5

Please sign in to comment.