From 89e78fdf4a89d1cfcb1d2b9af13b29ebe4757850 Mon Sep 17 00:00:00 2001 From: Martin Holters Date: Mon, 11 Jan 2021 17:31:35 +0100 Subject: [PATCH] Make sparse `hvcat` inferable This also requires making sparse `vcat` and `hcat` inferable in the vararg case which in turn requires a different way to determine the resulting index type, now implemented similar to `promote_eltype`. --- stdlib/SparseArrays/src/sparsematrix.jl | 8 ++++++-- stdlib/SparseArrays/src/sparsevector.jl | 17 ++++++++--------- stdlib/SparseArrays/test/sparse.jl | 4 ++-- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/stdlib/SparseArrays/src/sparsematrix.jl b/stdlib/SparseArrays/src/sparsematrix.jl index 8673b52f96bcf..5bcdc9fff9ff3 100644 --- a/stdlib/SparseArrays/src/sparsematrix.jl +++ b/stdlib/SparseArrays/src/sparsematrix.jl @@ -3274,6 +3274,10 @@ dropstored!(A::AbstractSparseMatrixCSC, ::Colon) = dropstored!(A, :, :) # Sparse concatenation +promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}) where {Ti} = Ti +promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}, X::AbstractSparseMatrixCSC...) where {Ti} = + promote_type(Ti, promote_idxtype(X...)) + function vcat(X::AbstractSparseMatrixCSC...) num = length(X) mX = Int[ size(x, 1) for x in X ] @@ -3288,7 +3292,7 @@ function vcat(X::AbstractSparseMatrixCSC...) end Tv = promote_eltype(X...) - Ti = promote_eltype(map(x->rowvals(x), X)...) + Ti = promote_idxtype(X...) nnzX = Int[ nnz(x) for x in X ] nnz_res = sum(nnzX) @@ -3340,7 +3344,7 @@ function hcat(X::AbstractSparseMatrixCSC...) n = sum(nX) Tv = promote_eltype(X...) - Ti = promote_eltype(map(x->rowvals(x), X)...) + Ti = promote_idxtype(X...) colptr = Vector{Ti}(undef, n+1) nnzX = Int[ nnz(x) for x in X ] diff --git a/stdlib/SparseArrays/src/sparsevector.jl b/stdlib/SparseArrays/src/sparsevector.jl index 777be897ea7df..30747ffda82e7 100644 --- a/stdlib/SparseArrays/src/sparsevector.jl +++ b/stdlib/SparseArrays/src/sparsevector.jl @@ -1091,17 +1091,16 @@ function vcat(Xin::_SparseConcatGroup...) X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin) vcat(X...) end -function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) - nbr = length(rows) # number of block rows - - tmp_rows = Vector{SparseMatrixCSC}(undef, nbr) - k = 0 - @inbounds for i = 1 : nbr - tmp_rows[i] = hcat(X[(1 : rows[i]) .+ k]...) - k += rows[i] +hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) = + vcat(_hvcat_rows(rows, X...)...) +function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) + if row1 ≤ 0 + throw(ArgumentError("length of block row must be positive, got $row1")) end - vcat(tmp_rows...) + # provide X[1] separately to convince inference that we don't call hcat() without arguments + return (hcat(X[1], X[2 : row1]...), _hvcat_rows(rows, X[row1+1:end]...)...) end +_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = () # make sure UniformScaling objects are converted to sparse matrices for concatenation promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index 1558c1fa02b22..35b1d8aec74d4 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -165,7 +165,7 @@ end sz34 = spzeros(3, 4) se77 = sparse(1.0I, 7, 7) @testset "h+v concatenation" begin - @test [se44 sz42 sz41; sz34 se33] == se77 + @test @inferred(hvcat((3, 2), se44, sz42, sz41, sz34, se33)) == se77 # [se44 sz42 sz41; sz34 se33] @test length(nonzeros([sp33 0I; 1I 0I])) == 6 end @@ -2201,7 +2201,7 @@ end # Test that concatenations of pairs of sparse matrices yield sparse arrays @test issparse(vcat(spmat, spmat)) @test issparse(hcat(spmat, spmat)) - @test issparse(hvcat((2,), spmat, spmat)) + @test issparse(@inferred(hvcat((2,), spmat, spmat))) @test issparse(cat(spmat, spmat; dims=(1,2))) # Test that concatenations of a sparse matrice with a dense matrix/vector yield sparse arrays @test issparse(vcat(spmat, densemat))