Skip to content

Commit

Permalink
Merge pull request JuliaLang#15665 from KristofferC/kc/remove_colptr_…
Browse files Browse the repository at this point in the history
…copy

avoid unnecessary copy of colptr in spset! and spdelete!
  • Loading branch information
tanmaykm committed Apr 1, 2016
2 parents ccc8fc9 + 30d4af2 commit 2c3ef4e
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2283,29 +2283,28 @@ setindex!{Tv,T<:Integer}(A::SparseMatrixCSC{Tv}, x::Number, I::AbstractVector{T}
(0 == x) ? spdelete!(A, I, J) : spset!(A, convert(Tv,x), I, J)

function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector{Ti}, J::AbstractVector{Ti})
!issorted(I) && (@inbounds I = I[sortperm(I)])
!issorted(J) && (@inbounds J = J[sortperm(J)])
!issorted(I) && (I = sort(I))
!issorted(J) && (J = sort(J))

m, n = size(A)
lenI = length(I)
((I[end] > m) || (J[end] > n)) && throw(DimensionMismatch(""))
nnzA = nnz(A) + lenI * length(J)

colptrA = colptr = A.colptr
colptr = A.colptr
rowvalA = rowval = A.rowval
nzvalA = nzval = A.nzval

rowidx = 1
nadd = 0
@inbounds for col in 1:n
rrange = colptr[col]:(colptr[col+1]-1)
(nadd > 0) && (colptrA[col] = colptr[col] + nadd)
(nadd > 0) && (colptr[col] = colptr[col] + nadd)

if col in J
if isempty(rrange) # set new vals only
nincl = lenI
if nadd == 0
colptrA = copy(colptr)
rowvalA = Array(Ti, nnzA); copy!(rowvalA, 1, rowval, 1, length(rowval))
nzvalA = Array(Tv, nnzA); copy!(nzvalA, 1, nzval, 1, length(nzval))
end
Expand Down Expand Up @@ -2333,7 +2332,6 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
old_ptr += 1
else
if nadd == 0
colptrA = copy(colptr)
rowvalA = Array(Ti, nnzA); copy!(rowvalA, 1, rowval, 1, length(rowval))
nzvalA = Array(Tv, nnzA); copy!(nzvalA, 1, nzval, 1, length(nzval))
end
Expand All @@ -2348,7 +2346,6 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
if old_ptr > old_stop
if new_ptr <= new_stop
if nadd == 0
colptrA = copy(colptr)
rowvalA = Array(Ti, nnzA); copy!(rowvalA, 1, rowval, 1, length(rowval))
nzvalA = Array(Tv, nnzA); copy!(nzvalA, 1, nzval, 1, length(nzval))
end
Expand Down Expand Up @@ -2379,11 +2376,10 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
end

if nadd > 0
colptrA[n+1] = rowidx
colptr[n+1] = rowidx
deleteat!(rowvalA, rowidx:nnzA)
deleteat!(nzvalA, rowidx:nnzA)

A.colptr = colptrA
A.rowval = rowvalA
A.nzval = nzvalA
end
Expand All @@ -2395,19 +2391,19 @@ function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}
nnzA = nnz(A)
(nnzA == 0) && (return A)

!issorted(I) && (@inbounds I = I[sortperm(I)])
!issorted(J) && (@inbounds J = J[sortperm(J)])
!issorted(I) && (I = sort(I))
!issorted(J) && (J = sort(J))

((I[end] > m) || (J[end] > n)) && throw(DimensionMismatch(""))

colptr = colptrA = A.colptr
colptr = A.colptr
rowval = rowvalA = A.rowval
nzval = nzvalA = A.nzval
rowidx = 1
ndel = 0
@inbounds for col in 1:n
rrange = colptr[col]:(colptr[col+1]-1)
(ndel > 0) && (colptrA[col] = colptr[col] - ndel)
(ndel > 0) && (colptr[col] = colptr[col] - ndel)
if isempty(rrange) || !(col in J)
nincl = length(rrange)
if(ndel > 0) && !isempty(rrange)
Expand All @@ -2419,7 +2415,6 @@ function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}
for ridx in rrange
if rowval[ridx] in I
if ndel == 0
colptrA = copy(colptr)
rowvalA = copy(rowval)
nzvalA = copy(nzval)
end
Expand All @@ -2436,11 +2431,10 @@ function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}
end

if ndel > 0
colptrA[n+1] = rowidx
colptr[n+1] = rowidx
deleteat!(rowvalA, rowidx:nnzA)
deleteat!(nzvalA, rowidx:nnzA)

A.colptr = colptrA
A.rowval = rowvalA
A.nzval = nzvalA
end
Expand Down

0 comments on commit 2c3ef4e

Please sign in to comment.