Skip to content

Commit

Permalink
Fix find() when called with function and non-StridedArray
Browse files Browse the repository at this point in the history
The previous code expected testfun to be vectorized, which does not match
the behavior of other methods, nor the documentation. Especially visible
when passing a sparse matrix. Always return an Array rather than an object
similar to the input, because it makes no sense to return a sparse vector.
Remove the method matching any object so that an error is raised rather
than returning incorrect results on non-AbstractArray (e.g. Set).
  • Loading branch information
nalimilan committed Sep 12, 2014
1 parent 13da365 commit ae641c9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 3 additions & 3 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ function findnext(testf::Function, A, start::Integer)
end
findfirst(testf::Function, A) = findnext(testf, A, 1)

function find(testf::Function, A::StridedArray)
function find(testf::Function, A::AbstractArray)
# use a dynamic-length array to store the indexes, then copy to a non-padded
# array for the return
tmpI = Array(Int, 0)
Expand All @@ -1063,7 +1063,7 @@ function find(testf::Function, A::StridedArray)
push!(tmpI, i)
end
end
I = similar(A, Int, length(tmpI))
I = Array(Int, length(tmpI))
copy!(I, tmpI)
I
end
Expand All @@ -1082,7 +1082,7 @@ function find(A::StridedArray)
end

find(x::Number) = x == 0 ? Array(Int,0) : [1]
find(testf::Function, x) = find(testf(x))
find(testf::Function, x::Number) = testf(x) == 0 ? Array(Int,0) : [1]

findn(A::AbstractVector) = find(A)

Expand Down
4 changes: 4 additions & 0 deletions test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ mfe22 = eye(Float64, 2)
K,J,V = findnz(SparseMatrixCSC(2,1,[1,3],[1,2],[1.0,0.0]))
@test length(K) == length(J) == length(V) == 1

# https://groups.google.com/d/msg/julia-users/Yq4dh8NOWBQ/GU57L90FZ3EJ
A = speye(Bool, 5)
@test find(A) == find(x -> x == true, A) == find(full(A))

# issue #5437
@test nnz(sparse([1,2,3],[1,2,3],[0.0,1.0,2.0])) == 2

Expand Down

0 comments on commit ae641c9

Please sign in to comment.