Skip to content

Commit

Permalink
Merge pull request #8661 from JuliaLang/anj/bdsqr
Browse files Browse the repository at this point in the history
Fix bdsqr, add and reorganize tests
  • Loading branch information
andreasnoack committed Oct 14, 2014
2 parents 39024a7 + 06f4fc4 commit e4028ff
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 47 deletions.
40 changes: 28 additions & 12 deletions base/linalg/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3102,25 +3102,41 @@ for (bdsqr, relty, elty) in
#*> a real N-by-N (upper or lower) bidiagonal matrix B using the implicit
#*> zero-shift QR algorithm.
function bdsqr!(uplo::BlasChar, d::Vector{$relty}, e_::Vector{$relty},
vt::StridedMatrix{$elty}, u::StridedMatrix{$elty}, c::StridedMatrix{$elty})
@chkuplo
Vt::StridedMatrix{$elty}, U::StridedMatrix{$elty}, C::StridedMatrix{$elty})

# Extract number
n = length(d)
if length(e_) != n-1 throw(DimensionMismatch("bdsqr!")) end
ncvt, nru, ncc = size(vt, 2), size(u, 1), size(c, 2)
ldvt, ldu, ldc = max(1,stride(vt,2)), max(1,stride(u,2)), max(1,stride(c,2))
ncvt, nru, ncc = size(Vt, 2), size(U, 1), size(C, 2)
ldvt, ldu, ldc = max(1, stride(Vt,2)), max(1, stride(U, 2)), max(1, stride(C,2))

# Do checks
@chkuplo
length(e_) == n - 1 || throw(DimensionMismatch("off-diagonal has length $(length(e_)) but should have length $(n - 1)"))
if ncvt > 0
ldvt >= n || throw(DimensionMismatch("leading dimension of Vt must be at least $n"))
end
ldu >= nru || throw(DimensionMismatch("leading dimension of U must be at least $nru"))
size(U, 2) == n || throw(DimensionMismatch("U must have $n columns but have $(size(U, 2))"))
if ncc > 0
ldc >= n || throw(DimensionMismatch("leading dimension of C must be at least $n"))
end

# Allocate
work = Array($elty, 4n)
info = Array(BlasInt,1)

ccall(($(string(bdsqr)),liblapack), Void,
(Ptr{BlasChar}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, ncvt, &nru, &ncc,
d, e_, vt, &ldvt, u,
&ldu, c, &ldc, work, info)
(Ptr{BlasChar}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, &ncvt, &nru,
&ncc, d, e_, Vt,
&ldvt, U, &ldu, C, &ldc,
work, info)

@lapackerror
d, vt, u, c #singular values in descending order, P**T * VT, U * Q, Q**T * C
d, Vt, U, C #singular values in descending order, P**T * VT, U * Q, Q**T * C
end
end
end
Expand Down
54 changes: 54 additions & 0 deletions test/linalg/lapack.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using Base.Test
using Base.LAPACK.bdsqr!

let # syevr
srand(123)
Ainit = randn(5,5)
for elty in (Float32, Float64, Complex64, Complex128)
if elty == Complex64 || elty == Complex128
A = complex(Ainit, Ainit)
else
A = Ainit
end
A = convert(Array{elty, 2}, A)
Asym = A'A
vals, Z = LAPACK.syevr!('V', copy(Asym))
@test_approx_eq Z*scale(vals, Z') Asym
@test all(vals .> 0.0)
@test_approx_eq LAPACK.syevr!('N','V','U',copy(Asym),0.0,1.0,4,5,-1.0)[1] vals[vals .< 1.0]
@test_approx_eq LAPACK.syevr!('N','I','U',copy(Asym),0.0,1.0,4,5,-1.0)[1] vals[4:5]
@test_approx_eq vals LAPACK.syev!('N','U',copy(Asym))
end
end

let # Test gglse
for elty in (Float32, Float64, Complex64, Complex128)
A = convert(Array{elty, 2}, [1 1 1 1; 1 3 1 1; 1 -1 3 1; 1 1 1 3; 1 1 1 -1])
c = convert(Array{elty, 1}, [2, 1, 6, 3, 1])
B = convert(Array{elty, 2}, [1 1 1 -1; 1 -1 1 1; 1 1 -1 1])
d = convert(Array{elty, 1}, [1, 3, -1])
@test_approx_eq LAPACK.gglse!(A, c, B, d)[1] convert(Array{elty}, [0.5, -0.5, 1.5, 0.5])
end
end

let # xbdsqr
n = 10
for elty in (Float32, Float64)
d, e = convert(Vector{elty}, randn(n)), convert(Vector{elty}, randn(n - 1))
U, Vt, C = eye(elty, n), eye(elty, n), eye(elty, n)
s, _ = bdsqr!('U', copy(d), copy(e), Vt, U, C)
@test_approx_eq full(Bidiagonal(d, e, true)) U*Diagonal(s)*Vt

@test_throws ArgumentError bdsqr!('A', d, e, Vt, U, C)
@test_throws DimensionMismatch bdsqr!('U', d, [e, 1], Vt, U, C)
@test_throws DimensionMismatch bdsqr!('U', d, e, Vt[1:end - 1, :], U, C)
@test_throws DimensionMismatch bdsqr!('U', d, e, Vt, U[:,1:end - 1], C)
@test_throws DimensionMismatch bdsqr!('U', d, e, Vt, U, C[1:end - 1, :])
end
end

let # Issue #7886
x, r = LAPACK.gelsy!([0 1; 0 2; 0 3.], [2, 4, 6.])
@test_approx_eq x [0,2]
@test r == 1
end
29 changes: 0 additions & 29 deletions test/linalg2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,35 +170,6 @@ for elty in (Complex64, Complex128)
@test_approx_eq triu(LinAlg.BLAS.her2k('U','C',U,V)) triu(U'*V + V'*U)
end

# LAPACK tests
srand(123)
Ainit = randn(5,5)
for elty in (Float32, Float64, Complex64, Complex128)
# syevr!
if elty == Complex64 || elty == Complex128
A = complex(Ainit, Ainit)
else
A = Ainit
end
A = convert(Array{elty, 2}, A)
Asym = A'A
vals, Z = LinAlg.LAPACK.syevr!('V', copy(Asym))
@test_approx_eq Z*scale(vals, Z') Asym
@test all(vals .> 0.0)
@test_approx_eq LinAlg.LAPACK.syevr!('N','V','U',copy(Asym),0.0,1.0,4,5,-1.0)[1] vals[vals .< 1.0]
@test_approx_eq LinAlg.LAPACK.syevr!('N','I','U',copy(Asym),0.0,1.0,4,5,-1.0)[1] vals[4:5]
@test_approx_eq vals LinAlg.LAPACK.syev!('N','U',copy(Asym))
end

# Test gglse
for elty in (Float32, Float64, Complex64, Complex128)
A = convert(Array{elty, 2}, [1 1 1 1; 1 3 1 1; 1 -1 3 1; 1 1 1 3; 1 1 1 -1])
c = convert(Array{elty, 1}, [2, 1, 6, 3, 1])
B = convert(Array{elty, 2}, [1 1 1 -1; 1 -1 1 1; 1 1 -1 1])
d = convert(Array{elty, 1}, [1, 3, -1])
@test_approx_eq LinAlg.LAPACK.gglse!(A, c, B, d)[1] convert(Array{elty}, [0.5, -0.5, 1.5, 0.5])
end

# Test givens rotations
for elty in (Float32, Float64, Complex64, Complex128)
if elty <: Real
Expand Down
5 changes: 0 additions & 5 deletions test/linalg4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,6 @@ for elty in (Float32, Float64, Complex{Float32}, Complex{Float64})
end
end

# Issue #7886
x, r = LAPACK.gelsy!([0 1; 0 2; 0 3.], [2, 4, 6.])
@test_approx_eq x [0,2]
@test r == 1

# Issue #7933
A7933 = [1 2; 3 4]
B7933 = copy(A7933)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ tests = (ARGS==["all"] || isempty(ARGS)) ? testnames : ARGS
if "linalg" in tests
# specifically selected case
filter!(x -> x != "linalg", tests)
prepend!(tests, ["linalg1", "linalg2", "linalg3", "linalg4", "linalg/triangular", "linalg/tridiag"])
prepend!(tests, ["linalg1", "linalg2", "linalg3", "linalg4", "linalg/lapack", "linalg/triangular", "linalg/tridiag"])
end

net_required_for = ["socket", "parallel"]
Expand Down

0 comments on commit e4028ff

Please sign in to comment.