From 27b31d188dd4f8bc690be7400a3442ea6246ac63 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 9 Feb 2024 16:26:32 +0530 Subject: [PATCH] Reroute algebraic functions for `Symmetric`/`Hermitian` through triangular (#52942) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This ensures that only the triangular indices are accessed for strided parent matrices. Fix #52895 ```julia julia> M = Matrix{Complex{BigFloat}}(undef, 2, 2); julia> M[1,1] = M[2,2] = M[1,2] = 2; julia> H = Hermitian(M) 2×2 Hermitian{Complex{BigFloat}, Matrix{Complex{BigFloat}}}: 2.0+0.0im 2.0+0.0im 2.0-0.0im 2.0+0.0im julia> H + H # works after this 2×2 Hermitian{Complex{BigFloat}, Matrix{Complex{BigFloat}}}: 4.0+0.0im 4.0+0.0im 4.0-0.0im 4.0+0.0im ``` This also provides a speed-up in several common cases (allocations mentioned only when they differ): ```julia julia> H = Hermitian(rand(ComplexF64,1000,1000)); julia> H2 = Hermitian(rand(ComplexF64,1000,1000),:L); ``` | Operation | master | PR | | ---- | ---- | ---- | |`-H` |2.247 ms | 1.384 ms | | `real(H)` |1.544 ms |1.175 ms | |`H + H` |2.288 ms |1.978 ms | |`H + H2` |5.139 ms |3.287 ms | | `isdiag(H)` |23.042 ns (1 allocation: 16 bytes) |16.778 ns (0 allocations: 0 bytes) | I'm not entirely certain why `isdiag(H)` allocates on master, as union splitting should handle this automatically, but manually splitting the union appears to help. --- stdlib/LinearAlgebra/src/symmetric.jl | 72 ++++++++++++++++---------- stdlib/LinearAlgebra/test/symmetric.jl | 43 +++++++++++++++ 2 files changed, 88 insertions(+), 27 deletions(-) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index d0066ddecbfe0..21047dad8fcd9 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -269,10 +269,34 @@ end end end +_conjugation(::Symmetric) = transpose +_conjugation(::Hermitian) = adjoint + diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo)) diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo)) -isdiag(A::HermOrSym) = isdiag(A.uplo == 'U' ? UpperTriangular(A.data) : LowerTriangular(A.data)) +function applytri(f, A::HermOrSym) + if A.uplo == 'U' + f(UpperTriangular(A.data)) + else + f(LowerTriangular(A.data)) + end +end + +function applytri(f, A::HermOrSym, B::HermOrSym) + if A.uplo == B.uplo == 'U' + f(UpperTriangular(A.data), UpperTriangular(B.data)) + elseif A.uplo == B.uplo == 'L' + f(LowerTriangular(A.data), LowerTriangular(B.data)) + elseif A.uplo == 'U' + f(UpperTriangular(A.data), UpperTriangular(_conjugation(B)(B.data))) + else # A.uplo == 'L' + f(UpperTriangular(_conjugation(A)(A.data)), UpperTriangular(B.data)) + end +end +parentof_applytri(f, args...) = applytri(parent ∘ f, args...) + +isdiag(A::HermOrSym) = applytri(isdiag, A) # For A<:Union{Symmetric,Hermitian}, similar(A[, neweltype]) should yield a matrix with the same # symmetry type, uplo flag, and underlying storage type as A. The following methods cover these cases. @@ -314,8 +338,8 @@ Hermitian{T,S}(A::Hermitian) where {T,S<:AbstractMatrix{T}} = Hermitian{T,S}(con AbstractMatrix{T}(A::Hermitian) where {T} = Hermitian(convert(AbstractMatrix{T}, A.data), sym_uplo(A.uplo)) AbstractMatrix{T}(A::Hermitian{T}) where {T} = copy(A) -copy(A::Symmetric{T,S}) where {T,S} = (B = copy(A.data); Symmetric{T,typeof(B)}(B,A.uplo)) -copy(A::Hermitian{T,S}) where {T,S} = (B = copy(A.data); Hermitian{T,typeof(B)}(B,A.uplo)) +copy(A::Symmetric) = (Symmetric(parentof_applytri(copy, A), sym_uplo(A.uplo))) +copy(A::Hermitian) = (Hermitian(parentof_applytri(copy, A), sym_uplo(A.uplo))) function copyto!(dest::Symmetric, src::Symmetric) if src.uplo == dest.uplo @@ -389,9 +413,9 @@ transpose(A::Hermitian) = Transpose(A) real(A::Symmetric{<:Real}) = A real(A::Hermitian{<:Real}) = A -real(A::Symmetric) = Symmetric(real(A.data), sym_uplo(A.uplo)) -real(A::Hermitian) = Hermitian(real(A.data), sym_uplo(A.uplo)) -imag(A::Symmetric) = Symmetric(imag(A.data), sym_uplo(A.uplo)) +real(A::Symmetric) = Symmetric(parentof_applytri(real, A), sym_uplo(A.uplo)) +real(A::Hermitian) = Hermitian(parentof_applytri(real, A), sym_uplo(A.uplo)) +imag(A::Symmetric) = Symmetric(parentof_applytri(imag, A), sym_uplo(A.uplo)) Base.copy(A::Adjoint{<:Any,<:Symmetric}) = Symmetric(copy(adjoint(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U)) @@ -401,8 +425,9 @@ Base.copy(A::Transpose{<:Any,<:Hermitian}) = tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations) tr(A::Hermitian) = real(tr(A.data)) -Base.conj(A::HermOrSym) = typeof(A)(conj(A.data), A.uplo) -Base.conj!(A::HermOrSym) = typeof(A)(conj!(A.data), A.uplo) +Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), sym_uplo(A.uplo)) +Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), sym_uplo(A.uplo)) +Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo) # tril/triu function tril(A::Hermitian, k::Integer=0) @@ -496,21 +521,14 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni end end -(-)(A::Symmetric) = Symmetric(-A.data, sym_uplo(A.uplo)) -(-)(A::Hermitian) = Hermitian(-A.data, sym_uplo(A.uplo)) +(-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), sym_uplo(A.uplo)) +(-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), sym_uplo(A.uplo)) ## Addition/subtraction -for f ∈ (:+, :-), (Wrapper, conjugation) ∈ ((:Hermitian, :adjoint), (:Symmetric, :transpose)) - @eval begin - function $f(A::$Wrapper, B::$Wrapper) - if A.uplo == B.uplo - return $Wrapper($f(parent(A), parent(B)), sym_uplo(A.uplo)) - elseif A.uplo == 'U' - return $Wrapper($f(parent(A), $conjugation(parent(B))), :U) - else - return $Wrapper($f($conjugation(parent(A)), parent(B)), :U) - end - end +for f ∈ (:+, :-), Wrapper ∈ (:Hermitian, :Symmetric) + @eval function $f(A::$Wrapper, B::$Wrapper) + uplo = A.uplo == B.uplo ? sym_uplo(A.uplo) : (:U) + $Wrapper(parentof_applytri($f, A, B), uplo) end end @@ -555,12 +573,12 @@ function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector) end # Scaling with Number -*(A::Symmetric, x::Number) = Symmetric(A.data*x, sym_uplo(A.uplo)) -*(x::Number, A::Symmetric) = Symmetric(x*A.data, sym_uplo(A.uplo)) -*(A::Hermitian, x::Real) = Hermitian(A.data*x, sym_uplo(A.uplo)) -*(x::Real, A::Hermitian) = Hermitian(x*A.data, sym_uplo(A.uplo)) -/(A::Symmetric, x::Number) = Symmetric(A.data/x, sym_uplo(A.uplo)) -/(A::Hermitian, x::Real) = Hermitian(A.data/x, sym_uplo(A.uplo)) +*(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y * x, A), sym_uplo(A.uplo)) +*(x::Number, A::Symmetric) = Symmetric(parentof_applytri(y -> x * y, A), sym_uplo(A.uplo)) +*(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y * x, A), sym_uplo(A.uplo)) +*(x::Real, A::Hermitian) = Hermitian(parentof_applytri(y -> x * y, A), sym_uplo(A.uplo)) +/(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y/x, A), sym_uplo(A.uplo)) +/(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y/x, A), sym_uplo(A.uplo)) factorize(A::HermOrSym) = _factorize(A) function _factorize(A::HermOrSym{T}; check::Bool=true) where T diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index b3e0b7b560e7a..d3b24ccf78b0b 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -470,6 +470,42 @@ end end end +@testset "non-isbits algebra" begin + for ST in (Symmetric, Hermitian), uplo in (:L, :U) + M = Matrix{Complex{BigFloat}}(undef,2,2) + M[1,1] = rand() + M[2,2] = rand() + M[1+(uplo==:L), 1+(uplo==:U)] = rand(ComplexF64) + S = ST(M, uplo) + MS = Matrix(S) + @test real(S) == real(MS) + @test imag(S) == imag(MS) + @test conj(S) == conj(MS) + @test conj!(copy(S)) == conj(MS) + @test -S == -MS + @test S + S == MS + MS + @test S - S == MS - MS + @test S*2 == 2*S == 2*MS + @test S/2 == MS/2 + end + @testset "mixed uplo" begin + Mu = Matrix{Complex{BigFloat}}(undef,2,2) + Mu[1,1] = Mu[2,2] = 3 + Mu[1,2] = 2 + 3im + Ml = Matrix{Complex{BigFloat}}(undef,2,2) + Ml[1,1] = Ml[2,2] = 4 + Ml[2,1] = 4 + 5im + for ST in (Symmetric, Hermitian) + Su = ST(Mu, :U) + MSu = Matrix(Su) + Sl = ST(Ml, :L) + MSl = Matrix(Sl) + @test Su + Sl == Sl + Su == MSu + MSl + @test Su - Sl == -(Sl - Su) == MSu - MSl + end + end +end + # bug identified in PR #52318: dot products of quaternionic Hermitian matrices, # or any number type where conj(a)*conj(b) ≠ conj(a*b): @testset "dot Hermitian quaternion #52318" begin @@ -932,4 +968,11 @@ end end end +@testset "conj for immutable" begin + S = Symmetric(reshape((1:16)*im, 4, 4)) + @test conj(S) == conj(Array(S)) + H = Hermitian(reshape((1:16)*im, 4, 4)) + @test conj(H) == conj(Array(H)) +end + end # module TestSymmetric