Skip to content

Commit

Permalink
fix matrix power of Symmetric/Hermitian (JuliaLang#22923)
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikekre committed Aug 1, 2017
1 parent 73ca642 commit 738d042
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 62 deletions.
60 changes: 32 additions & 28 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,34 +357,12 @@ kron(a::AbstractMatrix, b::AbstractVector) = kron(a, reshape(b, length(b), 1))
kron(a::AbstractVector, b::AbstractMatrix) = kron(reshape(a, length(a), 1), b)

# Matrix power
(^)(A::AbstractMatrix{T}, p::Integer) where {T} = p < 0 ? Base.power_by_squaring(inv(A), -p) : Base.power_by_squaring(A, p)
function (^)(A::AbstractMatrix{T}, p::Real) where T
# For integer powers, use repeated squaring
if isinteger(p)
TT = Base.promote_op(^, eltype(A), typeof(p))
return (TT == eltype(A) ? A : copy!(similar(A, TT), A))^Integer(p)
end

# If possible, use diagonalization
if T <: Real && issymmetric(A)
return (Symmetric(A)^p)
end
if ishermitian(A)
return (Hermitian(A)^p)
end

n = checksquare(A)

# Quicker return if A is diagonal
if isdiag(A)
retmat = copy(A)
for i in 1:n
retmat[i, i] = retmat[i, i] ^ p
end
return retmat
end

# Otherwise, use Schur decomposition
(^)(A::AbstractMatrix, p::Integer) = p < 0 ? Base.power_by_squaring(inv(A), -p) : Base.power_by_squaring(A, p)
function integerpow(A::AbstractMatrix{T}, p) where T
TT = Base.promote_op(^, T, typeof(p))
return (TT == T ? A : copy!(similar(A, TT), A))^Integer(p)
end
function schurpow(A::AbstractMatrix, p)
if istriu(A)
# Integer part
retmat = A ^ floor(p)
Expand Down Expand Up @@ -416,6 +394,32 @@ function (^)(A::AbstractMatrix{T}, p::Real) where T
return retmat
end
end
function (^)(A::AbstractMatrix{T}, p::Real) where T
n = checksquare(A)

# For integer powers, use power_by_squaring
isinteger(p) && return integerpow(A, p)

# If possible, use diagonalization
if issymmetric(A)
return (Symmetric(A)^p)
end
if ishermitian(A)
return (Hermitian(A)^p)
end

# Quicker return if A is diagonal
if isdiag(A)
retmat = copy(A)
for i in 1:n
retmat[i, i] = retmat[i, i] ^ p
end
return retmat
end

# Otherwise, use Schur decomposition
return schurpow(A, p)
end
(^)(A::AbstractMatrix, p::Number) = expm(p*logm(A))

# Matrix exponential
Expand Down
26 changes: 15 additions & 11 deletions base/linalg/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,50 +538,54 @@ function svdvals!(A::RealHermSymComplexHerm)
end

# Matrix functions
function ^(A::Symmetric{T}, p::Integer) where T<:Real
^(A::Symmetric{<:Real}, p::Integer) = sympow(A, p)
^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p)
function sympow(A::Symmetric, p::Integer)
if p < 0
return Symmetric(Base.power_by_squaring(inv(A), -p))
else
return Symmetric(Base.power_by_squaring(A, p))
end
end
function ^(A::Symmetric{T}, p::Real) where T<:Real
function ^(A::Symmetric{<:Real}, p::Real)
isinteger(p) && return integerpow(A, p)
F = eigfact(A)
if all-> λ 0, F.values)
retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors'
return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors')
else
retmat = (F.vectors * Diagonal((complex(F.values)).^p)) * F.vectors'
return Symmetric((F.vectors * Diagonal((complex(F.values)).^p)) * F.vectors')
end
return Symmetric(retmat)
end
function ^(A::Symmetric{<:Complex}, p::Real)
isinteger(p) && return integerpow(A, p)
return Symmetric(schurpow(A, p))
end
function ^(A::Hermitian, p::Integer)
n = checksquare(A)
if p < 0
retmat = Base.power_by_squaring(inv(A), -p)
else
retmat = Base.power_by_squaring(A, p)
end
for i = 1:n
for i = 1:size(A,1)
retmat[i,i] = real(retmat[i,i])
end
return Hermitian(retmat)
end
function ^(A::Hermitian{T}, p::Real) where T
n = checksquare(A)
isinteger(p) && return integerpow(A, p)
F = eigfact(A)
if all-> λ 0, F.values)
retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors'
if T <: Real
return Hermitian(retmat)
else
for i = 1:n
for i = 1:size(A,1)
retmat[i,i] = real(retmat[i,i])
end
return Hermitian(retmat)
end
else
retmat = (F.vectors * Diagonal((complex(F.values).^p))) * F.vectors'
return retmat
return (F.vectors * Diagonal((complex(F.values).^p))) * F.vectors'
end
end

Expand Down
44 changes: 21 additions & 23 deletions test/linalg/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,29 +245,27 @@ end
end

@testset "pow" begin
@test (asym)^2 Array(Symmetric(asym)^2)
@test (asym)^-2 Array(Symmetric(asym)^-2)
@test (aherm)^2 Array(Hermitian(aherm)^2)
@test (aherm)^-2 Array(Hermitian(aherm)^-2)
if eltya == Int
@test (asym)^2.0 real(Array(Symmetric(asym)^2.0))
@test (asym)^-2.0 real(Array(Symmetric(asym)^-2.0))
@test (aherm)^2.0 real(Array(Hermitian(aherm)^2.0))
@test (aherm)^-2.0 real(Array(Hermitian(aherm)^-2.0))
@test (apos)^2.0 real(Array(Hermitian(apos)^2.0))
elseif eltya <: Real
@test (asym)^2.0 real(Array(Symmetric(asym)^2.0)) rtol=100*n^2*eps(real(eltya))
@test (asym)^-2.0 real(Array(Symmetric(asym)^-2.0)) rtol=100*n^2*eps(real(eltya))
@test (aherm)^2.0 real(Array(Hermitian(aherm)^2.0)) rtol=100*n^2*eps(real(eltya))
@test (aherm)^-2.0 real(Array(Hermitian(aherm)^-2.0)) rtol=100*n^2*eps(real(eltya))
@test (apos)^2.0 real(Array(Hermitian(apos)^2.0)) rtol=100*n^2*eps(real(eltya))
else
@test (asym)^2.0 Array(Symmetric(asym)^2.0) rtol=100*n^2*eps(real(eltya))
@test (asym)^-2.0 Array(Symmetric(asym)^-2.0) rtol=100*n^2*eps(real(eltya))
@test (aherm)^2.0 Array(Hermitian(aherm)^2.0) rtol=100*n^2*eps(real(eltya))
@test (aherm)^-2.0 Array(Hermitian(aherm)^-2.0) rtol=100*n^2*eps(real(eltya))
@test (apos)^2.0 Array(Hermitian(apos)^2.0) rtol=100*n^2*eps(real(eltya))
end
# Integer power
@test (asym)^2 (Symmetric(asym)^2)::Symmetric
@test (asym)^-2 (Symmetric(asym)^-2)::Symmetric
@test (aposs)^2 (Symmetric(aposs)^2)::Symmetric
@test (aherm)^2 (Hermitian(aherm)^2)::Hermitian
@test (aherm)^-2 (Hermitian(aherm)^-2)::Hermitian
@test (apos)^2 (Hermitian(apos)^2)::Hermitian
# integer floating point power
@test (asym)^2.0 (Symmetric(asym)^2.0)::Symmetric
@test (asym)^-2.0 (Symmetric(asym)^-2.0)::Symmetric
@test (aposs)^2.0 (Symmetric(aposs)^2.0)::Symmetric
@test (aherm)^2.0 (Hermitian(aherm)^2.0)::Hermitian
@test (aherm)^-2.0 (Hermitian(aherm)^-2.0)::Hermitian
@test (apos)^2.0 (Hermitian(apos)^2.0)::Hermitian
# non-integer floating point power
@test (asym)^2.5 (Symmetric(asym)^2.5)::Symmetric
@test (asym)^-2.5 (Symmetric(asym)^-2.5)::Symmetric
@test (aposs)^2.5 (Symmetric(aposs)^2.5)::Symmetric
@test (aherm)^2.5 (Hermitian(aherm)^2.5)#::Hermitian
@test (aherm)^-2.5 (Hermitian(aherm)^-2.5)#::Hermitian
@test (apos)^2.5 (Hermitian(apos)^2.5)::Hermitian
end
end
end
Expand Down

0 comments on commit 738d042

Please sign in to comment.