diff --git a/stdlib/LinearAlgebra/src/svd.jl b/stdlib/LinearAlgebra/src/svd.jl index 513dd408df03d..68bce4793661f 100644 --- a/stdlib/LinearAlgebra/src/svd.jl +++ b/stdlib/LinearAlgebra/src/svd.jl @@ -89,22 +89,36 @@ default_svd_alg(A) = DivideAndConquer() `svd!` is the same as [`svd`](@ref), but saves space by overwriting the input `A`, instead of creating a copy. See documentation of [`svd`](@ref) for details. """ -function svd!(A::StridedMatrix{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where T<:BlasFloat - m,n = size(A) +function svd!(A::StridedMatrix{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where {T<:BlasFloat} + m, n = size(A) if m == 0 || n == 0 - u,s,vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T,0)), Matrix{T}(I, n, n)) + u, s, vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T,0)), Matrix{T}(I, n, n)) else - u,s,vt = _svd!(A,full,alg) + u, s, vt = _svd!(A, full, alg) + end + SVD(u, s, vt) +end +function svd!(A::StridedVector{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where {T<:BlasFloat} + m = length(A) + normA = norm(A) + if iszero(normA) + return SVD(Matrix{T}(I, m, full ? m : 1), [normA], ones(T, 1, 1)) + elseif !full + normalize!(A) + return SVD(reshape(A, (m, 1)), [normA], ones(T, 1, 1)) + else + u, s, vt = _svd!(reshape(A, (m, 1)), full, alg) + return SVD(u, s, vt) end - SVD(u,s,vt) end - -_svd!(A::StridedMatrix{T}, full::Bool, alg::Algorithm) where T<:BlasFloat = throw(ArgumentError("Unsupported value for `alg` keyword.")) -_svd!(A::StridedMatrix{T}, full::Bool, alg::DivideAndConquer) where T<:BlasFloat = LAPACK.gesdd!(full ? 'A' : 'S', A) -function _svd!(A::StridedMatrix{T}, full::Bool, alg::QRIteration) where T<:BlasFloat +_svd!(A::StridedMatrix{T}, full::Bool, alg::Algorithm) where {T<:BlasFloat} = + throw(ArgumentError("Unsupported value for `alg` keyword.")) +_svd!(A::StridedMatrix{T}, full::Bool, alg::DivideAndConquer) where {T<:BlasFloat} = + LAPACK.gesdd!(full ? 'A' : 'S', A) +function _svd!(A::StridedMatrix{T}, full::Bool, alg::QRIteration) where {T<:BlasFloat} c = full ? 'A' : 'S' - u,s,vt = LAPACK.gesvd!(c, c, A) + u, s, vt = LAPACK.gesvd!(c, c, A) end @@ -153,7 +167,7 @@ julia> Uonly == U true ``` """ -function svd(A::StridedVecOrMat{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where T +function svd(A::StridedVecOrMat{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where {T} svd!(copy_oftype(A, eigtype(T)), full = full, alg = alg) end function svd(x::Number; full::Bool = false, alg::Algorithm = default_svd_alg(x)) @@ -190,7 +204,7 @@ See also [`svdvals`](@ref) and [`svd`](@ref). ``` """ svdvals!(A::StridedMatrix{T}) where {T<:BlasFloat} = isempty(A) ? zeros(real(T), 0) : LAPACK.gesdd!('N', A)[2] -svdvals(A::AbstractMatrix{<:BlasFloat}) = svdvals!(copy(A)) +svdvals!(A::StridedVector{T}) where {T<:BlasFloat} = svdvals!(reshape(A, (length(A), 1))) """ svdvals(A) @@ -214,7 +228,10 @@ julia> svdvals(A) 0.0 ``` """ -svdvals(A::AbstractMatrix{T}) where T = svdvals!(copy_oftype(A, eigtype(T))) +svdvals(A::AbstractMatrix{T}) where {T} = svdvals!(copy_oftype(A, eigtype(T))) +svdvals(A::AbstractVector{T}) where {T} = [convert(eigtype(T), norm(A))] +svdvals(A::AbstractMatrix{<:BlasFloat}) = svdvals!(copy(A)) +svdvals(A::AbstractVector{<:BlasFloat}) = [norm(A)] svdvals(x::Number) = abs(x) svdvals(S::SVD{<:Any,T}) where {T} = (S.S)::Vector{T} diff --git a/stdlib/LinearAlgebra/test/svd.jl b/stdlib/LinearAlgebra/test/svd.jl index d83d2de0f3c88..30dd6db300eb9 100644 --- a/stdlib/LinearAlgebra/test/svd.jl +++ b/stdlib/LinearAlgebra/test/svd.jl @@ -8,6 +8,22 @@ using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, QRPivoted @testset "Simple svdvals / svd tests" begin ≊(x,y) = isapprox(x,y,rtol=1e-15) + m = [2, 0] + @test @inferred(svdvals(m)) ≊ [2] + @test @inferred(svdvals!(float(m))) ≊ [2] + for sf in (@inferred(svd(m)), @inferred(svd!(float(m)))) + @test sf.S ≊ [2] + @test sf.U'sf.U ≊ [1] + @test sf.Vt'sf.Vt ≊ [1] + @test sf.U*Diagonal(sf.S)*sf.Vt' ≊ m + end + F = @inferred svd(m, full=true) + @test size(F.U) == (2, 2) + @test F.S ≊ [2] + @test F.U'F.U ≊ Matrix(I, 2, 2) + @test F.Vt'*F.Vt ≊ [1] + @test @inferred(svdvals(3:4)) ≊ [5] + m1 = [2 0; 0 0] m2 = [2 -2; 1 1]/sqrt(2) m2c = Complex.([2 -2; 1 1]/sqrt(2)) @@ -15,8 +31,8 @@ using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, QRPivoted @test @inferred(svdvals(m2)) ≊ [2, 1] @test @inferred(svdvals(m2c)) ≊ [2, 1] - sf1 = svd(m1) - sf2 = svd(m2) + sf1 = @inferred svd(m1) + sf2 = @inferred svd(m2) @test sf1.S ≊ [2, 0] @test sf2.S ≊ [2, 1] # U & Vt are unitary