Skip to content

Commit

Permalink
Merge pull request JuliaGPU#179 from erathorn/master
Browse files Browse the repository at this point in the history
include potri and test
  • Loading branch information
maleadt committed May 29, 2020
2 parents 9ab9986 + b4fef36 commit 6ee3739
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
37 changes: 37 additions & 0 deletions lib/cusolver/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,43 @@ for (fname,elty) in ((:cusolverDnSpotrs, :Float32),
end
end

## potri
for (bname, fname,elty) in ((:cusolverDnSpotri_bufferSize, :cusolverDnSpotri, :Float32),
(:cusolverDnDpotri_bufferSize, :cusolverDnDpotri, :Float64),
(:cusolverDnCpotri_bufferSize, :cusolverDnCpotri, :ComplexF32),
(:cusolverDnZpotri_bufferSize, :cusolverDnZpotri, :ComplexF64))
@eval begin
function LinearAlgebra.LAPACK.potri!(uplo::Char,
A::CuMatrix{$elty})

cuuplo = cublasfill(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
devinfo = CuArray{Cint}(undef, 1)

@workspace eltyp=$elty size=@argout(
$bname(dense_handle(), cuuplo, n, A, lda, out(Ref{Cint}(0)))
)[] buffer->begin
$fname(dense_handle(), cuuplo, n, A, lda, buffer, length(buffer), devinfo)
end

info = @allowscalar devinfo[1]
unsafe_free!(devinfo)
chkargsok(BlasInt(info))

A
end
end
end
"""
potri!(uplo::Char, A::CuMatrix)
!!! note
`potri!` requires CUDA >= 10.1
"""
LinearAlgebra.LAPACK.potri!(uplo::Char, A::CuMatrix)

#getrf
for (bname, fname,elty) in ((:cusolverDnSgetrf_bufferSize, :cusolverDnSgetrf, :Float32),
(:cusolverDnDgetrf_bufferSize, :cusolverDnDgetrf, :Float64),
Expand Down
25 changes: 25 additions & 0 deletions test/cusolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,31 @@ k = 1
@test_throws LinearAlgebra.PosDefException cholesky(d_A)
end

CUDA.CUSOLVER.version() >= v"10.1" && @testset "Cholesky inverse (potri)" begin
# test lower
A = rand(elty,n,n)
A = A*A' #posdef
d_A = CuArray(A)

LinearAlgebra.LAPACK.potrf!('L', A)
LinearAlgebra.LAPACK.potrf!('L', d_A)

LinearAlgebra.LAPACK.potri!('L', A)
LinearAlgebra.LAPACK.potri!('L', d_A)
@test A collect(d_A)

# test upper
A = rand(elty,n,n)
A = A*A' #posdef
d_A = CuArray(A)

LinearAlgebra.LAPACK.potrf!('U', A)
LinearAlgebra.LAPACK.potrf!('U', d_A)
LinearAlgebra.LAPACK.potri!('U', A)
LinearAlgebra.LAPACK.potri!('U', d_A)
@test A collect(d_A)
end

@testset "getrf!" begin
A = rand(elty,m,n)
d_A = CuArray(A)
Expand Down

0 comments on commit 6ee3739

Please sign in to comment.