diff --git a/lib/cusolver/dense.jl b/lib/cusolver/dense.jl index 4a77e34dcf..5ccf3c928e 100644 --- a/lib/cusolver/dense.jl +++ b/lib/cusolver/dense.jl @@ -258,9 +258,9 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, : @eval begin function ormqr!(side::Char, trans::Char, - A::CuMatrix{$elty}, + A::StridedCuMatrix{$elty}, tau::CuVector{$elty}, - C::CuVecOrMat{$elty}) + C::StridedCuVecOrMat{$elty}) # Support transa = 'C' for real matrices trans = $elty <: Real && trans == 'C' ? 'T' : trans diff --git a/lib/cusolver/dense_generic.jl b/lib/cusolver/dense_generic.jl index f761d1171b..c7e19bccbc 100644 --- a/lib/cusolver/dense_generic.jl +++ b/lib/cusolver/dense_generic.jl @@ -1,11 +1,11 @@ -# TO DO: cusolverDnSetAdvOptions - # Xpotrf function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} chkuplo(uplo) n = checksquare(A) lda = max(1, stride(A, 2)) info = CuArray{Cint}(undef, 1) + params = Ref{cusolverDnParams_t}(C_NULL) + cusolverDnCreateParams(params) function bufferSize() out_cpu = Ref{Csize_t}(0) @@ -20,6 +20,7 @@ function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} sizeof(buffer_cpu), info) end + cusolverDnDestroyParams(params[]) flag = @allowscalar info[1] unsafe_free!(info) chkargsok(BlasInt(flag)) @@ -35,8 +36,12 @@ function Xpotrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuVecOrMat{T}) whe lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) info = CuArray{Cint}(undef, 1) + params = Ref{cusolverDnParams_t}(C_NULL) + cusolverDnCreateParams(params) + cusolverDnXpotrs(dense_handle(), params, uplo, n, nrhs, T, A, lda, T, B, ldb, info) + cusolverDnDestroyParams(params[]) flag = @allowscalar info[1] unsafe_free!(info) chkargsok(BlasInt(flag)) @@ -48,6 +53,8 @@ function Xgetrf!(A::StridedCuMatrix{T}, ipiv::CuVector{T}) where {T <: BlasFloat m, n = size(A) lda = max(1, stride(A, 2)) info = CuArray{Cint}(undef, 1) + params = Ref{cusolverDnParams_t}(C_NULL) + cusolverDnCreateParams(params) function bufferSize() out_cpu = Ref{Csize_t}(0) @@ -62,6 +69,7 @@ function Xgetrf!(A::StridedCuMatrix{T}, ipiv::CuVector{T}) where {T <: BlasFloat sizeof(buffer_cpu), info) end + cusolverDnDestroyParams(params[]) flag = @allowscalar info[1] unsafe_free!(info) chkargsok(BlasInt(flag)) @@ -82,8 +90,12 @@ function Xgetrs!(trans::Char, A::StridedCuMatrix{T}, ipiv::CuVector{Int64}, B::S lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) info = CuArray{Cint}(undef, 1) + params = Ref{cusolverDnParams_t}(C_NULL) + cusolverDnCreateParams(params) + cusolverDnXgetrs(dense_handle(), params, trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, info) + cusolverDnDestroyParams(params[]) flag = @allowscalar info[1] unsafe_free!(info) chkargsok(BlasInt(flag)) @@ -95,6 +107,8 @@ function Xgeqrf!(A::StridedCuMatrix{T}, tau::CuVector{T}) where {T <: BlasFloat} m, n = size(A) lda = max(1, stride(A, 2)) info = CuArray{Cint}(undef, 1) + params = Ref{cusolverDnParams_t}(C_NULL) + cusolverDnCreateParams(params) function bufferSize() out_cpu = Ref{Csize_t}(0) @@ -109,6 +123,7 @@ function Xgeqrf!(A::StridedCuMatrix{T}, tau::CuVector{T}) where {T <: BlasFloat} buffer_cpu, sizeof(buffer_cpu), info) end + cusolverDnDestroyParams(params[]) flag = @allowscalar info[1] unsafe_free!(info) chkargsok(BlasInt(flag)) @@ -122,7 +137,7 @@ function Xgeqrf!(A::StridedCuMatrix{T}) where {T <: BlasFloat} end # Xsytrs -function Xsytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuMatrix{T}) where {T <: BlasFloat} +function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuMatrix{T}) where {T <: BlasFloat} chkuplo(uplo) n = checksquare(A) lda = max(1, stride(A, 2)) @@ -149,8 +164,8 @@ function Xsytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::Strid end # Xtrtri -function Xtrtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} - chktrans(trans) +function trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} + chkuplo(uplo) chkdiag(diag) n = checksquare(A) lda = max(1, stride(A, 2)) @@ -177,22 +192,22 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla m, n = size(A) R = real(T) (m < n) && throw(ArgumentError("The number of rows of A ($m) must be greater or equal to the number of columns of A ($n)")) - if jobu == 'A' - U = CuMatrix{T}(undef, m, m) + U = if jobu == 'A' + CuMatrix{T}(undef, m, m) elseif jobu == 'S' || jobu == 'O' - U = CuMatrix{T}(undef, m, min(m, n)) + CuMatrix{T}(undef, m, min(m, n)) elseif jobu == 'N' - U = CU_NULL + CU_NULL else throw(ArgumentError("jobu is incorrect. The values accepted are 'A', 'S', 'O' and 'N'.")) end Σ = CuVector{R}(undef, min(m, n)) - if jobvt == 'A' - U = CuMatrix{T}(undef, n, n) + Vt = if jobvt == 'A' + CuMatrix{T}(undef, n, n) elseif jobvt == 'S' || jobvt == 'O' - U = CuMatrix{T}(undef, min(m, n), n) + CuMatrix{T}(undef, min(m, n), n) elseif jobvt == 'N' - U = CU_NULL + CU_NULL else throw(ArgumentError("jobvt is incorrect. The values accepted are 'A', 'S', 'O' and 'N'.")) end @@ -200,6 +215,8 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla ldu = U == CU_NULL ? 1 : max(1, stride(U, 2)) ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2)) info = CuArray{Cint}(undef, 1) + params = Ref{cusolverDnParams_t}(C_NULL) + cusolverDnCreateParams(params) function bufferSize() out_cpu = Ref{Csize_t}(0) @@ -215,6 +232,7 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info) end + cusolverDnDestroyParams(params[]) flag = @allowscalar info[1] unsafe_free!(info) chklapackerror(BlasInt(flag)) @@ -227,6 +245,8 @@ end # lda = max(1, stride(A, 2)) # info = CuArray{Cint}(undef, 1) # h_err_sigma = Ref{Cdouble}(0) +# params = Ref{cusolverDnParams_t}(C_NULL) +# cusolverDnCreateParams(params) # # function bufferSize() # out_cpu = Ref{Csize_t}(0) @@ -243,6 +263,7 @@ end # buffer_cpu, sizeof(buffer_cpu), info, h_err_sigma) # end # +# cusolverDnDestroyParams(params[]) # flag = @allowscalar info[1] # unsafe_free!(info) # chklapackerror(BlasInt(flag)) @@ -254,6 +275,8 @@ end # m, n = size(A) # lda = max(1, stride(A, 2)) # info = CuArray{Cint}(undef, 1) +# params = Ref{cusolverDnParams_t}(C_NULL) +# cusolverDnCreateParams(params) # # function bufferSize() # out_cpu = Ref{Csize_t}(0) @@ -272,6 +295,7 @@ end # sizeof(buffer_cpu), info) # end # +# cusolverDnDestroyParams(params[]) # flag = @allowscalar info[1] # unsafe_free!(info) # chklapackerror(BlasInt(flag)) @@ -286,6 +310,8 @@ function Xsyevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: Blas lda = max(1, stride(A, 2)) info = CuArray{Cint}(undef, 1) W = CuArray{R}(undef, n) + params = Ref{cusolverDnParams_t}(C_NULL) + cusolverDnCreateParams(params) function bufferSize() out_cpu = Ref{Csize_t}(0) @@ -300,6 +326,7 @@ function Xsyevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: Blas buffer_cpu, sizeof(buffer_cpu), info) end + cusolverDnDestroyParams(params[]) flag = @allowscalar info[1] unsafe_free!(info) chkargsok(BlasInt(flag)) @@ -318,6 +345,8 @@ end # lda = max(1, stride(A, 2)) # info = CuArray{Cint}(undef, 1) # ... +# params = Ref{cusolverDnParams_t}(C_NULL) +# cusolverDnCreateParams(params) # # function bufferSize() # out_cpu = Ref{Csize_t}(0) @@ -333,6 +362,7 @@ end # sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info) # end # +# cusolverDnDestroyParams(params[]) # flag = @allowscalar info[1] # unsafe_free!(info) # chkargsok(BlasInt(flag)) @@ -347,7 +377,7 @@ end # LAPACK for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64) @eval begin - LinearAlgebra.LAPACK.sytrs!(uplo::Char, A::StridedCuMatrix{$elty}, p::CuVector{Int64}, B::StridedCuVecOrMat{$elty}) = CUSOLVER.Xsytrs!(uplo, A, p, B) - LinearAlgebra.LAPACK.trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.Xtrtri!(uplo, diag, A) + LinearAlgebra.LAPACK.sytrs!(uplo::Char, A::StridedCuMatrix{$elty}, p::CuVector{Int64}, B::StridedCuVecOrMat{$elty}) = CUSOLVER.sytrs!(uplo, A, p, B) + LinearAlgebra.LAPACK.trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.trtri!(uplo, diag, A) end end diff --git a/test/libraries/cusolver/dense.jl b/test/libraries/cusolver/dense.jl index 82de8cf2ff..3326825ccb 100644 --- a/test/libraries/cusolver/dense.jl +++ b/test/libraries/cusolver/dense.jl @@ -47,7 +47,7 @@ k = 1 @test_throws LinearAlgebra.PosDefException cholesky(d_A) end - CUDA.CUSOLVER.version() >= v"10.1" && @testset "Cholesky inverse (potri)" begin + @testset "Cholesky inverse (potri)" begin # test lower A = rand(elty,n,n) A = A*A'+I #posdef diff --git a/test/libraries/cusolver/dense_generic.jl b/test/libraries/cusolver/dense_generic.jl new file mode 100644 index 0000000000..a978ede04f --- /dev/null +++ b/test/libraries/cusolver/dense_generic.jl @@ -0,0 +1,37 @@ +using CUDA.CUSOLVER +using LinearAlgebra + +n = 10 +p = 5 + +@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64] + @testset "sytrs!" begin + for uplo in ('L', 'U') + A = rand(elty,n,n) + B = rand(elty,n,p) + A = A + A' + d_A = CuMatrix(A) + d_B = CuMatrix(B) + d_A, d_ipiv, _ = CUSOLVER.sytrf!(uplo, d_A) + d_ipiv = CuVector{Int64}(d_ipiv) + A, ipiv, _ = LinearAlgebra.sytrf!(uplo, A) + CUSOLVER.sytrs!(uplo, d_A, d_ipiv, d_B) + LinearAlgebra.sytrs!(uplo, A, ipiv, B) + @test B ≈ collect(d_B) + end + end + + @testset "trtri!" begin + for uplo in ('L', 'U') + for diag in ('N', 'U') + A = rand(elty,n,n) + A = uplo == 'L' ? tril(A) : triu(A) + A = diag == 'N' ? A : A - Diagonal(A) + I + d_A = CuMatrix(A) + d_B = copy(d_A) + CUSOLVER.trtri!(uplo, diag, d_B) + @test collect(d_A * d_B) ≈ I + end + end + end +end