From 8e76b12c01e9a76e70d6b7b1d6c6db7b40f1cee3 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Wed, 22 Apr 2020 02:10:03 -0600 Subject: [PATCH] Sparse addition of Symmetric and Hermitian matrices (#35325) Co-authored-by: MasonProtter Co-authored-by: Daniel Karrasch --- stdlib/LinearAlgebra/src/symmetric.jl | 22 ++++++++++++++++---- stdlib/SparseArrays/src/linalg.jl | 20 +++++++++++++++++- stdlib/SparseArrays/test/sparse.jl | 30 +++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index b62f56fc3b41e..20a25084372dd 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -461,11 +461,25 @@ end (-)(A::Hermitian) = Hermitian(-A.data, sym_uplo(A.uplo)) ## Addition/subtraction +for f ∈ (:+, :-), (Wrapper, conjugation) ∈ ((:Hermitian, :adjoint), (:Symmetric, :transpose)) + @eval begin + function $f(A::$Wrapper, B::$Wrapper) + if A.uplo == B.uplo + return $Wrapper($f(parent(A), parent(B)), sym_uplo(A.uplo)) + elseif A.uplo == 'U' + return $Wrapper($f(parent(A), $conjugation(parent(B))), :U) + else + return $Wrapper($f($conjugation(parent(A)), parent(B)), :U) + end + end + end +end + for f in (:+, :-) - @eval $f(A::Symmetric, B::Symmetric) = Symmetric($f(A.data, B), sym_uplo(A.uplo)) - @eval $f(A::Hermitian, B::Hermitian) = Hermitian($f(A.data, B), sym_uplo(A.uplo)) - @eval $f(A::Hermitian, B::Symmetric{<:Real}) = Hermitian($f(A.data, B), sym_uplo(A.uplo)) - @eval $f(A::Symmetric{<:Real}, B::Hermitian) = Hermitian($f(A.data, B), sym_uplo(A.uplo)) + @eval begin + $f(A::Hermitian, B::Symmetric{<:Real}) = $f(A, Hermitian(parent(B), sym_uplo(B.uplo))) + $f(A::Symmetric{<:Real}, B::Hermitian) = $f(Hermitian(parent(A), sym_uplo(A.uplo)), B) + end end ## Matvec diff --git a/stdlib/SparseArrays/src/linalg.jl b/stdlib/SparseArrays/src/linalg.jl index 070ab01fb0ff7..d24cfaf0849a3 100644 --- a/stdlib/SparseArrays/src/linalg.jl +++ b/stdlib/SparseArrays/src/linalg.jl @@ -1,12 +1,30 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license -import LinearAlgebra: checksquare +import LinearAlgebra: checksquare, sym_uplo using Random: rand! # In matrix-vector multiplication, the correct orientation of the vector is assumed. const StridedOrTriangularMatrix{T} = Union{StridedMatrix{T}, LowerTriangular{T}, UnitLowerTriangular{T}, UpperTriangular{T}, UnitUpperTriangular{T}} const AdjOrTransStridedOrTriangularMatrix{T} = Union{StridedOrTriangularMatrix{T},Adjoint{<:Any,<:StridedOrTriangularMatrix{T}},Transpose{<:Any,<:StridedOrTriangularMatrix{T}}} +for op ∈ (:+, :-), Wrapper ∈ (:Hermitian, :Symmetric) + @eval begin + $op(A::AbstractSparseMatrix, B::$Wrapper{<:Any,<:AbstractSparseMatrix}) = $op(A, sparse(B)) + $op(A::$Wrapper{<:Any,<:AbstractSparseMatrix}, B::AbstractSparseMatrix) = $op(sparse(A), B) + + $op(A::AbstractSparseMatrix, B::$Wrapper) = $op(A, collect(B)) + $op(A::$Wrapper, B::AbstractSparseMatrix) = $op(collect(A), B) + end +end +for op ∈ (:+, :-) + @eval begin + $op(A::Symmetric{<:Any, <:AbstractSparseMatrix}, B::Hermitian{<:Any, <:AbstractSparseMatrix}) = $op(sparse(A), sparse(B)) + $op(A::Hermitian{<:Any, <:AbstractSparseMatrix}, B::Symmetric{<:Any, <:AbstractSparseMatrix}) = $op(sparse(A), sparse(B)) + $op(A::Symmetric{<:Real, <:AbstractSparseMatrix}, B::Hermitian{<:Any, <:AbstractSparseMatrix}) = $op(Hermitian(parent(A), sym_uplo(A.uplo)), B) + $op(A::Hermitian{<:Any, <:AbstractSparseMatrix}, B::Symmetric{<:Real, <:AbstractSparseMatrix}) = $op(A, Hermitian(parent(B), sym_uplo(B.uplo))) + end +end + function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number) size(A, 2) == size(B, 1) || throw(DimensionMismatch()) size(A, 1) == size(C, 1) || throw(DimensionMismatch()) diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index 9449895f16b66..3961d0926bca0 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -2906,4 +2906,34 @@ end @test B ≈ mapreduce(identity, +, Matrix(A), dims=2) end +@testset "Symmetric and Hermitian #35325" begin + A = sprandn(ComplexF64, 10, 10, 0.1) + B = sprandn(ComplexF64, 10, 10, 0.1) + + @test Symmetric(real(A)) + Hermitian(B) isa Hermitian{ComplexF64, <:SparseMatrixCSC} + @test Hermitian(A) + Symmetric(real(B)) isa Hermitian{ComplexF64, <:SparseMatrixCSC} + @test Hermitian(A) + Symmetric(B) isa SparseMatrixCSC + @testset "$Wrapper $op" for op ∈ (+, -), Wrapper ∈ (Hermitian, Symmetric) + AWU = Wrapper(A, :U) + AWL = Wrapper(A, :L) + BWU = Wrapper(B, :U) + BWL = Wrapper(B, :L) + + @test op(AWU, B) isa SparseMatrixCSC + @test op(A, BWL) isa SparseMatrixCSC + + @test op(AWU, B) ≈ op(collect(AWU), B) + @test op(AWL, B) ≈ op(collect(AWL), B) + @test op(A, BWU) ≈ op(A, collect(BWU)) + @test op(A, BWL) ≈ op(A, collect(BWL)) + + @test op(AWU, BWL) isa Wrapper{ComplexF64, <:SparseMatrixCSC} + + @test op(AWU, BWU) ≈ op(collect(AWU), collect(BWU)) + @test op(AWU, BWL) ≈ op(collect(AWU), collect(BWL)) + @test op(AWL, BWU) ≈ op(collect(AWL), collect(BWU)) + @test op(AWL, BWL) ≈ op(collect(AWL), collect(BWL)) + end +end + end # module