Skip to content

Commit

Permalink
Test sytrs and trtri
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 14, 2023
1 parent f6107e8 commit 9e5c369
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 18 deletions.
4 changes: 2 additions & 2 deletions lib/cusolver/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :
@eval begin
function ormqr!(side::Char,
trans::Char,
A::CuMatrix{$elty},
A::StridedCuMatrix{$elty},
tau::CuVector{$elty},
C::CuVecOrMat{$elty})
C::StridedCuVecOrMat{$elty})

# Support transa = 'C' for real matrices
trans = $elty <: Real && trans == 'C' ? 'T' : trans
Expand Down
60 changes: 45 additions & 15 deletions lib/cusolver/dense_generic.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# TO DO: cusolverDnSetAdvOptions

# Xpotrf
function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
info = CuArray{Cint}(undef, 1)
params = Ref{cusolverDnParams_t}(C_NULL)
cusolverDnCreateParams(params)

function bufferSize()
out_cpu = Ref{Csize_t}(0)
Expand All @@ -20,6 +20,7 @@ function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
sizeof(buffer_cpu), info)
end

cusolverDnDestroyParams(params[])
flag = @allowscalar info[1]
unsafe_free!(info)
chkargsok(BlasInt(flag))
Expand All @@ -35,8 +36,12 @@ function Xpotrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuVecOrMat{T}) whe
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
info = CuArray{Cint}(undef, 1)
params = Ref{cusolverDnParams_t}(C_NULL)
cusolverDnCreateParams(params)

cusolverDnXpotrs(dense_handle(), params, uplo, n, nrhs, T, A, lda, T, B, ldb, info)

cusolverDnDestroyParams(params[])
flag = @allowscalar info[1]
unsafe_free!(info)
chkargsok(BlasInt(flag))
Expand All @@ -48,6 +53,8 @@ function Xgetrf!(A::StridedCuMatrix{T}, ipiv::CuVector{T}) where {T <: BlasFloat
m, n = size(A)
lda = max(1, stride(A, 2))
info = CuArray{Cint}(undef, 1)
params = Ref{cusolverDnParams_t}(C_NULL)
cusolverDnCreateParams(params)

function bufferSize()
out_cpu = Ref{Csize_t}(0)
Expand All @@ -62,6 +69,7 @@ function Xgetrf!(A::StridedCuMatrix{T}, ipiv::CuVector{T}) where {T <: BlasFloat
sizeof(buffer_cpu), info)
end

cusolverDnDestroyParams(params[])
flag = @allowscalar info[1]
unsafe_free!(info)
chkargsok(BlasInt(flag))
Expand All @@ -82,8 +90,12 @@ function Xgetrs!(trans::Char, A::StridedCuMatrix{T}, ipiv::CuVector{Int64}, B::S
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
info = CuArray{Cint}(undef, 1)
params = Ref{cusolverDnParams_t}(C_NULL)
cusolverDnCreateParams(params)

cusolverDnXgetrs(dense_handle(), params, trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, info)

cusolverDnDestroyParams(params[])
flag = @allowscalar info[1]
unsafe_free!(info)
chkargsok(BlasInt(flag))
Expand All @@ -95,6 +107,8 @@ function Xgeqrf!(A::StridedCuMatrix{T}, tau::CuVector{T}) where {T <: BlasFloat}
m, n = size(A)
lda = max(1, stride(A, 2))
info = CuArray{Cint}(undef, 1)
params = Ref{cusolverDnParams_t}(C_NULL)
cusolverDnCreateParams(params)

function bufferSize()
out_cpu = Ref{Csize_t}(0)
Expand All @@ -109,6 +123,7 @@ function Xgeqrf!(A::StridedCuMatrix{T}, tau::CuVector{T}) where {T <: BlasFloat}
buffer_cpu, sizeof(buffer_cpu), info)
end

cusolverDnDestroyParams(params[])
flag = @allowscalar info[1]
unsafe_free!(info)
chkargsok(BlasInt(flag))
Expand All @@ -122,7 +137,7 @@ function Xgeqrf!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
end

# Xsytrs
function Xsytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuMatrix{T}) where {T <: BlasFloat}
function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuMatrix{T}) where {T <: BlasFloat}
chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
Expand All @@ -149,8 +164,8 @@ function Xsytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::Strid
end

# Xtrtri
function Xtrtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
chktrans(trans)
function trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
chkuplo(uplo)
chkdiag(diag)
n = checksquare(A)
lda = max(1, stride(A, 2))
Expand All @@ -177,29 +192,31 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla
m, n = size(A)
R = real(T)
(m < n) && throw(ArgumentError("The number of rows of A ($m) must be greater or equal to the number of columns of A ($n)"))
if jobu == 'A'
U = CuMatrix{T}(undef, m, m)
U = if jobu == 'A'
CuMatrix{T}(undef, m, m)
elseif jobu == 'S' || jobu == 'O'
U = CuMatrix{T}(undef, m, min(m, n))
CuMatrix{T}(undef, m, min(m, n))
elseif jobu == 'N'
U = CU_NULL
CU_NULL
else
throw(ArgumentError("jobu is incorrect. The values accepted are 'A', 'S', 'O' and 'N'."))
end
Σ = CuVector{R}(undef, min(m, n))
if jobvt == 'A'
U = CuMatrix{T}(undef, n, n)
Vt = if jobvt == 'A'
CuMatrix{T}(undef, n, n)
elseif jobvt == 'S' || jobvt == 'O'
U = CuMatrix{T}(undef, min(m, n), n)
CuMatrix{T}(undef, min(m, n), n)
elseif jobvt == 'N'
U = CU_NULL
CU_NULL
else
throw(ArgumentError("jobvt is incorrect. The values accepted are 'A', 'S', 'O' and 'N'."))
end
lda = max(1, stride(A, 2))
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2))
info = CuArray{Cint}(undef, 1)
params = Ref{cusolverDnParams_t}(C_NULL)
cusolverDnCreateParams(params)

function bufferSize()
out_cpu = Ref{Csize_t}(0)
Expand All @@ -215,6 +232,7 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla
sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info)
end

cusolverDnDestroyParams(params[])
flag = @allowscalar info[1]
unsafe_free!(info)
chklapackerror(BlasInt(flag))
Expand All @@ -227,6 +245,8 @@ end
# lda = max(1, stride(A, 2))
# info = CuArray{Cint}(undef, 1)
# h_err_sigma = Ref{Cdouble}(0)
# params = Ref{cusolverDnParams_t}(C_NULL)
# cusolverDnCreateParams(params)
#
# function bufferSize()
# out_cpu = Ref{Csize_t}(0)
Expand All @@ -243,6 +263,7 @@ end
# buffer_cpu, sizeof(buffer_cpu), info, h_err_sigma)
# end
#
# cusolverDnDestroyParams(params[])
# flag = @allowscalar info[1]
# unsafe_free!(info)
# chklapackerror(BlasInt(flag))
Expand All @@ -254,6 +275,8 @@ end
# m, n = size(A)
# lda = max(1, stride(A, 2))
# info = CuArray{Cint}(undef, 1)
# params = Ref{cusolverDnParams_t}(C_NULL)
# cusolverDnCreateParams(params)
#
# function bufferSize()
# out_cpu = Ref{Csize_t}(0)
Expand All @@ -272,6 +295,7 @@ end
# sizeof(buffer_cpu), info)
# end
#
# cusolverDnDestroyParams(params[])
# flag = @allowscalar info[1]
# unsafe_free!(info)
# chklapackerror(BlasInt(flag))
Expand All @@ -286,6 +310,8 @@ function Xsyevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: Blas
lda = max(1, stride(A, 2))
info = CuArray{Cint}(undef, 1)
W = CuArray{R}(undef, n)
params = Ref{cusolverDnParams_t}(C_NULL)
cusolverDnCreateParams(params)

function bufferSize()
out_cpu = Ref{Csize_t}(0)
Expand All @@ -300,6 +326,7 @@ function Xsyevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: Blas
buffer_cpu, sizeof(buffer_cpu), info)
end

cusolverDnDestroyParams(params[])
flag = @allowscalar info[1]
unsafe_free!(info)
chkargsok(BlasInt(flag))
Expand All @@ -318,6 +345,8 @@ end
# lda = max(1, stride(A, 2))
# info = CuArray{Cint}(undef, 1)
# ...
# params = Ref{cusolverDnParams_t}(C_NULL)
# cusolverDnCreateParams(params)
#
# function bufferSize()
# out_cpu = Ref{Csize_t}(0)
Expand All @@ -333,6 +362,7 @@ end
# sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info)
# end
#
# cusolverDnDestroyParams(params[])
# flag = @allowscalar info[1]
# unsafe_free!(info)
# chkargsok(BlasInt(flag))
Expand All @@ -347,7 +377,7 @@ end
# LAPACK
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
@eval begin
LinearAlgebra.LAPACK.sytrs!(uplo::Char, A::StridedCuMatrix{$elty}, p::CuVector{Int64}, B::StridedCuVecOrMat{$elty}) = CUSOLVER.Xsytrs!(uplo, A, p, B)
LinearAlgebra.LAPACK.trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.Xtrtri!(uplo, diag, A)
LinearAlgebra.LAPACK.sytrs!(uplo::Char, A::StridedCuMatrix{$elty}, p::CuVector{Int64}, B::StridedCuVecOrMat{$elty}) = CUSOLVER.sytrs!(uplo, A, p, B)
LinearAlgebra.LAPACK.trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.trtri!(uplo, diag, A)
end
end
2 changes: 1 addition & 1 deletion test/libraries/cusolver/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ k = 1
@test_throws LinearAlgebra.PosDefException cholesky(d_A)
end

CUDA.CUSOLVER.version() >= v"10.1" && @testset "Cholesky inverse (potri)" begin
@testset "Cholesky inverse (potri)" begin
# test lower
A = rand(elty,n,n)
A = A*A'+I #posdef
Expand Down
37 changes: 37 additions & 0 deletions test/libraries/cusolver/dense_generic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using CUDA.CUSOLVER
using LinearAlgebra

n = 10
p = 5

@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
@testset "sytrs!" begin
for uplo in ('L', 'U')
A = rand(elty,n,n)
B = rand(elty,n,p)
A = A + A'
d_A = CuMatrix(A)
d_B = CuMatrix(B)
d_A, d_ipiv, _ = CUSOLVER.sytrf!(uplo, d_A)
d_ipiv = CuVector{Int64}(d_ipiv)
A, ipiv, _ = LinearAlgebra.sytrf!(uplo, A)
CUSOLVER.sytrs!(uplo, d_A, d_ipiv, d_B)
LinearAlgebra.sytrs!(uplo, A, ipiv, B)
@test B collect(d_B)
end
end

@testset "trtri!" begin
for uplo in ('L', 'U')
for diag in ('N', 'U')
A = rand(elty,n,n)
A = uplo == 'L' ? tril(A) : triu(A)
A = diag == 'N' ? A : A - Diagonal(A) + I
d_A = CuMatrix(A)
d_B = copy(d_A)
CUSOLVER.trtri!(uplo, diag, d_B)
@test collect(d_A * d_B) I
end
end
end
end

0 comments on commit 9e5c369

Please sign in to comment.