Skip to content

Commit

Permalink
Allow SVD of vectors (JuliaLang#39087)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Jan 14, 2021
1 parent 006f90d commit bbe4cfa
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
43 changes: 30 additions & 13 deletions stdlib/LinearAlgebra/src/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,36 @@ default_svd_alg(A) = DivideAndConquer()
`svd!` is the same as [`svd`](@ref), but saves space by
overwriting the input `A`, instead of creating a copy. See documentation of [`svd`](@ref) for details.
"""
function svd!(A::StridedMatrix{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where T<:BlasFloat
m,n = size(A)
function svd!(A::StridedMatrix{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where {T<:BlasFloat}
m, n = size(A)
if m == 0 || n == 0
u,s,vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T,0)), Matrix{T}(I, n, n))
u, s, vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T,0)), Matrix{T}(I, n, n))
else
u,s,vt = _svd!(A,full,alg)
u, s, vt = _svd!(A, full, alg)
end
SVD(u, s, vt)
end
function svd!(A::StridedVector{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where {T<:BlasFloat}
m = length(A)
normA = norm(A)
if iszero(normA)
return SVD(Matrix{T}(I, m, full ? m : 1), [normA], ones(T, 1, 1))
elseif !full
normalize!(A)
return SVD(reshape(A, (m, 1)), [normA], ones(T, 1, 1))
else
u, s, vt = _svd!(reshape(A, (m, 1)), full, alg)
return SVD(u, s, vt)
end
SVD(u,s,vt)
end


_svd!(A::StridedMatrix{T}, full::Bool, alg::Algorithm) where T<:BlasFloat = throw(ArgumentError("Unsupported value for `alg` keyword."))
_svd!(A::StridedMatrix{T}, full::Bool, alg::DivideAndConquer) where T<:BlasFloat = LAPACK.gesdd!(full ? 'A' : 'S', A)
function _svd!(A::StridedMatrix{T}, full::Bool, alg::QRIteration) where T<:BlasFloat
_svd!(A::StridedMatrix{T}, full::Bool, alg::Algorithm) where {T<:BlasFloat} =
throw(ArgumentError("Unsupported value for `alg` keyword."))
_svd!(A::StridedMatrix{T}, full::Bool, alg::DivideAndConquer) where {T<:BlasFloat} =
LAPACK.gesdd!(full ? 'A' : 'S', A)
function _svd!(A::StridedMatrix{T}, full::Bool, alg::QRIteration) where {T<:BlasFloat}
c = full ? 'A' : 'S'
u,s,vt = LAPACK.gesvd!(c, c, A)
u, s, vt = LAPACK.gesvd!(c, c, A)
end


Expand Down Expand Up @@ -153,7 +167,7 @@ julia> Uonly == U
true
```
"""
function svd(A::StridedVecOrMat{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where T
function svd(A::StridedVecOrMat{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where {T}
svd!(copy_oftype(A, eigtype(T)), full = full, alg = alg)
end
function svd(x::Number; full::Bool = false, alg::Algorithm = default_svd_alg(x))
Expand Down Expand Up @@ -190,7 +204,7 @@ See also [`svdvals`](@ref) and [`svd`](@ref).
```
"""
svdvals!(A::StridedMatrix{T}) where {T<:BlasFloat} = isempty(A) ? zeros(real(T), 0) : LAPACK.gesdd!('N', A)[2]
svdvals(A::AbstractMatrix{<:BlasFloat}) = svdvals!(copy(A))
svdvals!(A::StridedVector{T}) where {T<:BlasFloat} = svdvals!(reshape(A, (length(A), 1)))

"""
svdvals(A)
Expand All @@ -214,7 +228,10 @@ julia> svdvals(A)
0.0
```
"""
svdvals(A::AbstractMatrix{T}) where T = svdvals!(copy_oftype(A, eigtype(T)))
svdvals(A::AbstractMatrix{T}) where {T} = svdvals!(copy_oftype(A, eigtype(T)))
svdvals(A::AbstractVector{T}) where {T} = [convert(eigtype(T), norm(A))]
svdvals(A::AbstractMatrix{<:BlasFloat}) = svdvals!(copy(A))
svdvals(A::AbstractVector{<:BlasFloat}) = [norm(A)]
svdvals(x::Number) = abs(x)
svdvals(S::SVD{<:Any,T}) where {T} = (S.S)::Vector{T}

Expand Down
20 changes: 18 additions & 2 deletions stdlib/LinearAlgebra/test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,31 @@ using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, QRPivoted
@testset "Simple svdvals / svd tests" begin
(x,y) = isapprox(x,y,rtol=1e-15)

m = [2, 0]
@test @inferred(svdvals(m)) [2]
@test @inferred(svdvals!(float(m))) [2]
for sf in (@inferred(svd(m)), @inferred(svd!(float(m))))
@test sf.S [2]
@test sf.U'sf.U [1]
@test sf.Vt'sf.Vt [1]
@test sf.U*Diagonal(sf.S)*sf.Vt' m
end
F = @inferred svd(m, full=true)
@test size(F.U) == (2, 2)
@test F.S [2]
@test F.U'F.U Matrix(I, 2, 2)
@test F.Vt'*F.Vt [1]
@test @inferred(svdvals(3:4)) [5]

m1 = [2 0; 0 0]
m2 = [2 -2; 1 1]/sqrt(2)
m2c = Complex.([2 -2; 1 1]/sqrt(2))
@test @inferred(svdvals(m1)) [2, 0]
@test @inferred(svdvals(m2)) [2, 1]
@test @inferred(svdvals(m2c)) [2, 1]

sf1 = svd(m1)
sf2 = svd(m2)
sf1 = @inferred svd(m1)
sf2 = @inferred svd(m2)
@test sf1.S [2, 0]
@test sf2.S [2, 1]
# U & Vt are unitary
Expand Down

0 comments on commit bbe4cfa

Please sign in to comment.