Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 14, 2023
1 parent 8e75786 commit f6107e8
Showing 1 changed file with 139 additions and 113 deletions.
252 changes: 139 additions & 113 deletions lib/cusolver/dense_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Xpotrf
function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
LinearAlgebra.BLAS.chkuplo(uplo)
chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
info = CuArray{Cint}(undef, 1)
Expand All @@ -26,13 +26,12 @@ function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
A, flag
end

Xpotrf(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} = Xpotrf!(uplo, copy(A))

# Xpotrs
function Xpotrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}) where {T <: BlasFloat}
LinearAlgebra.BLAS.chkuplo(uplo)
function Xpotrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuVecOrMat{T}) where {T <: BlasFloat}
chkuplo(uplo)
n = checksquare(A)
nrhs = size(B, 2)
p, nrhs = size(B)
(p n) && throw(DimensionMismatch("first dimension of B, $p, must match second dimension of A, $n"))
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
info = CuArray{Cint}(undef, 1)
Expand All @@ -44,13 +43,10 @@ function Xpotrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}) where
B
end

Xpotrs(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}) where {T <: BlasFloat} = Xpotrs!(uplo, A, copy(B))

# Xgetrf
function Xgetrf!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
function Xgetrf!(A::StridedCuMatrix{T}, ipiv::CuVector{T}) where {T <: BlasFloat}
m, n = size(A)
lda = max(1, stride(A, 2))
p = CuVector{Int64}(undef, min(m,n))
info = CuArray{Cint}(undef, 1)

function bufferSize()
Expand All @@ -61,41 +57,43 @@ function Xgetrf!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
out_gpu[], out_cpu[]
end
with_workspaces(bufferSize) do buffer_gpu, buffer_cpu
cusolverDnXgetrf(dense_handle(), params, m, n, T, A, lda, p,
cusolverDnXgetrf(dense_handle(), params, m, n, T, A, lda, ipiv,
T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu,
sizeof(buffer_cpu), info)
end

flag = @allowscalar info[1]
unsafe_free!(info)
chkargsok(BlasInt(flag))
A, p, flag
A, ipiv, flag
end

Xgetrf(A::StridedCuMatrix{T}) where {T <: BlasFloat} = Xgetrf!(copy(A))
function Xgetrf!(A::StridedCuMatrix{$elty})
m,n = size(A)
ipiv = CuArray{Int64}(undef, min(m, n))
Xgetrf!(A, ipiv)
end

# Xgetrs
function Xgetrs!(transa::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuMatrix{T}) where {T <: BlasFloat}
function Xgetrs!(trans::Char, A::StridedCuMatrix{T}, ipiv::CuVector{Int64}, B::StridedCuVecOrMat{T}) where {T <: BlasFloat}
chktrans(trans)
n = checksquare(A)
nrhs = size(B, 2)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
info = CuArray{Cint}(undef, 1)
cusolverDnXgetrs(dense_handle(), params, transa, n, nrhs, T, A, lda, p, T, B, ldb, info)
cusolverDnXgetrs(dense_handle(), params, trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, info)

flag = @allowscalar info[1]
unsafe_free!(info)
chkargsok(BlasInt(flag))
B
end

Xgetrs(transa::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuMatrix{T}) where {T <: BlasFloat} = Xgetrs!(transa, A, p, copy(B))

# Xgeqrf
function Xgeqrf!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
function Xgeqrf!(A::StridedCuMatrix{T}, tau::CuVector{T}) where {T <: BlasFloat}
m, n = size(A)
lda = max(1, stride(A, 2))
tau = CuVector{T}(undef, min(m,n))
info = CuArray{Cint}(undef, 1)

function bufferSize()
Expand All @@ -117,10 +115,15 @@ function Xgeqrf!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
A, tau
end

Xgeqrf(A::StridedCuMatrix{T}) where {T <: BlasFloat} = Xgeqrf!(copy(A))
function Xgeqrf!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
m, n = size(A)
tau = CuArray{T}(undef, min(m,n))
Xgeqrf!(A, tau)
end

# Xsytrs
function Xsytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuMatrix{T}) where {T <: BlasFloat}
chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
Expand All @@ -145,10 +148,10 @@ function Xsytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::Strid
B
end

Xsytrs(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuMatrix{T}) where {T <: BlasFloat} = Xsytrs!(uplo, A, p, copy(B))

# Xtrtri
function Xtrtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
chktrans(trans)
chkdiag(diag)
n = checksquare(A)
lda = max(1, stride(A, 2))
info = CuArray{Cint}(undef, 1)
Expand All @@ -169,33 +172,33 @@ function Xtrtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: Blas
A
end

Xtrtri(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} = Xtrtri!(uplo, diag, copy(A))

# Xgesvd
function Xgesvd!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
m, n = size(A)
R = real(T)
(m < n) && throw(ArgumentError("The number of rows of A ($m) must be greater or equal to the number of columns of A ($n)"))
if jobu == 'A'
U = CuMatrix{T}(undef, m, m)
elseif jobu == 'S' || jubu == 'O'
U = CuMatrix{T}(undef, m, min(m,n))
elseif jobu == 'S' || jobu == 'O'
U = CuMatrix{T}(undef, m, min(m, n))
elseif jobu == 'N'
U = CUNULL
U = CU_NULL
else
throw(ArgumentError("jobu is incorrect. The values accepted are 'A', 'S', 'O' and 'N'."))
end
Σ = CuVector{T}(undef, min(m,n))
Σ = CuVector{R}(undef, min(m, n))
if jobvt == 'A'
U = CuMatrix{T}(undef, n, n)
elseif jobvt == 'S' || jubu == 'O'
U = CuMatrix{T}(undef, min(m,n), n)
elseif jobvt == 'S' || jobvt == 'O'
U = CuMatrix{T}(undef, min(m, n), n)
elseif jobvt == 'N'
U = CUNULL
U = CU_NULL
else
throw(ArgumentError("jobvt is incorrect. The values accepted are 'A', 'S', 'O' and 'N'."))
end
lda = max(1, stride(A, 2))
ldu = U == CUNULL ? 0 : max(1, stride(U, 2))
ldvt = Vt == CUNULL ? 0 : max(1, stride(Vt, 2))
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2))
info = CuArray{Cint}(undef, 1)

function bufferSize()
Expand All @@ -219,67 +222,70 @@ function Xgesvd!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
end

# Xgesvdp
function Xgesvdp!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
m, n = size(A)
lda = max(1, stride(A, 2))
info = CuArray{Cint}(undef, 1)
h_err_sigma = Ref{Cdouble}(0)

function bufferSize()
out_cpu = Ref{Csize_t}(0)
out_gpu = Ref{Csize_t}(0)
cusolverDnXgesvdp_bufferSize(dense_handle(), params, jobz, econ, m,
n, T, A, lda, T, S, T, U, ldu, T, V,
ldv, T, out_gpu, out_cpu)

out_gpu[], out_cpu[]
end
with_workspaces(bufferSize) do buffer_gpu, buffer_cpu
cusolverDnXgesvdp(dense_handle(), params, jobz, econ, m, n, T, A, lda, T,
S, T, U, ldu, T, V, ldv, T, buffer_gpu, sizeof(buffer_gpu),
buffer_cpu, sizeof(buffer_cpu), info, h_err_sigma)
end

flag = @allowscalar info[1]
unsafe_free!(info)
chklapackerror(BlasInt(flag))
...
end
# function Xgesvdp!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
# m, n = size(A)
# lda = max(1, stride(A, 2))
# info = CuArray{Cint}(undef, 1)
# h_err_sigma = Ref{Cdouble}(0)
#
# function bufferSize()
# out_cpu = Ref{Csize_t}(0)
# out_gpu = Ref{Csize_t}(0)
# cusolverDnXgesvdp_bufferSize(dense_handle(), params, jobz, econ, m,
# n, T, A, lda, T, S, T, U, ldu, T, V,
# ldv, T, out_gpu, out_cpu)
#
# out_gpu[], out_cpu[]
# end
# with_workspaces(bufferSize) do buffer_gpu, buffer_cpu
# cusolverDnXgesvdp(dense_handle(), params, jobz, econ, m, n, T, A, lda, T,
# S, T, U, ldu, T, V, ldv, T, buffer_gpu, sizeof(buffer_gpu),
# buffer_cpu, sizeof(buffer_cpu), info, h_err_sigma)
# end
#
# flag = @allowscalar info[1]
# unsafe_free!(info)
# chklapackerror(BlasInt(flag))
# ...
# end

# Xgesvdr
function Xgesvdr!(A::StridedCuMatrix{T}, k::Integer) where {T <: BlasFloat}
m, n = size(A)
lda = max(1, stride(A, 2))
info = CuArray{Cint}(undef, 1)

function bufferSize()
out_cpu = Ref{Csize_t}(0)
out_gpu = Ref{Csize_t}(0)
cusolverDnXgesvdr_bufferSize(dense_handle(), params, jobu, jobv, m, n, k, p,
niters, T, A, lda, T,
Srand, T, Urand, ldUrand,
T, Vrand, ldVrand, T, out_gpu, out_cpu)
out_gpu[], out_cpu[]
end
with_workspaces(bufferSize) do buffer_gpu, buffer_cpu
cusolverDnXgesvdr(dense_handle(), params, jobu, jobv, m, n, k, p, niters,
T, A, lda, T, Srand, T,
Urand, ldUrand, T, Vrand, ldVrand,
T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu,
sizeof(buffer_cpu), info)
end

flag = @allowscalar info[1]
unsafe_free!(info)
chklapackerror(BlasInt(flag))
...
end
# function Xgesvdr!(A::StridedCuMatrix{T}, k::Integer) where {T <: BlasFloat}
# m, n = size(A)
# lda = max(1, stride(A, 2))
# info = CuArray{Cint}(undef, 1)
#
# function bufferSize()
# out_cpu = Ref{Csize_t}(0)
# out_gpu = Ref{Csize_t}(0)
# cusolverDnXgesvdr_bufferSize(dense_handle(), params, jobu, jobv, m, n, k, p,
# niters, T, A, lda, T,
# Srand, T, Urand, ldUrand,
# T, Vrand, ldVrand, T, out_gpu, out_cpu)
# out_gpu[], out_cpu[]
# end
# with_workspaces(bufferSize) do buffer_gpu, buffer_cpu
# cusolverDnXgesvdr(dense_handle(), params, jobu, jobv, m, n, k, p, niters,
# T, A, lda, T, Srand, T,
# Urand, ldUrand, T, Vrand, ldVrand,
# T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu,
# sizeof(buffer_cpu), info)
# end
#
# flag = @allowscalar info[1]
# unsafe_free!(info)
# chklapackerror(BlasInt(flag))
# ...
# end

# Xsyevd
function Xsyevd!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
m, n = size(A)
function Xsyevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
chkuplo(uplo)
n = checksquare(A)
R = real(T)
lda = max(1, stride(A, 2))
info = CuArray{Cint}(undef, 1)
W = CuArray{R}(undef, n)

function bufferSize()
out_cpu = Ref{Csize_t}(0)
Expand All @@ -296,32 +302,52 @@ function Xsyevd!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}

flag = @allowscalar info[1]
unsafe_free!(info)
chklapackerror(BlasInt(flag))
...
chkargsok(BlasInt(flag))

if jobz == 'N'
return W
elseif jobz == 'V'
return W, A
end
end

# Xsyevdx
function Xsyevdx!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
m, n = size(A)
lda = max(1, stride(A, 2))
info = CuArray{Cint}(undef, 1)

function bufferSize()
out_cpu = Ref{Csize_t}(0)
out_gpu = Ref{Csize_t}(0)
cusolverDnXsyevdx_bufferSize(dense_handle(), params, jobz, range, uplo, n,
T, A, lda, vl, vu, il, iu, h_meig,
T, W, T, out_gpu, out_cpu)
out_gpu[], out_cpu[]
end
with_workspaces(bufferSize) do buffer_gpu, buffer_cpu
cusolverDnXsyevdx(dense_handle(), params, jobz, range, uplo, n, T, A,
lda, vl, vu, il, iu, meig64, T, W, T, buffer_gpu,
sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info)
# function Xsyevdx!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
# chkuplo(uplo)
# n = checksquare(A)
# lda = max(1, stride(A, 2))
# info = CuArray{Cint}(undef, 1)
# ...
#
# function bufferSize()
# out_cpu = Ref{Csize_t}(0)
# out_gpu = Ref{Csize_t}(0)
# cusolverDnXsyevdx_bufferSize(dense_handle(), params, jobz, range, uplo, n,
# T, A, lda, vl, vu, il, iu, h_meig,
# T, W, T, out_gpu, out_cpu)
# out_gpu[], out_cpu[]
# end
# with_workspaces(bufferSize) do buffer_gpu, buffer_cpu
# cusolverDnXsyevdx(dense_handle(), params, jobz, range, uplo, n, T, A,
# lda, vl, vu, il, iu, meig64, T, W, T, buffer_gpu,
# sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info)
# end
#
# flag = @allowscalar info[1]
# unsafe_free!(info)
# chkargsok(BlasInt(flag))
#
# if jobz == 'N'
# return W
# elseif jobz == 'V'
# return W, A
# end
# end

# LAPACK
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
@eval begin
LinearAlgebra.LAPACK.sytrs!(uplo::Char, A::StridedCuMatrix{$elty}, p::CuVector{Int64}, B::StridedCuVecOrMat{$elty}) = CUSOLVER.Xsytrs!(uplo, A, p, B)
LinearAlgebra.LAPACK.trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.Xtrtri!(uplo, diag, A)
end

flag = @allowscalar info[1]
unsafe_free!(info)
chklapackerror(BlasInt(flag))
...
end

0 comments on commit f6107e8

Please sign in to comment.