Skip to content

Commit

Permalink
Implement generic pivoted Cholesky (#54619)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Jun 5, 2024
1 parent 97bf148 commit b8a058a
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 73 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ Standard library changes
* Added keyword argument `alg` to `eigen`, `eigen!`, `eigvals` and `eigvals!` for self-adjoint
matrix types (i.e., the type union `RealHermSymComplexHerm`) that allows one to switch
between different eigendecomposition algorithms ([#49355]).
* Added a generic version of the (unblocked) pivoted Cholesky decomposition
(callable via `cholesky[!](A, RowMaximum())`) ([#54619]).

#### Logging

Expand Down
164 changes: 148 additions & 16 deletions stdlib/LinearAlgebra/src/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,147 @@ function _chol!(x::Number, _)
return (rval, convert(BlasInt, rx != abs(x)))
end

## for StridedMatrices, check that matrix is symmetric/Hermitian
# _cholpivoted!. Internal methods for calling pivoted Cholesky
Base.@propagate_inbounds function _swap_rowcols!(A, ::Type{UpperTriangular}, n, j, q)
j == q && return
@assert j < q
# swap rows and cols without touching the possibly undef-ed triangle
A[q, q] = A[j, j]
for k in 1:j-1 # initial vertical segments
A[k,j], A[k,q] = A[k,q], A[k,j]
end
for k in j+1:q-1 # intermediate segments
A[j,k], A[k,q] = conj(A[k,q]), conj(A[j,k])
end
A[j,q] = conj(A[j,q]) # corner case
for k in q+1:n # final horizontal segments
A[j,k], A[q,k] = A[q,k], A[j,k]
end
return
end
Base.@propagate_inbounds function _swap_rowcols!(A, ::Type{LowerTriangular}, n, j, q)
j == q && return
@assert j < q
# swap rows and cols without touching the possibly undef-ed triangle
A[q, q] = A[j, j]
for k in 1:j-1 # initial horizontal segments
A[j,k], A[q,k] = A[q,k], A[j,k]
end
for k in j+1:q-1 # intermediate segments
A[k,j], A[q,k] = conj(A[q,k]), conj(A[k,j])
end
A[q,j] = conj(A[q,j]) # corner case
for k in q+1:n # final vertical segments
A[k,j], A[k,q] = A[k,q], A[k,j]
end
return
end
### BLAS/LAPACK element types
_cholpivoted!(A::StridedMatrix{<:BlasFloat}, ::Type{UpperTriangular}, tol::Real, check::Bool) =
LAPACK.pstrf!('U', A, tol)
_cholpivoted!(A::StridedMatrix{<:BlasFloat}, ::Type{LowerTriangular}, tol::Real, check::Bool) =
LAPACK.pstrf!('L', A, tol)
## Non BLAS/LAPACK element types (generic)
function _cholpivoted!(A::AbstractMatrix, ::Type{UpperTriangular}, tol::Real, check::Bool)
# checks
Base.require_one_based_indexing(A)
n = LinearAlgebra.checksquare(A)
# initialization
piv = collect(1:n)
dots = zeros(real(eltype(A)), n)
temp = similar(dots)
info = 0
rank = n

@inbounds begin
# first step
ajj, q = findmax(i -> real(A[i,i]), 1:n)
stop = tol < 0 ? eps(eltype(A))*n*abs(ajj) : tol
if ajj stop
return A, piv, convert(BlasInt, 0), convert(BlasInt, 1)
end
# swap
_swap_rowcols!(A, UpperTriangular, n, 1, q)
piv[1], piv[q] = piv[q], piv[1]
A[1,1] = ajj = sqrt(ajj)
@views A[1, 2:n] .= A[1, 2:n] ./ ajj

for j in 2:n
for k in j:n
dots[k] += abs2(A[j-1, k])
temp[k] = real(A[k,k]) - dots[k]
end
ajj, q = findmax(i -> temp[i], j:n)
if ajj stop
rank = j - 1
info = 1
break
end
q += j - 1
# swap
_swap_rowcols!(A, UpperTriangular, n, j, q)
dots[j], dots[q] = dots[q], dots[j]
piv[j], piv[q] = piv[q], piv[j]
# update
A[j,j] = ajj = sqrt(ajj)
@views if j < n
mul!(A[j, (j+1):n], A[1:(j-1), (j+1):n]', A[1:(j-1), j], -1, true)
A[j, j+1:n] ./= ajj
end
end
return A, piv, convert(BlasInt, rank), convert(BlasInt, info)
end
end
function _cholpivoted!(A::AbstractMatrix, ::Type{LowerTriangular}, tol::Real, check::Bool)
# checks
Base.require_one_based_indexing(A)
n = LinearAlgebra.checksquare(A)
# initialization
piv = collect(1:n)
dots = zeros(real(eltype(A)), n)
temp = similar(dots)
info = 0
rank = n

@inbounds begin
# first step
ajj, q = findmax(i -> real(A[i,i]), 1:n)
stop = tol < 0 ? eps(eltype(A))*n*abs(ajj) : tol
if ajj stop
return A, piv, convert(BlasInt, 0), convert(BlasInt, 1)
end
# swap
_swap_rowcols!(A, LowerTriangular, n, 1, q)
piv[1], piv[q] = piv[q], piv[1]
A[1,1] = ajj = sqrt(ajj)
@views A[2:n, 1] .= A[2:n, 1] ./ ajj

for j in 2:n
for k in j:n
dots[k] += abs2(A[k, j-1])
temp[k] = real(A[k,k]) - dots[k]
end
ajj, q = findmax(i -> temp[i], j:n)
q += j - 1
if ajj stop
rank = j - 1
info = 1
break
end
# swap
_swap_rowcols!(A, LowerTriangular, n, j, q)
dots[j], dots[q] = dots[q], dots[j]
piv[j], piv[q] = piv[q], piv[j]
# update
A[j,j] = ajj = sqrt(ajj)
@views if j < n
mul!(A[(j+1):n, j], A[(j+1):n, 1:(j-1)], A[j, 1:(j-1)], -1, true)
A[j+1:n, j] ./= ajj
end
end
return A, piv, convert(BlasInt, rank), convert(BlasInt, info)
end
end

# cholesky!. Destructive methods for computing Cholesky factorization of real symmetric
# or Hermitian matrix
Expand Down Expand Up @@ -295,7 +435,7 @@ Stacktrace:
function cholesky!(A::AbstractMatrix, ::NoPivot = NoPivot(); check::Bool = true)
checksquare(A)
if !ishermitian(A) # return with info = -1 if not Hermitian
check && checkpositivedefinite(-1)
check && checkpositivedefinite(convert(BlasInt, -1))
return Cholesky(A, 'U', convert(BlasInt, -1))
else
return cholesky!(Hermitian(A), NoPivot(); check = check)
Expand All @@ -305,23 +445,15 @@ end
@deprecate cholesky!(A::RealHermSymComplexHerm, ::Val{false}; check::Bool = true) cholesky!(A, NoPivot(); check) false

## With pivoting
### BLAS/LAPACK element types
function cholesky!(A::RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix},
::RowMaximum; tol = 0.0, check::Bool = true)
AA, piv, rank, info = LAPACK.pstrf!(A.uplo, A.data, tol)
C = CholeskyPivoted{eltype(AA),typeof(AA),typeof(piv)}(AA, A.uplo, piv, rank, tol, info)
### Non BLAS/LAPACK element types (generic).
function cholesky!(A::RealHermSymComplexHerm, ::RowMaximum; tol = 0.0, check::Bool = true)
AA, piv, rank, info = _cholpivoted!(A.data, A.uplo == 'U' ? UpperTriangular : LowerTriangular, tol, check)
C = CholeskyPivoted(AA, A.uplo, piv, rank, tol, info)
check && chkfullrank(C)
return C
end
@deprecate cholesky!(A::RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix}, ::Val{true}; kwargs...) cholesky!(A, RowMaximum(); kwargs...) false

### Non BLAS/LAPACK element types (generic). Since generic fallback for pivoted Cholesky
### is not implemented yet we throw an error
cholesky!(A::RealHermSymComplexHerm{<:Real}, ::RowMaximum; tol = 0.0, check::Bool = true) =
throw(ArgumentError("generic pivoted Cholesky factorization is not implemented yet"))
@deprecate cholesky!(A::RealHermSymComplexHerm{<:Real}, ::Val{true}; kwargs...) cholesky!(A, RowMaximum(); kwargs...) false

### for AbstractMatrix, check that matrix is symmetric/Hermitian
"""
cholesky!(A::AbstractMatrix, RowMaximum(); tol = 0.0, check = true) -> CholeskyPivoted
Expand All @@ -335,7 +467,7 @@ function cholesky!(A::AbstractMatrix, ::RowMaximum; tol = 0.0, check::Bool = tru
if !ishermitian(A)
C = CholeskyPivoted(A, 'U', Vector{BlasInt}(),convert(BlasInt, 1),
tol, convert(BlasInt, -1))
check && chkfullrank(C)
check && checkpositivedefinite(-1)
return C
else
return cholesky!(Hermitian(A), RowMaximum(); tol = tol, check = check)
Expand Down Expand Up @@ -738,7 +870,7 @@ end

function chkfullrank(C::CholeskyPivoted)
if C.rank < size(C.factors, 1)
throw(RankDeficientException(C.info))
throw(RankDeficientException(C.rank))
end
end

Expand Down
109 changes: 52 additions & 57 deletions stdlib/LinearAlgebra/test/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,12 @@ end
end

#pivoted upper Cholesky
if eltya != BigFloat
cpapd = cholesky(apdh, RowMaximum())
unary_ops_tests(apdh, cpapd, ε*κ*n)
@test rank(cpapd) == n
@test all(diff(diag(real(cpapd.factors))).<=0.) # diagonal should be non-increasing
cpapd = cholesky(apdh, RowMaximum())
unary_ops_tests(apdh, cpapd, ε*κ*n)
@test rank(cpapd) == n
@test all(diff(diag(real(cpapd.factors))).<=0.) # diagonal should be non-increasing

@test cpapd.P*cpapd.L*cpapd.U*cpapd.P' apd
end
@test cpapd.P*cpapd.L*cpapd.U*cpapd.P' apd

for eltyb in (Float32, Float64, ComplexF32, ComplexF64, Int)
b = eltyb == Int ? rand(1:5, n, 2) : convert(Matrix{eltyb}, eltyb <: Complex ? complex.(breal, bimg) : breal)
Expand All @@ -167,22 +165,17 @@ end

@test norm(a*(capd\(a'*b)) - b,1)/norm(b,1) <= ε*κ*n # Ad hoc, revisit

if eltya != BigFloat && eltyb != BigFloat
lapd = cholesky(apdhL)
@test norm(apd * (lapd\b) - b)/norm(b) <= ε*κ*n
@test norm(apd * (lapd\b[1:n]) - b[1:n])/norm(b[1:n]) <= ε*κ*n
end

if eltya != BigFloat && eltyb != BigFloat # Note! Need to implement pivoted Cholesky decomposition in julia
lapd = cholesky(apdhL)
@test norm(apd * (lapd\b) - b)/norm(b) <= ε*κ*n
@test norm(apd * (lapd\b[1:n]) - b[1:n])/norm(b[1:n]) <= ε*κ*n

cpapd = cholesky(apdh, RowMaximum())
@test norm(apd * (cpapd\b) - b)/norm(b) <= ε*κ*n # Ad hoc, revisit
@test norm(apd * (cpapd\b[1:n]) - b[1:n])/norm(b[1:n]) <= ε*κ*n
cpapd = cholesky(apdh, RowMaximum())
@test norm(apd * (cpapd\b) - b)/norm(b) <= ε*κ*n # Ad hoc, revisit
@test norm(apd * (cpapd\b[1:n]) - b[1:n])/norm(b[1:n]) <= ε*κ*n

lpapd = cholesky(apdhL, RowMaximum())
@test norm(apd * (lpapd\b) - b)/norm(b) <= ε*κ*n # Ad hoc, revisit
@test norm(apd * (lpapd\b[1:n]) - b[1:n])/norm(b[1:n]) <= ε*κ*n
end
lpapd = cholesky(apdhL, RowMaximum())
@test norm(apd * (lpapd\b) - b)/norm(b) <= ε*κ*n # Ad hoc, revisit
@test norm(apd * (lpapd\b[1:n]) - b[1:n])/norm(b[1:n]) <= ε*κ*n
end
end

Expand All @@ -201,13 +194,11 @@ end
ldiv!(capd, BB)
@test norm(apd \ B - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(apd * BB - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
if eltya != BigFloat
cpapd = cholesky(apdh, RowMaximum())
BB = copy(B)
ldiv!(cpapd, BB)
@test norm(apd \ B - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(apd * BB - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
end
cpapd = cholesky(apdh, RowMaximum())
BB = copy(B)
ldiv!(cpapd, BB)
@test norm(apd \ B - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(apd * BB - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
end
end

Expand All @@ -232,18 +223,16 @@ end
rdiv!(BB, cpapd)
@test norm(B / apd - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(BB * apd - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
if eltya != BigFloat
cpapd = cholesky(eltya <: Complex ? apdh : apds, RowMaximum())
BB = copy(B)
rdiv!(BB, cpapd)
@test norm(B / apd - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(BB * apd - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
cpapd = cholesky(eltya <: Complex ? apdhL : apdsL, RowMaximum())
BB = copy(B)
rdiv!(BB, cpapd)
@test norm(B / apd - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(BB * apd - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
end
cpapd = cholesky(eltya <: Complex ? apdh : apds, RowMaximum())
BB = copy(B)
rdiv!(BB, cpapd)
@test norm(B / apd - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(BB * apd - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
cpapd = cholesky(eltya <: Complex ? apdhL : apdsL, RowMaximum())
BB = copy(B)
rdiv!(BB, cpapd)
@test norm(B / apd - BB, 1) / norm(BB, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
@test norm(BB * apd - B, 1) / norm(B, 1) <= (3n^2 + n + n^3*ε)*ε/(1-(n+1)*ε)*κ
end
end
if eltya <: BlasFloat
Expand Down Expand Up @@ -274,19 +263,29 @@ end
@test !LinearAlgebra.issuccess(cholesky(M; check = false))
@test !LinearAlgebra.issuccess(cholesky!(copy(M); check = false))
end
if T !== BigFloat # generic pivoted cholesky is not implemented
for M in (A, Hermitian(A), B)
@test_throws RankDeficientException cholesky(M, RowMaximum())
@test_throws RankDeficientException cholesky!(copy(M), RowMaximum())
@test_throws RankDeficientException cholesky(M, RowMaximum(); check = true)
@test_throws RankDeficientException cholesky!(copy(M), RowMaximum(); check = true)
@test !LinearAlgebra.issuccess(cholesky(M, RowMaximum(); check = false))
@test !LinearAlgebra.issuccess(cholesky!(copy(M), RowMaximum(); check = false))
C = cholesky(M, RowMaximum(); check = false)
@test_throws RankDeficientException chkfullrank(C)
C = cholesky!(copy(M), RowMaximum(); check = false)
@test_throws RankDeficientException chkfullrank(C)
end
for M in (A, Hermitian(A)) # hermitian, but not semi-positive definite
@test_throws RankDeficientException cholesky(M, RowMaximum())
@test_throws RankDeficientException cholesky!(copy(M), RowMaximum())
@test_throws RankDeficientException cholesky(M, RowMaximum(); check = true)
@test_throws RankDeficientException cholesky!(copy(M), RowMaximum(); check = true)
@test !issuccess(cholesky(M, RowMaximum(); check = false))
@test !issuccess(cholesky!(copy(M), RowMaximum(); check = false))
C = cholesky(M, RowMaximum(); check = false)
@test_throws RankDeficientException chkfullrank(C)
C = cholesky!(copy(M), RowMaximum(); check = false)
@test_throws RankDeficientException chkfullrank(C)
end
for M in (B,) # not hermitian
@test_throws PosDefException(-1) cholesky(M, RowMaximum())
@test_throws PosDefException(-1) cholesky!(copy(M), RowMaximum())
@test_throws PosDefException(-1) cholesky(M, RowMaximum(); check = true)
@test_throws PosDefException(-1) cholesky!(copy(M), RowMaximum(); check = true)
@test !issuccess(cholesky(M, RowMaximum(); check = false))
@test !issuccess(cholesky!(copy(M), RowMaximum(); check = false))
C = cholesky(M, RowMaximum(); check = false)
@test_throws RankDeficientException chkfullrank(C)
C = cholesky!(copy(M), RowMaximum(); check = false)
@test_throws RankDeficientException chkfullrank(C)
end
@test !isposdef(A)
str = sprint((io, x) -> show(io, "text/plain", x), cholesky(A; check = false))
Expand Down Expand Up @@ -368,10 +367,6 @@ end
end
end

@testset "fail for non-BLAS element types" begin
@test_throws ArgumentError cholesky!(Hermitian(rand(Float16, 5,5)), RowMaximum())
end

@testset "cholesky Diagonal" begin
# real
d = abs.(randn(3)) .+ 0.1
Expand Down

0 comments on commit b8a058a

Please sign in to comment.