Skip to content

Commit

Permalink
Ensure that we're working with the right integer types consistently i…
Browse files Browse the repository at this point in the history
…n the sparse matrix constructor with triplets. (JuliaLang#26566)
  • Loading branch information
haampie authored and KristofferC committed Mar 22, 2018
1 parent 0f7cd03 commit f2b42c9
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,19 +613,19 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti},
csccolptr::Vector{Ti}, cscrowval::Vector{Ti}, cscnzval::Vector{Tv}) where {Tv,Ti<:Integer}

# Compute the CSR form's row counts and store them shifted forward by one in csrrowptr
fill!(csrrowptr, 0)
fill!(csrrowptr, Ti(0))
coolen = length(I)
@inbounds for k in 1:coolen
Ik = I[k]
if 1 > Ik || m < Ik
throw(ArgumentError("row indices I[k] must satisfy 1 <= I[k] <= m"))
end
csrrowptr[Ik+1] += 1
csrrowptr[Ik+1] += Ti(1)
end

# Compute the CSR form's rowptrs and store them shifted forward by one in csrrowptr
countsum = 1
csrrowptr[1] = 1
countsum = Ti(1)
csrrowptr[1] = Ti(1)
@inbounds for i in 2:(m+1)
overwritten = csrrowptr[i]
csrrowptr[i] = countsum
Expand All @@ -636,11 +636,11 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti},
# Tracking write positions in csrrowptr corrects the row pointers
@inbounds for k in 1:coolen
Ik, Jk = I[k], J[k]
if 1 > Jk || n < Jk
if Ti(1) > Jk || Ti(n) < Jk
throw(ArgumentError("column indices J[k] must satisfy 1 <= J[k] <= n"))
end
csrk = csrrowptr[Ik+1]
csrrowptr[Ik+1] = csrk+1
csrrowptr[Ik+1] = csrk + Ti(1)
csrcolval[csrk] = Jk
csrnzval[csrk] = V[k]
end
Expand All @@ -652,23 +652,23 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti},
#
# Minimizing extraneous communication and nonlocality of reference, primarily by using
# only a single auxiliary array in this step, is the key to this method's performance.
fill!(csccolptr, 0)
fill!(klasttouch, 0)
writek = 1
newcsrrowptri = 1
origcsrrowptri = 1
fill!(csccolptr, Ti(0))
fill!(klasttouch, Ti(0))
writek = Ti(1)
newcsrrowptri = Ti(1)
origcsrrowptri = Ti(1)
origcsrrowptrip1 = csrrowptr[2]
@inbounds for i in 1:m
for readk in origcsrrowptri:(origcsrrowptrip1-1)
for readk in origcsrrowptri:(origcsrrowptrip1-Ti(1))
j = csrcolval[readk]
if klasttouch[j] < newcsrrowptri
klasttouch[j] = writek
if writek != readk
csrcolval[writek] = j
csrnzval[writek] = csrnzval[readk]
end
writek += 1
csccolptr[j+1] += 1
writek += Ti(1)
csccolptr[j+1] += Ti(1)
else
klt = klasttouch[j]
csrnzval[klt] = combine(csrnzval[klt], csrnzval[readk])
Expand All @@ -681,27 +681,27 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti},
end

# Compute the CSC form's colptrs and store them shifted forward by one in csccolptr
countsum = 1
csccolptr[1] = 1
countsum = Ti(1)
csccolptr[1] = Ti(1)
@inbounds for j in 2:(n+1)
overwritten = csccolptr[j]
csccolptr[j] = countsum
countsum += overwritten
end

# Now knowing the CSC form's entry count, resize cscrowval and cscnzval if necessary
cscnnz = countsum - 1
cscnnz = countsum - Ti(1)
length(cscrowval) < cscnnz && resize!(cscrowval, cscnnz)
length(cscnzval) < cscnnz && resize!(cscnzval, cscnnz)

# Finally counting-sort the row and nonzero values from the CSR form into cscrowval and
# cscnzval. Tracking write positions in csccolptr corrects the column pointers.
@inbounds for i in 1:m
for csrk in csrrowptr[i]:(csrrowptr[i+1]-1)
for csrk in csrrowptr[i]:(csrrowptr[i+1]-Ti(1))
j = csrcolval[csrk]
x = csrnzval[csrk]
csck = csccolptr[j+1]
csccolptr[j+1] = csck+1
csccolptr[j+1] = csck + Ti(1)
cscrowval[csck] = i
cscnzval[csck] = x
end
Expand Down

0 comments on commit f2b42c9

Please sign in to comment.