Skip to content

Commit

Permalink
Generalize squeeze to work with OffsetArrays (JuliaLang#22663)
Browse files Browse the repository at this point in the history
* Generalize squeeze to work with OffsetArrays

* Test squeeze with OffsetArrays
  • Loading branch information
GregPlowman authored and mbauman committed Sep 22, 2017
1 parent 942b843 commit 72eed90
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
6 changes: 3 additions & 3 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,18 @@ julia> squeeze(a,3)
function squeeze(A::AbstractArray, dims::Dims)
for i in 1:length(dims)
1 <= dims[i] <= ndims(A) || throw(ArgumentError("squeezed dims must be in range 1:ndims(A)"))
size(A, dims[i]) == 1 || throw(ArgumentError("squeezed dims must all be size 1"))
length(indices(A, dims[i])) == 1 || throw(ArgumentError("squeezed dims must all be size 1"))
for j = 1:i-1
dims[j] == dims[i] && throw(ArgumentError("squeezed dims must be unique"))
end
end
d = ()
for i = 1:ndims(A)
if !in(i, dims)
d = tuple(d..., size(A, i))
d = tuple(d..., indices(A, i))
end
end
reshape(A, d::typeof(_sub(size(A), dims)))
reshape(A, d::typeof(_sub(indices(A), dims)))
end

squeeze(A::AbstractArray, dim::Integer) = squeeze(A, (Int(dim),))
Expand Down
15 changes: 15 additions & 0 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,21 @@ am = map(identity, a)
@test isa(am, OffsetArray)
@test am == a

# squeeze
a0 = rand(1,1,8,8,1)
a = OffsetArray(a0, (-1,2,3,4,5))
@test @inferred(squeeze(a, 1)) == @inferred(squeeze(a, (1,))) == OffsetArray(reshape(a, (1,8,8,1)), (2,3,4,5))
@test @inferred(squeeze(a, 5)) == @inferred(squeeze(a, (5,))) == OffsetArray(reshape(a, (1,1,8,8)), (-1,2,3,4))
@test @inferred(squeeze(a, (1,5))) == squeeze(a, (5,1)) == OffsetArray(reshape(a, (1,8,8)), (2,3,4))
@test @inferred(squeeze(a, (1,2,5))) == squeeze(a, (5,2,1)) == OffsetArray(reshape(a, (8,8)), (3,4))
@test_throws ArgumentError squeeze(a, 0)
@test_throws ArgumentError squeeze(a, (1,1))
@test_throws ArgumentError squeeze(a, (1,2,1))
@test_throws ArgumentError squeeze(a, (1,1,2))
@test_throws ArgumentError squeeze(a, 3)
@test_throws ArgumentError squeeze(a, 4)
@test_throws ArgumentError squeeze(a, 6)

# other functions
v = OffsetArray(v0, (-3,))
@test endof(v) == 1
Expand Down

0 comments on commit 72eed90

Please sign in to comment.