Skip to content

Commit

Permalink
Handle generic numbers in generic pivoted cholesky correctly (#54735)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Jun 14, 2024
1 parent 222231f commit e40b57f
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 89 deletions.
132 changes: 67 additions & 65 deletions stdlib/LinearAlgebra/src/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ function _chol!(A::AbstractMatrix, ::Type{UpperTriangular})
A[k,k] = Akk
AkkInv = inv(copy(Akk'))
for j = k + 1:n
for i = 1:k - 1
@simd for i = 1:k - 1
A[k,j] -= A[i,k]'A[i,j]
end
A[k,j] = AkkInv*A[k,j]
Expand All @@ -236,14 +236,15 @@ function _chol!(A::AbstractMatrix, ::Type{LowerTriangular})
return LowerTriangular(A), convert(BlasInt, k)
end
A[k,k] = Akk
AkkInv = inv(Akk)
AkkInv = inv(copy(Akk'))
for j = 1:k - 1
Akjc = A[k,j]'
@simd for i = k + 1:n
A[i,k] -= A[i,j]*A[k,j]'
A[i,k] -= A[i,j]*Akjc
end
end
for i = k + 1:n
A[i,k] *= AkkInv'
@simd for i = k + 1:n
A[i,k] *= AkkInv
end
end
end
Expand Down Expand Up @@ -301,103 +302,104 @@ _cholpivoted!(A::StridedMatrix{<:BlasFloat}, ::Type{LowerTriangular}, tol::Real,
LAPACK.pstrf!('L', A, tol)
## Non BLAS/LAPACK element types (generic)
function _cholpivoted!(A::AbstractMatrix, ::Type{UpperTriangular}, tol::Real, check::Bool)
rTA = real(eltype(A))
# checks
Base.require_one_based_indexing(A)
n = LinearAlgebra.checksquare(A)
# initialization
piv = collect(1:n)
dots = zeros(real(eltype(A)), n)
dots = zeros(rTA, n)
temp = similar(dots)
info = 0
rank = n

@inbounds begin
# first step
ajj, q = findmax(i -> real(A[i,i]), 1:n)
stop = tol < 0 ? eps(eltype(A))*n*abs(ajj) : tol
if ajj stop
return A, piv, convert(BlasInt, 0), convert(BlasInt, 1)
end
Akk, q = findmax(i -> real(A[i,i]), 1:n)
stop = tol < 0 ? eps(rTA)*n*abs(Akk) : tol
Akk stop && return A, piv, convert(BlasInt, 0), convert(BlasInt, 1)
# swap
_swap_rowcols!(A, UpperTriangular, n, 1, q)
piv[1], piv[q] = piv[q], piv[1]
A[1,1] = ajj = sqrt(ajj)
@views A[1, 2:n] .= A[1, 2:n] ./ ajj
A[1,1] = Akk = sqrt(Akk)
AkkInv = inv(copy(Akk'))
@simd for j in 2:n
A[1, j] *= AkkInv
end

for j in 2:n
for k in j:n
dots[k] += abs2(A[j-1, k])
temp[k] = real(A[k,k]) - dots[k]
end
ajj, q = findmax(i -> temp[i], j:n)
if ajj stop
rank = j - 1
info = 1
break
for k in 2:n
@simd for j in k:n
dots[j] += abs2(A[k-1, j])
temp[j] = real(A[j,j]) - dots[j]
end
q += j - 1
Akk, q = findmax(j -> temp[j], k:n)
Akk stop && return A, piv, convert(BlasInt, k - 1), convert(BlasInt, 1)
q += k - 1
# swap
_swap_rowcols!(A, UpperTriangular, n, j, q)
dots[j], dots[q] = dots[q], dots[j]
piv[j], piv[q] = piv[q], piv[j]
_swap_rowcols!(A, UpperTriangular, n, k, q)
dots[k], dots[q] = dots[q], dots[k]
piv[k], piv[q] = piv[q], piv[k]
# update
A[j,j] = ajj = sqrt(ajj)
@views if j < n
mul!(A[j, (j+1):n], A[1:(j-1), (j+1):n]', A[1:(j-1), j], -1, true)
A[j, j+1:n] ./= ajj
A[k,k] = Akk = sqrt(Akk)
AkkInv = inv(copy(Akk'))
for j in (k+1):n
@simd for i in 1:(k-1)
A[k,j] -= A[i,k]'A[i,j]
end
A[k,j] = AkkInv * A[k,j]
end
end
return A, piv, convert(BlasInt, rank), convert(BlasInt, info)
return A, piv, convert(BlasInt, n), convert(BlasInt, 0)
end
end
function _cholpivoted!(A::AbstractMatrix, ::Type{LowerTriangular}, tol::Real, check::Bool)
rTA = real(eltype(A))
# checks
Base.require_one_based_indexing(A)
n = LinearAlgebra.checksquare(A)
# initialization
piv = collect(1:n)
dots = zeros(real(eltype(A)), n)
dots = zeros(rTA, n)
temp = similar(dots)
info = 0
rank = n

@inbounds begin
# first step
ajj, q = findmax(i -> real(A[i,i]), 1:n)
stop = tol < 0 ? eps(eltype(A))*n*abs(ajj) : tol
if ajj stop
return A, piv, convert(BlasInt, 0), convert(BlasInt, 1)
end
Akk, q = findmax(i -> real(A[i,i]), 1:n)
stop = tol < 0 ? eps(rTA)*n*abs(Akk) : tol
Akk stop && return A, piv, convert(BlasInt, 0), convert(BlasInt, 1)
# swap
_swap_rowcols!(A, LowerTriangular, n, 1, q)
piv[1], piv[q] = piv[q], piv[1]
A[1,1] = ajj = sqrt(ajj)
@views A[2:n, 1] .= A[2:n, 1] ./ ajj
A[1,1] = Akk = sqrt(Akk)
AkkInv = inv(copy(Akk'))
@simd for i in 2:n
A[i,1] *= AkkInv
end

for j in 2:n
for k in j:n
dots[k] += abs2(A[k, j-1])
temp[k] = real(A[k,k]) - dots[k]
end
ajj, q = findmax(i -> temp[i], j:n)
q += j - 1
if ajj stop
rank = j - 1
info = 1
break
for k in 2:n
@simd for j in k:n
dots[j] += abs2(A[j, k-1])
temp[j] = real(A[j,j]) - dots[j]
end
Akk, q = findmax(i -> temp[i], k:n)
Akk stop && return A, piv, convert(BlasInt, k-1), convert(BlasInt, 1)
q += k - 1
# swap
_swap_rowcols!(A, LowerTriangular, n, j, q)
dots[j], dots[q] = dots[q], dots[j]
piv[j], piv[q] = piv[q], piv[j]
_swap_rowcols!(A, LowerTriangular, n, k, q)
dots[k], dots[q] = dots[q], dots[k]
piv[k], piv[q] = piv[q], piv[k]
# update
A[j,j] = ajj = sqrt(ajj)
@views if j < n
mul!(A[(j+1):n, j], A[(j+1):n, 1:(j-1)], A[j, 1:(j-1)], -1, true)
A[j+1:n, j] ./= ajj
A[k,k] = Akk = sqrt(Akk)
for j in 1:(k-1)
Akjc = A[k,j]'
@simd for i in (k+1):n
A[i,k] -= A[i,j]*Akjc
end
end
AkkInv = inv(copy(Akk'))
@simd for i in (k+1):n
A[i, k] *= AkkInv
end
end
return A, piv, convert(BlasInt, rank), convert(BlasInt, info)
return A, piv, convert(BlasInt, n), convert(BlasInt, 0)
end
end
function _cholpivoted!(x::Number, tol)
Expand All @@ -411,7 +413,7 @@ end
# cholesky!. Destructive methods for computing Cholesky factorization of real symmetric
# or Hermitian matrix
## No pivoting (default)
function cholesky!(A::RealHermSymComplexHerm, ::NoPivot = NoPivot(); check::Bool = true)
function cholesky!(A::SelfAdjoint, ::NoPivot = NoPivot(); check::Bool = true)
C, info = _chol!(A.data, A.uplo == 'U' ? UpperTriangular : LowerTriangular)
check && checkpositivedefinite(info)
return Cholesky(C.data, A.uplo, info)
Expand Down Expand Up @@ -453,7 +455,7 @@ end

## With pivoting
### Non BLAS/LAPACK element types (generic).
function cholesky!(A::RealHermSymComplexHerm, ::RowMaximum; tol = 0.0, check::Bool = true)
function cholesky!(A::SelfAdjoint, ::RowMaximum; tol = 0.0, check::Bool = true)
AA, piv, rank, info = _cholpivoted!(A.data, A.uplo == 'U' ? UpperTriangular : LowerTriangular, tol, check)
C = CholeskyPivoted(AA, A.uplo, piv, rank, tol, info)
check && chkfullrank(C)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ end

# factorizations
function cholesky(S::RealHermSymComplexHerm{<:Real,<:SymTridiagonal}, ::NoPivot = NoPivot(); check::Bool = true)
T = choltype(eltype(S))
T = choltype(S)
B = Bidiagonal{T}(diag(S, 0), diag(S, S.uplo == 'U' ? 1 : -1), sym_uplo(S.uplo))
cholesky!(Hermitian(B, sym_uplo(S.uplo)), NoPivot(); check = check)
end
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ const HermOrSym{T, S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}
const SelfAdjoint = Union{Symmetric{<:Real}, Hermitian{<:Number}}

size(A::HermOrSym) = size(A.data)
axes(A::HermOrSym) = axes(A.data)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ function cholesky(S::SymTridiagonal, ::NoPivot = NoPivot(); check::Bool = true)
check && checkpositivedefinite(-1)
return Cholesky(S, 'U', convert(BlasInt, -1))
end
T = choltype(eltype(S))
T = choltype(S)
cholesky!(Hermitian(Bidiagonal{T}(diag(S, 0), diag(S, 1), :U)), NoPivot(); check = check)
end

Expand Down
78 changes: 56 additions & 22 deletions stdlib/LinearAlgebra/test/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@ using Test, LinearAlgebra, Random
using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, QRPivoted,
PosDefException, RankDeficientException, chkfullrank

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")

isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl"))
using .Main.Quaternions

function unary_ops_tests(a, ca, tol; n=size(a, 1))
@test inv(ca)*a Matrix(I, n, n)
@test a*inv(ca) Matrix(I, n, n)
@test abs((det(ca) - det(a))/det(ca)) <= tol # Ad hoc, but statistically verified, revisit
@test logdet(ca) logdet(a)
@test logdet(ca) log(det(ca)) # logdet is less likely to overflow
@test logdet(ca) logdet(a) broken = eltype(a) <: Quaternion
@test logdet(ca) log(det(ca)) # logdet is less likely to overflow
logabsdet_ca = logabsdet(ca)
logabsdet_a = logabsdet(a)
@test logabsdet_ca[1] logabsdet_a[1]
Expand Down Expand Up @@ -53,14 +58,21 @@ end
breal = randn(n,2)/2
bimg = randn(n,2)/2

for eltya in (Float32, Float64, ComplexF32, ComplexF64, BigFloat, Int)
a = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(areal, aimg) : areal)
a2 = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(a2real, a2img) : a2real)
for eltya in (Float32, Float64, ComplexF32, ComplexF64, BigFloat, Complex{BigFloat}, Quaternion{Float64}, Int)
a = if eltya == Int
rand(1:7, n, n)
elseif eltya <: Real
convert(Matrix{eltya}, areal)
elseif eltya <: Complex
convert(Matrix{eltya}, complex.(areal, aimg))
else
convert(Matrix{eltya}, Quaternion.(areal, aimg, a2real, a2img))
end

ε = εa = eps(abs(float(one(eltya))))

# Test of symmetric pos. def. strided matrix
apd = a'*a
apd = Matrix(Hermitian(a'*a))
capd = @inferred cholesky(apd)
r = capd.U
κ = cond(apd, 1) #condition number
Expand All @@ -86,16 +98,16 @@ end
#but only with Random.seed!(1234321) set before the loops.
E = abs.(apd - r'*r)
for i=1:n, j=1:n
@test E[i,j] <= (n+1/(1-(n+1)ε)*real(sqrt(apd[i,i]*apd[j,j]))
@test E[i,j] <= (n+1/(1-(n+1)ε)*sqrt(real(apd[i,i]*apd[j,j]))
end
E = abs.(apd - Matrix(capd))
for i=1:n, j=1:n
@test E[i,j] <= (n+1/(1-(n+1)ε)*real(sqrt(apd[i,i]*apd[j,j]))
@test E[i,j] <= (n+1/(1-(n+1)ε)*sqrt(real(apd[i,i]*apd[j,j]))
end
@test LinearAlgebra.issuccess(capd)
@inferred(logdet(capd))

apos = apd[1,1]
apos = real(apd[1,1])
@test all(x -> x apos, cholesky(apos).factors)

# Test cholesky with Symmetric/Hermitian upper/lower
Expand Down Expand Up @@ -131,7 +143,7 @@ end
@test Matrix(@inferred cholesky(Symmetric(S, uplo))) S
end
end
@test Matrix(cholesky(S).U) [2 -1; 0 sqrt(eltya(3))] / sqrt(eltya(2))
@test Matrix(cholesky(S).U) [2 -1; 0 float(eltya)(sqrt(real(eltya)(3)))] / float(eltya)(sqrt(real(eltya)(2)))
@test Matrix(cholesky(S)) S

# test extraction of factor and re-creating original matrix
Expand All @@ -142,15 +154,25 @@ end
end

#pivoted upper Cholesky
cpapd = cholesky(apdh, RowMaximum())
unary_ops_tests(apdh, cpapd, ε*κ*n)
@test rank(cpapd) == n
@test all(diff(diag(real(cpapd.factors))).<=0.) # diagonal should be non-increasing
for tol in (0.0, -1.0), APD in (apdh, apdhL)
cpapd = cholesky(APD, RowMaximum(), tol=tol)
unary_ops_tests(APD, cpapd, ε*κ*n)
@test rank(cpapd) == n
@test all(diff(real(diag(cpapd.factors))).<=0.) # diagonal should be non-increasing

@test cpapd.P*cpapd.L*cpapd.U*cpapd.P' apd
@test cpapd.P*cpapd.L*cpapd.U*cpapd.P' apd
end

for eltyb in (Float32, Float64, ComplexF32, ComplexF64, Int)
b = eltyb == Int ? rand(1:5, n, 2) : convert(Matrix{eltyb}, eltyb <: Complex ? complex.(breal, bimg) : breal)
b = if eltya <: Quaternion
convert(Matrix{eltya}, Quaternion.(breal, bimg, bimg, bimg))
elseif eltyb == Int
rand(1:5, n, 2)
elseif eltyb <: Complex
convert(Matrix{eltyb}, complex.(breal, bimg))
elseif eltyb <: Real
convert(Matrix{eltyb}, breal)
end
εb = eps(abs(float(one(eltyb))))
ε = max(εa,εb)

Expand Down Expand Up @@ -181,7 +203,13 @@ end
for eltyb in (Float64, ComplexF64)
Breal = convert(Matrix{BigFloat}, randn(n,n)/2)
Bimg = convert(Matrix{BigFloat}, randn(n,n)/2)
B = (eltya <: Complex || eltyb <: Complex) ? complex.(Breal, Bimg) : Breal
B = if eltya <: Quaternion
Quaternion.(Float64.(Breal), Float64.(Bimg), Float64.(Bimg), Float64.(Bimg))
elseif eltya <: Complex || eltyb <: Complex
complex.(Breal, Bimg)
else
Breal
end
εb = eps(abs(float(one(eltyb))))
ε = max(εa,εb)

Expand All @@ -204,30 +232,36 @@ end
@testset "solve with generic Cholesky" begin
Breal = convert(Matrix{BigFloat}, randn(n,n)/2)
Bimg = convert(Matrix{BigFloat}, randn(n,n)/2)
B = eltya <: Complex ? complex.(Breal, Bimg) : Breal
B = if eltya <: Quaternion
eltya.(Breal, Bimg, Bimg, Bimg)
elseif eltya <: Complex
complex.(Breal, Bimg)
else
Breal
end
εb = eps(abs(float(one(eltype(B)))))
ε = max(εa,εb)

for B in (B, view(B, 1:n, 1:n)) # Array and SubArray

# Test error bound on linear solver: LAWNS 14, Theorem 2.1
# This is a surprisingly loose bound
cpapd = cholesky(eltya <: Complex ? apdh : apds)
cpapd = cholesky(eltya <: Real ? apds : apdh)
BB = copy(B)
rdiv!(BB, cpapd)
@test norm(B / apd - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(BB * apd - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
cpapd = cholesky(eltya <: Complex ? apdhL : apdsL)
cpapd = cholesky(eltya <: Real ? apdsL : apdhL)
BB = copy(B)
rdiv!(BB, cpapd)
@test norm(B / apd - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(BB * apd - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
cpapd = cholesky(eltya <: Complex ? apdh : apds, RowMaximum())
cpapd = cholesky(eltya <: Real ? apds : apdh, RowMaximum())
BB = copy(B)
rdiv!(BB, cpapd)
@test norm(B / apd - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(BB * apd - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
cpapd = cholesky(eltya <: Complex ? apdhL : apdsL, RowMaximum())
cpapd = cholesky(eltya <: Real ? apdsL : apdhL, RowMaximum())
BB = copy(B)
rdiv!(BB, cpapd)
@test norm(B / apd - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
Expand Down
Loading

0 comments on commit e40b57f

Please sign in to comment.