Skip to content

Commit

Permalink
alg keyword for svd and svd! (JuliaLang#31057)
Browse files Browse the repository at this point in the history
* alg keyword for LinearAlgebra.svd

* SVDAlgorithms -> Algorithms

* default_svd_alg

* refined docstring

Co-Authored-By: Andreas Noack <[email protected]>

* rename to QRIteration; _svd! dispatch

* compat annotation
  • Loading branch information
carstenbauer authored and andreasnoack committed Aug 15, 2019
1 parent 0eabe22 commit 5e584fb
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 16 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Standard library changes
* The BLAS submodule no longer exports `dot`, which conflicts with that in LinearAlgebra ([#31838]).
* `diagm` and `spdiagm` now accept optional `m,n` initial arguments to specify a size ([#31654]).
* `Hessenberg` factorizations `H` now support efficient shifted solves `(H+µI) \ b` and determinants, and use a specialized tridiagonal factorization for Hermitian matrices. There is also a new `UpperHessenberg` matrix type ([#31853]).
* Added keyword argument `alg` to `svd` and `svd!` that allows one to switch between different SVD algorithms ([#31057]).
* Five-argument `mul!(C, A, B, α, β)` now implements inplace multiplication fused with addition _C = A B α + C β_ ([#23919]).

#### SparseArrays
Expand Down
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ else
const BlasInt = Int32
end


abstract type Algorithm end
struct DivideAndConquer <: Algorithm end
struct QRIteration <: Algorithm end


# Check that stride of matrix/vector is 1
# Writing like this to avoid splatting penalty when called with multiple arguments,
# see PR 16416
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ function svd!(M::Bidiagonal{<:BlasReal}; full::Bool = false)
d, e, U, Vt, Q, iQ = LAPACK.bdsdc!(M.uplo, 'I', M.dv, M.ev)
SVD(U, d, Vt)
end
function svd(M::Bidiagonal; full::Bool = false)
svd!(copy(M), full = full)
function svd(M::Bidiagonal; kw...)
svd!(copy(M), kw...)
end

####################
Expand Down
47 changes: 34 additions & 13 deletions stdlib/LinearAlgebra/src/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,19 @@ function SVD{T}(U::AbstractArray, S::AbstractVector{Tr}, Vt::AbstractArray) wher
convert(AbstractArray{T}, Vt))
end


# iteration for destructuring into components
Base.iterate(S::SVD) = (S.U, Val(:S))
Base.iterate(S::SVD, ::Val{:S}) = (S.S, Val(:V))
Base.iterate(S::SVD, ::Val{:V}) = (S.V, Val(:done))
Base.iterate(S::SVD, ::Val{:done}) = nothing


default_svd_alg(A) = DivideAndConquer()


"""
svd!(A; full::Bool = false) -> SVD
svd!(A; full::Bool = false, alg::Algorithm = default_svd_alg(A)) -> SVD
`svd!` is the same as [`svd`](@ref), but saves space by
overwriting the input `A`, instead of creating a copy.
Expand Down Expand Up @@ -92,18 +97,28 @@ julia> A
0.0 0.0 -2.0 0.0 0.0
```
"""
function svd!(A::StridedMatrix{T}; full::Bool = false) where T<:BlasFloat
function svd!(A::StridedMatrix{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where T<:BlasFloat
m,n = size(A)
if m == 0 || n == 0
u,s,vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T,0)), Matrix{T}(I, n, n))
else
u,s,vt = LAPACK.gesdd!(full ? 'A' : 'S', A)
u,s,vt = _svd!(A,full,alg)
end
SVD(u,s,vt)
end


_svd!(A::StridedMatrix{T}, full::Bool, alg::Algorithm) where T<:BlasFloat = throw(ArgumentError("Unsupported value for `alg` keyword."))
_svd!(A::StridedMatrix{T}, full::Bool, alg::DivideAndConquer) where T<:BlasFloat = LAPACK.gesdd!(full ? 'A' : 'S', A)
function _svd!(A::StridedMatrix{T}, full::Bool, alg::QRIteration) where T<:BlasFloat
c = full ? 'A' : 'S'
u,s,vt = LAPACK.gesvd!(c, c, A)
end



"""
svd(A; full::Bool = false) -> SVD
svd(A; full::Bool = false, alg::Algorithm = default_svd_alg(A)) -> SVD
Compute the singular value decomposition (SVD) of `A` and return an `SVD` object.
Expand All @@ -120,6 +135,12 @@ and `V` is `N \\times N`, while in the thin factorization `U` is `M
\\times K` and `V` is `N \\times K`, where `K = \\min(M,N)` is the
number of singular values.
If `alg = DivideAndConquer()` a divide-and-conquer algorithm is used to calculate the SVD.
Another (typically slower but more accurate) option is `alg = QRIteration()`.
!!! compat "Julia 1.3"
The `alg` keyword argument requires Julia 1.3 or later.
# Examples
```jldoctest
julia> A = [1. 0. 0. 0. 2.; 0. 0. 3. 0. 0.; 0. 0. 0. 0. 0.; 0. 2. 0. 0. 0.]
Expand All @@ -144,21 +165,21 @@ julia> u == F.U && s == F.S && v == F.V
true
```
"""
function svd(A::StridedVecOrMat{T}; full::Bool = false) where T
svd!(copy_oftype(A, eigtype(T)), full = full)
function svd(A::StridedVecOrMat{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where T
svd!(copy_oftype(A, eigtype(T)), full = full, alg = alg)
end
function svd(x::Number; full::Bool = false)
function svd(x::Number; full::Bool = false, alg::Algorithm = default_svd_alg(x))
SVD(x == 0 ? fill(one(x), 1, 1) : fill(x/abs(x), 1, 1), [abs(x)], fill(one(x), 1, 1))
end
function svd(x::Integer; full::Bool = false)
svd(float(x), full = full)
function svd(x::Integer; full::Bool = false, alg::Algorithm = default_svd_alg(x))
svd(float(x), full = full, alg = alg)
end
function svd(A::Adjoint; full::Bool = false)
s = svd(A.parent, full = full)
function svd(A::Adjoint; full::Bool = false, alg::Algorithm = default_svd_alg(A))
s = svd(A.parent, full = full, alg = alg)
return SVD(s.Vt', s.S, s.U')
end
function svd(A::Transpose; full::Bool = false)
s = svd(A.parent, full = full)
function svd(A::Transpose; full::Bool = false, alg::Algorithm = default_svd_alg(A))
s = svd(A.parent, full = full, alg = alg)
return SVD(transpose(s.Vt), s.S, transpose(s.U))
end

Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2532,7 +2532,7 @@ eigen(A::AbstractTriangular) = Eigen(eigvals(A), eigvecs(A))
# Generic singular systems
for func in (:svd, :svd!, :svdvals)
@eval begin
($func)(A::AbstractTriangular) = ($func)(copyto!(similar(parent(A)), A))
($func)(A::AbstractTriangular; kwargs...) = ($func)(copyto!(similar(parent(A)), A); kwargs...)
end
end

Expand Down
24 changes: 24 additions & 0 deletions stdlib/LinearAlgebra/test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,28 @@ aimg = randn(n,n)/2
end
end



@testset "SVD Algorithms" begin
(x,y) = isapprox(x,y,rtol=1e-15)

x = [0.1 0.2; 0.3 0.4]

for alg in [LinearAlgebra.QRIteration(), LinearAlgebra.DivideAndConquer()]
sx1 = svd(x, alg = alg)
@test sx1.U * Diagonal(sx1.S) * sx1.Vt x
@test sx1.V * sx1.Vt I
@test sx1.U * sx1.U' I
@test all(sx1.S .≥ 0)

sx2 = svd!(copy(x), alg = alg)
@test sx2.U * Diagonal(sx2.S) * sx2.Vt x
@test sx2.V * sx2.Vt I
@test sx2.U * sx2.U' I
@test all(sx2.S .≥ 0)
end
end



end # module TestSVD

0 comments on commit 5e584fb

Please sign in to comment.