Skip to content

Commit

Permalink
use Integer during broadcast when possible.
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Jan 6, 2024
1 parent ae6af52 commit dd7f1f8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
20 changes: 8 additions & 12 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -571,15 +571,15 @@ an `Int`.
Any remaining indices in `I` beyond the length of the `keep` tuple are truncated. The `keep` and `default`
tuples may be created by `newindexer(argument)`.
"""
Base.@propagate_inbounds newindex(arg, I::CartesianIndex) = CartesianIndex(_newindex(axes(arg), I.I))
Base.@propagate_inbounds newindex(arg, I::Integer) = CartesianIndex(_newindex(axes(arg), (I,)))
Base.@propagate_inbounds newindex(arg, I::CartesianIndex) = to_index(_newindex(axes(arg), I.I))
Base.@propagate_inbounds newindex(arg, I::Integer) = to_index(_newindex(axes(arg), (I,)))
Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple) = (ifelse(length(ax[1]) == 1, ax[1][1], I[1]), _newindex(tail(ax), tail(I))...)
Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple) = ()
Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple{}) = (ax[1][1], _newindex(tail(ax), ())...)
Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = ()

# If dot-broadcasting were already defined, this would be `ifelse.(keep, I, Idefault)`.
@inline newindex(I::CartesianIndex, keep, Idefault) = CartesianIndex(_newindex(I.I, keep, Idefault))
@inline newindex(I::CartesianIndex, keep, Idefault) = to_index(_newindex(I.I, keep, Idefault))
@inline newindex(i::Integer, keep::Tuple, idefault) = ifelse(keep[1], i, idefault[1])
@inline newindex(i::Integer, keep::Tuple{}, idefault) = CartesianIndex(())
@inline _newindex(I, keep, Idefault) =
Expand All @@ -599,18 +599,14 @@ Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = ()
(Base.length(ind1)::Integer != 1, keep...), (first(ind1), Idefault...)
end

@inline function Base.getindex(bc::Broadcasted, I::Union{Integer,CartesianIndex})
@inline function Base.getindex(bc::Broadcasted, Is::Vararg{Union{Integer,CartesianIndex},N}) where {N}
I = to_index(Base.IteratorsMD.flatten(Is))
@boundscheck checkbounds(bc, I)
@inbounds _broadcast_getindex(bc, I)
end
Base.@propagate_inbounds Base.getindex(
bc::Broadcasted,
i1::Union{Integer,CartesianIndex},
i2::Union{Integer,CartesianIndex},
I::Union{Integer,CartesianIndex}...,
) =
bc[CartesianIndex((i1, i2, I...))]
Base.@propagate_inbounds Base.getindex(bc::Broadcasted) = bc[CartesianIndex(())]
to_index(::Tuple{}) = CartesianIndex()
to_index(Is::Tuple{Any}) = Is[1]
to_index(Is::Tuple) = CartesianIndex(Is)

@inline Base.checkbounds(bc::Broadcasted, I::Union{Integer,CartesianIndex}) =
Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,))
Expand Down
19 changes: 16 additions & 3 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ ci(x) = CartesianIndex(x)
@test @inferred(newindex(ci((2,2)), (true, false), (-1,-1))) == ci((2,-1))
@test @inferred(newindex(ci((2,2)), (false, true), (-1,-1))) == ci((-1,2))
@test @inferred(newindex(ci((2,2)), (false, false), (-1,-1))) == ci((-1,-1))
@test @inferred(newindex(ci((2,2)), (true,), (-1,-1))) == ci((2,))
@test @inferred(newindex(ci((2,2)), (true,), (-1,))) == ci((2,))
@test @inferred(newindex(ci((2,2)), (false,), (-1,))) == ci((-1,))
@test @inferred(newindex(ci((2,2)), (true,), (-1,-1))) == 2
@test @inferred(newindex(ci((2,2)), (true,), (-1,))) == 2
@test @inferred(newindex(ci((2,2)), (false,), (-1,))) == -1
@test @inferred(newindex(ci((2,2)), (), ())) == ci(())

end
Expand Down Expand Up @@ -1175,3 +1175,16 @@ import Base.Broadcast: BroadcastStyle, DefaultArrayStyle

f51129(v, x) = (1 .- (v ./ x) .^ 2)
@test @inferred(f51129([13.0], 6.5)) == [-3.0]

@testset "broadcast for `AbstractArray` without `CartesianIndex` support" begin
struct BVec52775 <: AbstractVector{Int}
a::Vector{Int}
end
Base.size(a::BVec52775) = size(a.a)
Base.getindex(a::BVec52775, i::Real) = a.a[i]
Base.getindex(a::BVec52775, i) = error("unsupported index!")
a = BVec52775([1,2,3])
bc = Base.broadcasted(identity, a)
@test bc[1] == bc[CartesianIndex(1)] == bc[1, CartesianIndex()]
@test a .+ [1 2] == a.a .+ [1 2]
end

0 comments on commit dd7f1f8

Please sign in to comment.