Skip to content

Commit

Permalink
Conversion methods sparse matrix -> special linalg type (JuliaLang#40988
Browse files Browse the repository at this point in the history
)
  • Loading branch information
carstenbauer committed May 30, 2021
1 parent acdffeb commit 5cb5a87
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion stdlib/SparseArrays/src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using LinearAlgebra

import Base: +, -, *, \, /, &, |, xor, ==, zero
import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,
issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!,
issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded,
cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu, matprod

import Base: acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
Expand Down
11 changes: 11 additions & 0 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,17 @@ Array(S::AbstractSparseMatrixCSC) = Matrix(S)

convert(T::Type{<:AbstractSparseMatrixCSC}, m::AbstractMatrix) = m isa T ? m : T(m)

convert(T::Type{<:Diagonal}, m::AbstractSparseMatrixCSC) = m isa T ? m :
isdiag(m) ? T(m) : throw(ArgumentError("matrix cannot be represented as Diagonal"))
convert(T::Type{<:SymTridiagonal}, m::AbstractSparseMatrixCSC) = m isa T ? m :
issymmetric(m) && isbanded(m, -1, 1) ? T(m) : throw(ArgumentError("matrix cannot be represented as SymTridiagonal"))
convert(T::Type{<:Tridiagonal}, m::AbstractSparseMatrixCSC) = m isa T ? m :
isbanded(m, -1, 1) ? T(m) : throw(ArgumentError("matrix cannot be represented as Tridiagonal"))
convert(T::Type{<:LowerTriangular}, m::AbstractSparseMatrixCSC) = m isa T ? m :
istril(m) ? T(m) : throw(ArgumentError("matrix cannot be represented as LowerTriangular"))
convert(T::Type{<:UpperTriangular}, m::AbstractSparseMatrixCSC) = m isa T ? m :
istriu(m) ? T(m) : throw(ArgumentError("matrix cannot be represented as UpperTriangular"))

float(S::SparseMatrixCSC) = SparseMatrixCSC(size(S, 1), size(S, 2), copy(getcolptr(S)), copy(rowvals(S)), float.(nonzeros(S)))
complex(S::SparseMatrixCSC) = SparseMatrixCSC(size(S, 1), size(S, 2), copy(getcolptr(S)), copy(rowvals(S)), complex(copy(nonzeros(S))))

Expand Down
18 changes: 18 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ end
@test Array(SparseMatrixCSC{eltype(a), Int8}(a)) == Array(a)
end

@testset "conversion to special LinearAlgebra types" begin
# issue 40924
@test convert(Diagonal, sparse(Diagonal(1:2))) isa Diagonal
@test convert(Diagonal, sparse(Diagonal(1:2))) == Diagonal(1:2)
@test convert(Tridiagonal, sparse(Tridiagonal(1:3, 4:7, 8:10))) isa Tridiagonal
@test convert(Tridiagonal, sparse(Tridiagonal(1:3, 4:7, 8:10))) == Tridiagonal(1:3, 4:7, 8:10)
@test convert(SymTridiagonal, sparse(SymTridiagonal(1:4, 5:7))) isa SymTridiagonal
@test convert(SymTridiagonal, sparse(SymTridiagonal(1:4, 5:7))) == SymTridiagonal(1:4, 5:7)

lt = LowerTriangular([1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0])
@test convert(LowerTriangular, sparse(lt)) isa LowerTriangular
@test convert(LowerTriangular, sparse(lt)) == lt

ut = UpperTriangular([1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0])
@test convert(UpperTriangular, sparse(ut)) isa UpperTriangular
@test convert(UpperTriangular, sparse(ut)) == ut
end

@testset "sparse matrix construction" begin
@test (A = fill(1.0+im,5,5); isequal(Array(sparse(A)), A))
@test_throws ArgumentError sparse([1,2,3], [1,2], [1,2,3], 3, 3)
Expand Down

0 comments on commit 5cb5a87

Please sign in to comment.