From 76e050424d251b8347af6a088f37fa2894125a13 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Sat, 28 Oct 2023 18:25:18 -0500 Subject: [PATCH] [CUSOLVER] Interface sparse Cholesky and QR factorizations --- lib/cusolver/CUSOLVER.jl | 1 + lib/cusolver/libcusolver.jl | 166 +++++----- lib/cusolver/sparse_factorizations.jl | 292 ++++++++++++++++++ res/wrap/cusolver.toml | 47 +++ .../cusolver/sparse_factorizations.jl | 58 ++++ 5 files changed, 481 insertions(+), 83 deletions(-) create mode 100644 lib/cusolver/sparse_factorizations.jl create mode 100644 test/libraries/cusolver/sparse_factorizations.jl diff --git a/lib/cusolver/CUSOLVER.jl b/lib/cusolver/CUSOLVER.jl index 5619fd377c..1b9d5d5f9b 100644 --- a/lib/cusolver/CUSOLVER.jl +++ b/lib/cusolver/CUSOLVER.jl @@ -36,6 +36,7 @@ include("libcusolverRF.jl") include("error.jl") include("base.jl") include("sparse.jl") +include("sparse_factorizations.jl") include("dense.jl") include("dense_generic.jl") include("multigpu.jl") diff --git a/lib/cusolver/libcusolver.jl b/lib/cusolver/libcusolver.jl index b18aad6e22..0b5a84f6d7 100644 --- a/lib/cusolver/libcusolver.jl +++ b/lib/cusolver/libcusolver.jl @@ -5624,8 +5624,8 @@ end @ccall libcusolver.cusolverSpXcsrqrAnalysis(handle::cusolverSpHandle_t, m::Cint, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrqrInfo_t)::cusolverStatus_t end @@ -5636,9 +5636,9 @@ end @ccall libcusolver.cusolverSpScsrqrBufferInfo(handle::cusolverSpHandle_t, m::Cint, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{Cfloat}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, info::csrqrInfo_t, + csrValA::CuPtr{Cfloat}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrqrInfo_t, internalDataInBytes::Ptr{Csize_t}, workspaceInBytes::Ptr{Csize_t})::cusolverStatus_t end @@ -5650,9 +5650,9 @@ end @ccall libcusolver.cusolverSpDcsrqrBufferInfo(handle::cusolverSpHandle_t, m::Cint, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{Cdouble}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, info::csrqrInfo_t, + csrValA::CuPtr{Cdouble}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrqrInfo_t, internalDataInBytes::Ptr{Csize_t}, workspaceInBytes::Ptr{Csize_t})::cusolverStatus_t end @@ -5664,9 +5664,9 @@ end @ccall libcusolver.cusolverSpCcsrqrBufferInfo(handle::cusolverSpHandle_t, m::Cint, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{cuComplex}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, info::csrqrInfo_t, + csrValA::CuPtr{cuComplex}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrqrInfo_t, internalDataInBytes::Ptr{Csize_t}, workspaceInBytes::Ptr{Csize_t})::cusolverStatus_t end @@ -5678,9 +5678,9 @@ end @ccall libcusolver.cusolverSpZcsrqrBufferInfo(handle::cusolverSpHandle_t, m::Cint, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{cuDoubleComplex}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, info::csrqrInfo_t, + csrValA::CuPtr{cuDoubleComplex}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrqrInfo_t, internalDataInBytes::Ptr{Csize_t}, workspaceInBytes::Ptr{Csize_t})::cusolverStatus_t end @@ -5690,8 +5690,8 @@ end initialize_context() @ccall libcusolver.cusolverSpScsrqrSetup(handle::cusolverSpHandle_t, m::Cint, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{Cfloat}, csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, mu::Cfloat, + csrValA::CuPtr{Cfloat}, csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, mu::Cfloat, info::csrqrInfo_t)::cusolverStatus_t end @@ -5700,8 +5700,8 @@ end initialize_context() @ccall libcusolver.cusolverSpDcsrqrSetup(handle::cusolverSpHandle_t, m::Cint, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{Cdouble}, csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, mu::Cdouble, + csrValA::CuPtr{Cdouble}, csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, mu::Cdouble, info::csrqrInfo_t)::cusolverStatus_t end @@ -5710,8 +5710,8 @@ end initialize_context() @ccall libcusolver.cusolverSpCcsrqrSetup(handle::cusolverSpHandle_t, m::Cint, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{cuComplex}, csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, mu::cuComplex, + csrValA::CuPtr{cuComplex}, csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, mu::cuComplex, info::csrqrInfo_t)::cusolverStatus_t end @@ -5720,8 +5720,8 @@ end initialize_context() @ccall libcusolver.cusolverSpZcsrqrSetup(handle::cusolverSpHandle_t, m::Cint, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{cuDoubleComplex}, - csrRowPtrA::Ptr{Cint}, csrColIndA::Ptr{Cint}, + csrValA::CuPtr{cuDoubleComplex}, + csrRowPtrA::CuPtr{Cint}, csrColIndA::CuPtr{Cint}, mu::cuDoubleComplex, info::csrqrInfo_t)::cusolverStatus_t end @@ -5729,33 +5729,33 @@ end @checked function cusolverSpScsrqrFactor(handle, m, n, nnzA, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpScsrqrFactor(handle::cusolverSpHandle_t, m::Cint, n::Cint, - nnzA::Cint, b::Ptr{Cfloat}, x::Ptr{Cfloat}, + nnzA::Cint, b::CuPtr{Cfloat}, x::CuPtr{Cfloat}, info::csrqrInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpDcsrqrFactor(handle, m, n, nnzA, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpDcsrqrFactor(handle::cusolverSpHandle_t, m::Cint, n::Cint, - nnzA::Cint, b::Ptr{Cdouble}, x::Ptr{Cdouble}, + nnzA::Cint, b::CuPtr{Cdouble}, x::CuPtr{Cdouble}, info::csrqrInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpCcsrqrFactor(handle, m, n, nnzA, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpCcsrqrFactor(handle::cusolverSpHandle_t, m::Cint, n::Cint, - nnzA::Cint, b::Ptr{cuComplex}, - x::Ptr{cuComplex}, info::csrqrInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + nnzA::Cint, b::CuPtr{cuComplex}, + x::CuPtr{cuComplex}, info::csrqrInfo_t, + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpZcsrqrFactor(handle, m, n, nnzA, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpZcsrqrFactor(handle::cusolverSpHandle_t, m::Cint, n::Cint, - nnzA::Cint, b::Ptr{cuDoubleComplex}, - x::Ptr{cuDoubleComplex}, info::csrqrInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + nnzA::Cint, b::CuPtr{cuDoubleComplex}, + x::CuPtr{cuDoubleComplex}, info::csrqrInfo_t, + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpScsrqrZeroPivot(handle, info, tol, position) @@ -5789,33 +5789,33 @@ end @checked function cusolverSpScsrqrSolve(handle, m, n, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpScsrqrSolve(handle::cusolverSpHandle_t, m::Cint, n::Cint, - b::Ptr{Cfloat}, x::Ptr{Cfloat}, + b::CuPtr{Cfloat}, x::CuPtr{Cfloat}, info::csrqrInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpDcsrqrSolve(handle, m, n, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpDcsrqrSolve(handle::cusolverSpHandle_t, m::Cint, n::Cint, - b::Ptr{Cdouble}, x::Ptr{Cdouble}, + b::CuPtr{Cdouble}, x::CuPtr{Cdouble}, info::csrqrInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpCcsrqrSolve(handle, m, n, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpCcsrqrSolve(handle::cusolverSpHandle_t, m::Cint, n::Cint, - b::Ptr{cuComplex}, x::Ptr{cuComplex}, + b::CuPtr{cuComplex}, x::CuPtr{cuComplex}, info::csrqrInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpZcsrqrSolve(handle, m, n, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpZcsrqrSolve(handle::cusolverSpHandle_t, m::Cint, n::Cint, - b::Ptr{cuDoubleComplex}, - x::Ptr{cuDoubleComplex}, info::csrqrInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + b::CuPtr{cuDoubleComplex}, + x::CuPtr{cuDoubleComplex}, info::csrqrInfo_t, + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpCreateCsrcholInfoHost(info) @@ -6025,8 +6025,8 @@ end initialize_context() @ccall libcusolver.cusolverSpXcsrcholAnalysis(handle::cusolverSpHandle_t, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrcholInfo_t)::cusolverStatus_t end @@ -6036,9 +6036,9 @@ end initialize_context() @ccall libcusolver.cusolverSpScsrcholBufferInfo(handle::cusolverSpHandle_t, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{Cfloat}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, + csrValA::CuPtr{Cfloat}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrcholInfo_t, internalDataInBytes::Ptr{Csize_t}, workspaceInBytes::Ptr{Csize_t})::cusolverStatus_t @@ -6050,9 +6050,9 @@ end initialize_context() @ccall libcusolver.cusolverSpDcsrcholBufferInfo(handle::cusolverSpHandle_t, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{Cdouble}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, + csrValA::CuPtr{Cdouble}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrcholInfo_t, internalDataInBytes::Ptr{Csize_t}, workspaceInBytes::Ptr{Csize_t})::cusolverStatus_t @@ -6064,9 +6064,9 @@ end initialize_context() @ccall libcusolver.cusolverSpCcsrcholBufferInfo(handle::cusolverSpHandle_t, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{cuComplex}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, + csrValA::CuPtr{cuComplex}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrcholInfo_t, internalDataInBytes::Ptr{Csize_t}, workspaceInBytes::Ptr{Csize_t})::cusolverStatus_t @@ -6078,9 +6078,9 @@ end initialize_context() @ccall libcusolver.cusolverSpZcsrcholBufferInfo(handle::cusolverSpHandle_t, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{cuDoubleComplex}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, + csrValA::CuPtr{cuDoubleComplex}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrcholInfo_t, internalDataInBytes::Ptr{Csize_t}, workspaceInBytes::Ptr{Csize_t})::cusolverStatus_t @@ -6091,9 +6091,9 @@ end initialize_context() @ccall libcusolver.cusolverSpScsrcholFactor(handle::cusolverSpHandle_t, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{Cfloat}, csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, info::csrcholInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + csrValA::CuPtr{Cfloat}, csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrcholInfo_t, + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpDcsrcholFactor(handle, n, nnzA, descrA, csrValA, csrRowPtrA, @@ -6101,10 +6101,10 @@ end initialize_context() @ccall libcusolver.cusolverSpDcsrcholFactor(handle::cusolverSpHandle_t, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{Cdouble}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, info::csrcholInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + csrValA::CuPtr{Cdouble}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrcholInfo_t, + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpCcsrcholFactor(handle, n, nnzA, descrA, csrValA, csrRowPtrA, @@ -6112,10 +6112,10 @@ end initialize_context() @ccall libcusolver.cusolverSpCcsrcholFactor(handle::cusolverSpHandle_t, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{cuComplex}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, info::csrcholInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + csrValA::CuPtr{cuComplex}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrcholInfo_t, + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpZcsrcholFactor(handle, n, nnzA, descrA, csrValA, csrRowPtrA, @@ -6123,10 +6123,10 @@ end initialize_context() @ccall libcusolver.cusolverSpZcsrcholFactor(handle::cusolverSpHandle_t, n::Cint, nnzA::Cint, descrA::cusparseMatDescr_t, - csrValA::Ptr{cuDoubleComplex}, - csrRowPtrA::Ptr{Cint}, - csrColIndA::Ptr{Cint}, info::csrcholInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + csrValA::CuPtr{cuDoubleComplex}, + csrRowPtrA::CuPtr{Cint}, + csrColIndA::CuPtr{Cint}, info::csrcholInfo_t, + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpScsrcholZeroPivot(handle, info, tol, position) @@ -6160,59 +6160,59 @@ end @checked function cusolverSpScsrcholSolve(handle, n, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpScsrcholSolve(handle::cusolverSpHandle_t, n::Cint, - b::Ptr{Cfloat}, x::Ptr{Cfloat}, + b::CuPtr{Cfloat}, x::CuPtr{Cfloat}, info::csrcholInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpDcsrcholSolve(handle, n, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpDcsrcholSolve(handle::cusolverSpHandle_t, n::Cint, - b::Ptr{Cdouble}, x::Ptr{Cdouble}, + b::CuPtr{Cdouble}, x::CuPtr{Cdouble}, info::csrcholInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpCcsrcholSolve(handle, n, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpCcsrcholSolve(handle::cusolverSpHandle_t, n::Cint, - b::Ptr{cuComplex}, x::Ptr{cuComplex}, + b::CuPtr{cuComplex}, x::CuPtr{cuComplex}, info::csrcholInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpZcsrcholSolve(handle, n, b, x, info, pBuffer) initialize_context() @ccall libcusolver.cusolverSpZcsrcholSolve(handle::cusolverSpHandle_t, n::Cint, - b::Ptr{cuDoubleComplex}, - x::Ptr{cuDoubleComplex}, info::csrcholInfo_t, - pBuffer::Ptr{Cvoid})::cusolverStatus_t + b::CuPtr{cuDoubleComplex}, + x::CuPtr{cuDoubleComplex}, info::csrcholInfo_t, + pBuffer::CuPtr{Cvoid})::cusolverStatus_t end @checked function cusolverSpScsrcholDiag(handle, info, diag) initialize_context() @ccall libcusolver.cusolverSpScsrcholDiag(handle::cusolverSpHandle_t, info::csrcholInfo_t, - diag::Ptr{Cfloat})::cusolverStatus_t + diag::CuPtr{Cfloat})::cusolverStatus_t end @checked function cusolverSpDcsrcholDiag(handle, info, diag) initialize_context() @ccall libcusolver.cusolverSpDcsrcholDiag(handle::cusolverSpHandle_t, info::csrcholInfo_t, - diag::Ptr{Cdouble})::cusolverStatus_t + diag::CuPtr{Cdouble})::cusolverStatus_t end @checked function cusolverSpCcsrcholDiag(handle, info, diag) initialize_context() @ccall libcusolver.cusolverSpCcsrcholDiag(handle::cusolverSpHandle_t, info::csrcholInfo_t, - diag::Ptr{Cfloat})::cusolverStatus_t + diag::CuPtr{Cfloat})::cusolverStatus_t end @checked function cusolverSpZcsrcholDiag(handle, info, diag) initialize_context() @ccall libcusolver.cusolverSpZcsrcholDiag(handle::cusolverSpHandle_t, info::csrcholInfo_t, - diag::Ptr{Cdouble})::cusolverStatus_t + diag::CuPtr{Cdouble})::cusolverStatus_t end diff --git a/lib/cusolver/sparse_factorizations.jl b/lib/cusolver/sparse_factorizations.jl new file mode 100644 index 0000000000..bff078be78 --- /dev/null +++ b/lib/cusolver/sparse_factorizations.jl @@ -0,0 +1,292 @@ +mutable struct SparseQRInfo + info::csrqrInfo_t + + function SparseQRInfo() + info_ref = Ref{csrqrInfo_t}() + cusolverSpCreateCsrqrInfo(info_ref) + obj = new(info_ref[]) + finalizer(cusolverSpDestroyCsrqrInfo, obj) + obj + end +end + +Base.unsafe_convert(::Type{csrqrInfo_t}, info::SparseQRInfo) = info.info + +mutable struct SparseQR{T <: BlasFloat} + n::Cint + m::Cint + nnzA::Cint + mu::T + handle::cusolverSpHandle_t + descA::CuMatrixDescriptor + info::SparseQRInfo + buffer::Union{CuPtr{Cvoid},CuVector{UInt8}} +end + +function SparseQR(A::CuSparseMatrixCSR{T,Cint}, index::Char='O') where T <: BlasFloat + m,n = size(A) + nnzA = nnz(A) + mu = zero(T) + handle = sparse_handle() + descA = CuMatrixDescriptor('G', 'L', 'N', index) + handle = sparse_handle() + info = SparseQRInfo() + buffer = CU_NULL + F = SparseQR{T}(n, m, nnzA, mu, handle, descA, info, buffer) + spqr_analyse(F, A) + return F +end + +# csrqrAnalysis +# +# cusolverStatus_t cusolverSpXcsrqrAnalysis( +# cusolverSpHandle_t handle, +# int m, +# int n, +# int nnzA, +# const cusparseMatDescr_t descrA, +# const int * csrRowPtrA, +# const int * csrColIndA, +# csrqrInfo_t info); +function spqr_analyse(F::SparseQR{T}, A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat + cusolverSpXcsrqrAnalysis(F.handle, F.m, F.n, F.nnzA, F.descA, A.rowPtr, A.colVal, F.info) + return F +end + +#csrqrSetup +for (fname, elty) in ((:cusolverSpScsrqrSetup, :Float32), + (:cusolverSpDcsrqrSetup, :Float64), + (:cusolverSpCcsrqrSetup, :ComplexF32), + (:cusolverSpZcsrqrSetup, :ComplexF64)) + @eval begin + # cusolverStatus_t cusolverSpScsrqrSetup( + # cusolverSpHandle_t handle, + # int m, + # int n, + # int nnzA, + # const cusparseMatDescr_t descrA, + # const float * csrValA, + # const int * csrRowPtrA, + # const int * csrColIndA, + # float mu, + # csrqrInfo_t info); + function spqr_setup(F::SparseQR{$elty}, A::CuSparseMatrixCSR{$elty,Cint}) + $fname(F.handle, F.m, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.mu, F.info) + return F + end + end +end + +for (bname, fname, sname, pname, elty, relty) in + ((:cusolverSpScsrqrBufferInfo, :cusolverSpScsrqrFactor, :cusolverSpScsrqrSolve, :cusolverSpScsrqrZeroPivot, :Float32 , :Float32), + (:cusolverSpDcsrqrBufferInfo, :cusolverSpDcsrqrFactor, :cusolverSpDcsrqrSolve, :cusolverSpDcsrqrZeroPivot, :Float64 , :Float64), + (:cusolverSpCcsrqrBufferInfo, :cusolverSpCcsrqrFactor, :cusolverSpCcsrqrSolve, :cusolverSpCcsrqrZeroPivot, :ComplexF32, :Float32), + (:cusolverSpZcsrqrBufferInfo, :cusolverSpZcsrqrFactor, :cusolverSpZcsrqrSolve, :cusolverSpZcsrqrZeroPivot, :ComplexF64, :Float64)) + @eval begin + # csrqrBufferInfo + # + # cusolverStatus_t cusolverSpScsrqrBufferInfo( + # cusolverSpHandle_t handle, + # int m, + # int n, + # int nnzA, + # const cusparseMatDescr_t descrA, + # const float * csrValA, + # const int * csrRowPtrA, + # const int * csrColIndA, + # csrqrInfo_t info, + # size_t * internalDataInBytes, + # size_t * workspaceInBytes); + function spqr_buffer(F::SparseQR{$elty}, A::CuSparseMatrixCSR{$elty,Cint}) + internalDataInBytes = Ref{Csize_t}(0) + workspaceInBytes = Ref{Csize_t}(0) + $bname(F.handle, F.m, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.info, internalDataInBytes, workspaceInBytes) + # TODO: allocate buffer? + F.buffer = CuVector{UInt8}(undef, workspaceInBytes[]) + return F + end + + # csrqrFactor + # + # cusolverStatus_t cusolverSpScsrqrFactor( + # cusolverSpHandle_t handle, + # int m, + # int n, + # int nnzA, + # float * b, + # float * x, + # csrqrInfo_t info, + # void * pBuffer); + # + # csrqrZeroPivot + # + # cusolverStatus_t cusolverSpDcsrqrZeroPivot( + # cusolverSpHandle_t handle, + # csrqrInfo_t info, + # double tol, + # int * position); + function spqr_factorise(F::SparseQR{$elty}, tol::$relty) + $fname(F.handle, F.m, F.n, F.nnzA, CU_NULL, CU_NULL, F.info, F.buffer) + singularity = Ref{Cint}(0) + $pname(F.handle, F.info, tol, singularity) + (singularity[] ≥ 0) && throw(SingularException(singularity[])) + return F + end + + function spqr_factorise_solve(F::SparseQR{$elty}, b::CuVecOrMat{$elty}, x::CuVecOrMat{$elty}, tol::$relty) + $fname(F.handle, F.m, F.n, F.nnzA, b, x, F.info, F.buffer) + singularity = Ref{Cint}(0) + $pname(F.handle, F.info, tol, singularity) + (singularity[] ≥ 0) && throw(SingularException(singularity[])) + return F + end + + # csrqrSolve + # + # cusolverStatus_t CUSOLVERAPI cusolverSpScsrqrSolve( + # cusolverSpHandle_t handle, + # int m, + # int n, + # float * b, + # float * x, + # csrqrInfo_t info, + # void * pBuffer); + function spqr_solve(F::SparseQR{$elty}, b::CuVecOrMat{$elty}, x::CuVecOrMat{$elty}) + $sname(F.handle, F.m, F.n, b, x, F.info, F.buffer) + return x + end + end +end + +mutable struct SparseCholeskyInfo + info::csrcholInfo_t + + function SparseCholeskyInfo() + info_ref = Ref{csrcholInfo_t}() + cusolverSpCreateCsrcholInfo(info_ref) + obj = new(info_ref[]) + finalizer(cusolverSpDestroyCsrcholInfo, obj) + obj + end +end + +Base.unsafe_convert(::Type{csrcholInfo_t}, info::SparseCholeskyInfo) = info.info + +mutable struct SparseCholesky{T <: BlasFloat} + n::Cint + nnzA::Cint + handle::cusolverSpHandle_t + descA::CuMatrixDescriptor + info::SparseCholeskyInfo + buffer::Union{CuPtr{Cvoid},CuVector{UInt8}} +end + +function SparseCholesky(A::CuSparseMatrixCSR{T,Cint}, index::Char='O') where T <: BlasFloat + n = checksquare(A) + nnzA = nnz(A) + handle = sparse_handle() + descA = CuMatrixDescriptor('G', 'L', 'N', index) + info = SparseCholeskyInfo() + buffer = CU_NULL + F = SparseCholesky{T}(n, nnzA, handle, descA, info, buffer) + spcholesky_analyse(F, A) + return F +end + +# csrcholAnalysis +# +# cusolverStatus_t cusolverSpXcsrcholAnalysis( +# cusolverSpHandle_t handle, +# int n, +# int nnzA, +# const cusparseMatDescr_t descrA, +# const int * csrRowPtrA, +# const int * csrColIndA, +# csrcholInfo_t info); +function spcholesky_analyse(F::SparseCholesky{T}, A::CuSparseMatrixCSR{T}) where T <: BlasFloat + cusolverSpXcsrcholAnalysis(F.handle, F.n, F.nnzA, F.descA, A.rowPtr, A.colVal, F.info) + return F +end + +for (bname, fname, pname, sname, dname, elty, relty) in + ((:cusolverSpScsrcholBufferInfo, :cusolverSpScsrcholFactor, :cusolverSpScsrcholZeroPivot, :cusolverSpScsrcholSolve, :cusolverSpScsrcholDiag, :Float32 , :Float32), + (:cusolverSpDcsrcholBufferInfo, :cusolverSpDcsrcholFactor, :cusolverSpDcsrcholZeroPivot, :cusolverSpDcsrcholSolve, :cusolverSpDcsrcholDiag, :Float64 , :Float64), + (:cusolverSpCcsrcholBufferInfo, :cusolverSpCcsrcholFactor, :cusolverSpCcsrcholZeroPivot, :cusolverSpCcsrcholSolve, :cusolverSpCcsrcholDiag, :ComplexF32, :Float32), + (:cusolverSpZcsrcholBufferInfo, :cusolverSpZcsrcholFactor, :cusolverSpZcsrcholZeroPivot, :cusolverSpZcsrcholSolve, :cusolverSpZcsrcholDiag, :ComplexF64, :Float64)) + @eval begin + # csrcholBufferInfo + # + # cusolverStatus_t cusolverSpScsrcholBufferInfo( + # cusolverSpHandle_t handle, + # int n, + # int nnzA, + # const cusparseMatDescr_t descrA, + # const float * csrValA, + # const int * csrRowPtrA, + # const int * csrColIndA, + # csrcholInfo_t info, + # size_t * internalDataInBytes, + # size_t * workspaceInBytes); + function spcholesky_buffer(F::SparseCholesky{$elty}, A::CuSparseMatrixCSR{$elty}) + internalDataInBytes = Ref{Csize_t}(0) + workspaceInBytes = Ref{Csize_t}(0) + $bname(F.handle, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.info, internalDataInBytes, workspaceInBytes) + # TODO: allocate buffer? + F.buffer = CuVector{UInt8}(undef, workspaceInBytes[]) + return F + end + + # csrcholFactor + # + # cusolverStatus_t cusolverSpScsrcholFactor( + # cusolverSpHandle_t handle, + # int n, + # int nnzA, + # const cusparseMatDescr_t descrA, + # const float * csrValA, + # const int * csrRowPtrA, + # const int * csrColIndA, + # csrcholInfo_t info, + # void * pBuffer); + # + # csrcholZeroPivot + # + # cusolverStatus_t cusolverSpScsrcholZeroPivot( + # cusolverSpHandle_t handle, + # csrcholInfo_t info, + # float tol, + # int * position); + function spcholesky_factorise(F::SparseCholesky{$elty}, A::CuSparseMatrixCSR{$elty}, tol::$relty) + $fname(F.handle, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.info, F.buffer) + singularity = Ref{Cint}(0) + $pname(F.handle, F.info, tol, singularity) + (singularity[] ≥ 0) && throw(SingularException(singularity[])) + return F + end + + # csrcholSolve + # + # cusolverStatus_t cusolverSpZcsrcholSolve( + # cusolverSpHandle_t handle, + # int n, + # const cuDoubleComplex *b, + # cuDoubleComplex * x, + # csrcholInfo_t info, + # void * pBuffer); + function spcholesky_solve(F::SparseCholesky{$elty}, b::CuVecOrMat{$elty}, x::CuVecOrMat{$elty}) + $sname(F.handle, F.n, b, x, F.info, F.buffer) + return x + end + + # csrcholDiag + # + # cusolverStatus_t cusolverSpCcsrcholDiag( + # cusolverSpHandle_t handle, + # csrcholInfo_t info, + # float * diag); + function spcholesky_diag(F::SparseCholesky{$elty}, diag::CuVector{$relty}) + $dname(F.handle, F.info, diag) + return diag + end + end +end diff --git a/res/wrap/cusolver.toml b/res/wrap/cusolver.toml index ce89ee99d9..f133697031 100644 --- a/res/wrap/cusolver.toml +++ b/res/wrap/cusolver.toml @@ -2121,3 +2121,50 @@ needs_context = false 14 = "CuPtr{Cvoid}" 16 = "CuPtr{Cvoid}" 19 = "CuPtr{Cvoid}" + +[api.cusolverSpXcsrcholDiag.argtypes] +3 = "CuPtr{T}" + +[api.cusolverSpXcsrcholSolve.argtypes] +3 = "CuPtr{T}" +4 = "CuPtr{T}" +6 = "CuPtr{Cvoid}" + +[api.cusolverSpXcsrcholFactor.argtypes] +5 = "CuPtr{T}" +6 = "CuPtr{Cint}" +7 = "CuPtr{Cint}" +9 = "CuPtr{Cvoid}" + +[api.cusolverSpXcsrcholBufferInfo.argtypes] +5 = "CuPtr{T}" +6 = "CuPtr{Cint}" +7 = "CuPtr{Cint}" + +[api.cusolverSpXcsrcholAnalysis.argtypes] +5 = "CuPtr{Cint}" +6 = "CuPtr{Cint}" + +[api.cusolverSpXcsrqrSolve.argtypes] +4 = "CuPtr{T}" +5 = "CuPtr{T}" +7 = "CuPtr{Cvoid}" + +[api.cusolverSpXcsrqrSetup.argtypes] +6 = "CuPtr{T}" +7 = "CuPtr{Cint}" +8 = "CuPtr{Cint}" + +[api.cusolverSpXcsrqrBufferInfo.argtypes] +6 = "CuPtr{T}" +7 = "CuPtr{Cint}" +8 = "CuPtr{Cint}" + +[api.cusolverSpXcsrqrAnalysis.argtypes] +6 = "CuPtr{Cint}" +7 = "CuPtr{Cint}" + +[api.cusolverSpXcsrqrFactor.argtypes] +5 = "CuPtr{T}" +6 = "CuPtr{T}" +8 = "CuPtr{Cvoid}" diff --git a/test/libraries/cusolver/sparse_factorizations.jl b/test/libraries/cusolver/sparse_factorizations.jl new file mode 100644 index 0000000000..870b0815bb --- /dev/null +++ b/test/libraries/cusolver/sparse_factorizations.jl @@ -0,0 +1,58 @@ +using CUDA.CUSOLVER, CUDA.CUSPARSE + +m = 200 +n = 100 +density = 0.05 + +@testset "SparseCholesky -- $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64] + R = real(elty) + A = sprand(elty, n, n, density) + A = A * A' + I + d_A = CuSparseMatrixCSR{elty}(A) + F = CUSOLVER.SparseCholesky(d_A) + CUSOLVER.spcholesky_buffer(F, d_A) + tol = R == Float32 ? R(1e-6) : R(1e-12) + CUSOLVER.spcholesky_factorise(F, d_A, tol) + b = rand(elty, n) + d_b = CuVector(b) + x = zeros(elty, n) + d_x = CuVector(x) + CUSOLVER.spcholesky_solve(F, d_b, d_x) + d_r = d_b - d_A * d_x + @test norm(d_A' * d_r) ≤ √eps(R) + diag = zeros(elty, n) + d_diag = CuVector{R}(diag) + CUSOLVER.spcholesky_diag(F, d_diag) + det_A = mapreduce(x -> x^2, *, d_diag) + @test det_A ≈ det(Matrix(A)) +end + +@testset "SparseQR -- $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64] + R = real(elty) + A = sprand(elty, m, n, density) + d_A = CuSparseMatrixCSR{elty}(A) + F = CUSOLVER.SparseQR(d_A) + CUSOLVER.spqr_setup(F, d_A) + CUSOLVER.spqr_buffer(F, d_A) + tol = R == Float32 ? R(1e-6) : R(1e-12) + CUSOLVER.spqr_factorise(F, tol) + b = rand(elty, m) + d_b = CuVector(b) + x = zeros(elty, n) + d_x = CuVector(x) + CUSOLVER.spqr_solve(F, copy(d_b), d_x) + d_r = d_b - d_A * d_x + @test norm(d_A' * d_r) ≤ √eps(R) + + d_B = copy(d_A) + nnz_B = rand(elty, nnz(d_B)) + d_B.nzVal = CuVector{elty}(nnz_B) + CUSOLVER.spqr_setup(F, d_B) + b = rand(elty, m) + d_b = CuVector(b) + x = zeros(elty, n) + d_x = CuVector(x) + CUSOLVER.spqr_factorise_solve(F, copy(d_b), d_x, tol) + d_r = d_b - d_B * d_x + @test norm(d_B' * d_r) ≤ √eps(R) +end