Skip to content

Commit

Permalink
Merge pull request JuliaLang#34770 from JuliaLang/teh/oa92
Browse files Browse the repository at this point in the history
Convert range type in `reduced_index`
  • Loading branch information
timholy committed Feb 21, 2020
2 parents dc46ddd + 55dab67 commit 3f0b8c9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 3 additions & 3 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# for reductions that expand 0 dims to 1
reduced_index(i::OneTo) = OneTo(1)
reduced_index(i::Union{Slice, IdentityUnitRange}) = first(i):first(i)
reduced_index(i::Union{Slice, IdentityUnitRange}) = oftype(i, first(i):first(i))
reduced_index(i::AbstractUnitRange) =
throw(ArgumentError(
"""
Expand All @@ -20,7 +20,7 @@ reduced_indices0(a::AbstractArray, region) = reduced_indices0(axes(a), region)
function reduced_indices(inds::Indices{N}, d::Int) where N
d < 1 && throw(ArgumentError("dimension must be ≥ 1, got $d"))
if d == 1
return (reduced_index(inds[1]), tail(inds)...)
return (reduced_index(inds[1]), tail(inds)...)::typeof(inds)
elseif 1 < d <= N
return tuple(inds[1:d-1]..., oftype(inds[d], reduced_index(inds[d])), inds[d+1:N]...)::typeof(inds)
else
Expand All @@ -34,7 +34,7 @@ function reduced_indices0(inds::Indices{N}, d::Int) where N
ind = inds[d]
rd = isempty(ind) ? ind : reduced_index(inds[d])
if d == 1
return (rd, tail(inds)...)
return (rd, tail(inds)...)::typeof(inds)
else
return tuple(inds[1:d-1]..., oftype(inds[d], rd), inds[d+1:N]...)::typeof(inds)
end
Expand Down
8 changes: 8 additions & 0 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,14 @@ I = findall(!iszero, z)
@test std(A_3_3, dims=2) == OffsetArray(reshape([3,3,3], (3,1)), A_3_3.offsets)
@test sum(OffsetArray(fill(1,3000), -1000)) == 3000

# https://github.com/JuliaArrays/OffsetArrays.jl/issues/92
A92 = OffsetArray(reshape(1:27, 3, 3, 3), -2, -2, -2)
B92 = view(A92, :, :, -1:0)
@test axes(B92) == (-1:1, -1:1, 1:2)
@test sum(B92, dims=(2,3)) == OffsetArray(reshape([51,57,63], Val(3)), -2, -2, 0)
B92 = view(A92, :, :, Base.IdentityUnitRange(-1:0))
@test sum(B92, dims=(2,3)) == OffsetArray(reshape([51,57,63], Val(3)), -2, -2, -2)

@test norm(v) norm(parent(v))
@test norm(A) norm(parent(A))
@test dot(v, v) dot(v0, v0)
Expand Down

0 comments on commit 3f0b8c9

Please sign in to comment.