Skip to content

Commit

Permalink
Preserve types when adding/subtracting Herm/Sym/UniformScaling (#29500)
Browse files Browse the repository at this point in the history
* Preserve types when adding/subtracting Herm/Sym/UniformScaling

* Make `real(::SymOrHerm{<:Real})` consistent with `real(::Array)`.

* Fix embarrassing ambiguity

* More tests, remove imag(::Hermitian), simplify code

* Remove `.λ`
  • Loading branch information
dalum authored and andreasnoack committed Dec 13, 2018
1 parent 640b155 commit 49023b5
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 3 deletions.
14 changes: 14 additions & 0 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ transpose(A::Hermitian{<:Real}) = A
adjoint(A::Symmetric) = Adjoint(A)
transpose(A::Hermitian) = Transpose(A)

real(A::Symmetric{<:Real}) = A
real(A::Hermitian{<:Real}) = A
real(A::Symmetric) = Symmetric(real(A.data), sym_uplo(A.uplo))
real(A::Hermitian) = Hermitian(real(A.data), sym_uplo(A.uplo))
imag(A::Symmetric) = Symmetric(imag(A.data), sym_uplo(A.uplo))

Base.copy(A::Adjoint{<:Any,<:Hermitian}) = copy(A.parent)
Base.copy(A::Transpose{<:Any,<:Symmetric}) = copy(A.parent)
Base.copy(A::Adjoint{<:Any,<:Symmetric}) =
Expand Down Expand Up @@ -394,6 +400,14 @@ end
(-)(A::Symmetric{Tv,S}) where {Tv,S} = Symmetric{Tv,S}(-A.data, A.uplo)
(-)(A::Hermitian{Tv,S}) where {Tv,S} = Hermitian{Tv,S}(-A.data, A.uplo)

## Addition/subtraction
for f in (:+, :-)
@eval $f(A::Symmetric, B::Symmetric) = Symmetric($f(A.data, B), sym_uplo(A.uplo))
@eval $f(A::Hermitian, B::Hermitian) = Hermitian($f(A.data, B), sym_uplo(A.uplo))
@eval $f(A::Hermitian, B::Symmetric{<:Real}) = Hermitian($f(A.data, B), sym_uplo(A.uplo))
@eval $f(A::Symmetric{<:Real}, B::Hermitian) = Hermitian($f(A.data, B), sym_uplo(A.uplo))
end

## Matvec
mul!(y::StridedVector{T}, A::Symmetric{T,<:StridedMatrix}, x::StridedVector{T}) where {T<:BlasFloat} =
BLAS.symv!(A.uplo, one(T), A.data, x, zero(T), y)
Expand Down
25 changes: 25 additions & 0 deletions stdlib/LinearAlgebra/src/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,31 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular),
end
end

# Adding a complex UniformScaling to the diagonal of a Hermitian
# matrix breaks the hermiticity, if the UniformScaling is non-real.
# However, to preserve type stability, we do not special-case a
# UniformScaling{<:Complex} that happens to be real.
function (+)(A::Hermitian{T,S}, J::UniformScaling{<:Complex}) where {T,S}
A_ = copytri!(copy(parent(A)), A.uplo)

This comment has been minimized.

Copy link
@stevengj

stevengj May 10, 2019

Member

This is missing the conj=true argument, so it gives the wrong answer for complex-Hermitian matrices (which weren't covered by the tests).

Incidentally, it also makes an extra copy of the matrix if the result type is different from eltype(A).

I'll file a PR shortly with a bugfix and a test.

B = convert(AbstractMatrix{Base._return_type(+, Tuple{eltype(A), typeof(J)})}, A_)
@inbounds for i in diagind(B)
B[i] += J
end
return B
end

function (-)(J::UniformScaling{<:Complex}, A::Hermitian{T,S}) where {T,S}
A_ = copytri!(copy(parent(A)), A.uplo)
B = convert(AbstractMatrix{Base._return_type(+, Tuple{eltype(A), typeof(J)})}, A_)
@inbounds for i in eachindex(B)
B[i] = -B[i]
end
@inbounds for i in diagind(B)
B[i] += J
end
return B
end

function (+)(A::AbstractMatrix, J::UniformScaling)
checksquare(A)
B = copy_oftype(A, Base._return_type(+, Tuple{eltype(A), typeof(J)}))
Expand Down
29 changes: 26 additions & 3 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ end
@test (-Hermitian(aherm))::typeof(Hermitian(aherm)) == -aherm
end

@testset "Addition and subtraction for Symmetric/Hermitian matrices" begin
for f in (+, -)
@test (f(Symmetric(asym), Symmetric(aposs)))::typeof(Symmetric(asym)) == f(asym, aposs)
@test (f(Hermitian(aherm), Hermitian(apos)))::typeof(Hermitian(aherm)) == f(aherm, apos)
@test (f(Symmetric(real(asym)), Hermitian(aherm)))::typeof(Hermitian(aherm)) == f(real(asym), aherm)
@test (f(Hermitian(aherm), Symmetric(real(asym))))::typeof(Hermitian(aherm)) == f(aherm, real(asym))
@test (f(Symmetric(asym), Hermitian(aherm))) == f(asym, aherm)
@test (f(Hermitian(aherm), Symmetric(asym))) == f(aherm, asym)
end
end

@testset "getindex and unsafe_getindex" begin
@test aherm[1,1] == Hermitian(aherm)[1,1]
@test asym[1,1] == Symmetric(asym)[1,1]
Expand Down Expand Up @@ -153,6 +164,21 @@ end
@test transpose(H) == Hermitian(copy(transpose(aherm)))
end
end

@testset "real, imag" begin
S = Symmetric(asym)
H = Hermitian(aherm)
@test issymmetric(real(S))
@test ishermitian(real(H))
if eltya <: Real
@test real(S) === S == asym
@test real(H) === H == aherm
elseif eltya <: Complex
@test issymmetric(imag(S))
@test !ishermitian(imag(H))
end
end

end

@testset "linalg unary ops" begin
Expand Down Expand Up @@ -415,9 +441,6 @@ end

@test T([true false; false true]) .+ true == T([2 1; 1 2])
end

@test_throws ArgumentError Hermitian(X) + 2im*I
@test_throws ArgumentError Hermitian(X) - 2im*I
end

@testset "Issue #21981" begin
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ let
@test @inferred(J - T) == J - Array(T)
@test @inferred(T\I) == inv(T)

if isa(A, Array)
T = Hermitian(randn(3,3))
else
T = Hermitian(view(randn(3,3), 1:3, 1:3))
end
@test @inferred(T + J) == Array(T) + J
@test @inferred(J + T) == J + Array(T)
@test @inferred(T - J) == Array(T) - J
@test @inferred(J - T) == J - Array(T)

@test @inferred(I\A) == A
@test @inferred(A\I) == inv(A)
@test @inferred\I) === UniformScaling(1/λ)
Expand Down

0 comments on commit 49023b5

Please sign in to comment.