Skip to content

Commit

Permalink
Make sparse hvcat inferable
Browse files Browse the repository at this point in the history
This also requires making sparse `vcat` and `hcat` inferable in the
vararg case which in turn requires a different way to determine the
resulting index type, now implemented similar to `promote_eltype`.
  • Loading branch information
martinholters authored and johanmon committed Jul 5, 2021
1 parent 84bfa12 commit 89e78fd
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
8 changes: 6 additions & 2 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3274,6 +3274,10 @@ dropstored!(A::AbstractSparseMatrixCSC, ::Colon) = dropstored!(A, :, :)

# Sparse concatenation

promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}) where {Ti} = Ti
promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}, X::AbstractSparseMatrixCSC...) where {Ti} =
promote_type(Ti, promote_idxtype(X...))

function vcat(X::AbstractSparseMatrixCSC...)
num = length(X)
mX = Int[ size(x, 1) for x in X ]
Expand All @@ -3288,7 +3292,7 @@ function vcat(X::AbstractSparseMatrixCSC...)
end

Tv = promote_eltype(X...)
Ti = promote_eltype(map(x->rowvals(x), X)...)
Ti = promote_idxtype(X...)

nnzX = Int[ nnz(x) for x in X ]
nnz_res = sum(nnzX)
Expand Down Expand Up @@ -3340,7 +3344,7 @@ function hcat(X::AbstractSparseMatrixCSC...)
n = sum(nX)

Tv = promote_eltype(X...)
Ti = promote_eltype(map(x->rowvals(x), X)...)
Ti = promote_idxtype(X...)

colptr = Vector{Ti}(undef, n+1)
nnzX = Int[ nnz(x) for x in X ]
Expand Down
17 changes: 8 additions & 9 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1091,17 +1091,16 @@ function vcat(Xin::_SparseConcatGroup...)
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
vcat(X...)
end
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
nbr = length(rows) # number of block rows

tmp_rows = Vector{SparseMatrixCSC}(undef, nbr)
k = 0
@inbounds for i = 1 : nbr
tmp_rows[i] = hcat(X[(1 : rows[i]) .+ k]...)
k += rows[i]
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
vcat(_hvcat_rows(rows, X...)...)
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
if row1 0
throw(ArgumentError("length of block row must be positive, got $row1"))
end
vcat(tmp_rows...)
# provide X[1] separately to convince inference that we don't call hcat() without arguments
return (hcat(X[1], X[2 : row1]...), _hvcat_rows(rows, X[row1+1:end]...)...)
end
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()

# make sure UniformScaling objects are converted to sparse matrices for concatenation
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC
Expand Down
4 changes: 2 additions & 2 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ end
sz34 = spzeros(3, 4)
se77 = sparse(1.0I, 7, 7)
@testset "h+v concatenation" begin
@test [se44 sz42 sz41; sz34 se33] == se77
@test @inferred(hvcat((3, 2), se44, sz42, sz41, sz34, se33)) == se77 # [se44 sz42 sz41; sz34 se33]
@test length(nonzeros([sp33 0I; 1I 0I])) == 6
end

Expand Down Expand Up @@ -2201,7 +2201,7 @@ end
# Test that concatenations of pairs of sparse matrices yield sparse arrays
@test issparse(vcat(spmat, spmat))
@test issparse(hcat(spmat, spmat))
@test issparse(hvcat((2,), spmat, spmat))
@test issparse(@inferred(hvcat((2,), spmat, spmat)))
@test issparse(cat(spmat, spmat; dims=(1,2)))
# Test that concatenations of a sparse matrice with a dense matrix/vector yield sparse arrays
@test issparse(vcat(spmat, densemat))
Expand Down

0 comments on commit 89e78fd

Please sign in to comment.