diff --git a/lib/cublas/libcublas.jl b/lib/cublas/libcublas.jl index a026cd5896..e5e2f07daf 100644 --- a/lib/cublas/libcublas.jl +++ b/lib/cublas/libcublas.jl @@ -3262,10 +3262,10 @@ end beta, yarray, incy, batchCount) initialize_context() @ccall libcublas.cublasSgemvBatched(handle::cublasHandle_t, trans::cublasOperation_t, - m::Cint, n::Cint, alpha::Ptr{Cfloat}, - Aarray::Ptr{Ptr{Cfloat}}, lda::Cint, - xarray::Ptr{Ptr{Cfloat}}, incx::Cint, - beta::Ptr{Cfloat}, yarray::Ptr{Ptr{Cfloat}}, + m::Cint, n::Cint, alpha::RefOrCuRef{Cfloat}, + Aarray::CuPtr{Ptr{Cfloat}}, lda::Cint, + xarray::CuPtr{Ptr{Cfloat}}, incx::Cint, + beta::RefOrCuRef{Cfloat}, yarray::CuPtr{Ptr{Cfloat}}, incy::Cint, batchCount::Cint)::cublasStatus_t end @@ -3273,10 +3273,10 @@ end incx, beta, yarray, incy, batchCount) initialize_context() @ccall libcublas.cublasSgemvBatched_64(handle::cublasHandle_t, trans::cublasOperation_t, - m::Int64, n::Int64, alpha::Ptr{Cfloat}, - Aarray::Ptr{Ptr{Cfloat}}, lda::Int64, - xarray::Ptr{Ptr{Cfloat}}, incx::Int64, - beta::Ptr{Cfloat}, yarray::Ptr{Ptr{Cfloat}}, + m::Int64, n::Int64, alpha::RefOrCuRef{Cfloat}, + Aarray::CuPtr{Ptr{Cfloat}}, lda::Int64, + xarray::CuPtr{Ptr{Cfloat}}, incx::Int64, + beta::RefOrCuRef{Cfloat}, yarray::CuPtr{Ptr{Cfloat}}, incy::Int64, batchCount::Int64)::cublasStatus_t end @@ -3284,10 +3284,10 @@ end beta, yarray, incy, batchCount) initialize_context() @ccall libcublas.cublasDgemvBatched(handle::cublasHandle_t, trans::cublasOperation_t, - m::Cint, n::Cint, alpha::Ptr{Cdouble}, - Aarray::Ptr{Ptr{Cdouble}}, lda::Cint, - xarray::Ptr{Ptr{Cdouble}}, incx::Cint, - beta::Ptr{Cdouble}, yarray::Ptr{Ptr{Cdouble}}, + m::Cint, n::Cint, alpha::RefOrCuRef{Cdouble}, + Aarray::CuPtr{Ptr{Cdouble}}, lda::Cint, + xarray::CuPtr{Ptr{Cdouble}}, incx::Cint, + beta::RefOrCuRef{Cdouble}, yarray::CuPtr{Ptr{Cdouble}}, incy::Cint, batchCount::Cint)::cublasStatus_t end @@ -3295,10 +3295,10 @@ end incx, beta, yarray, incy, batchCount) initialize_context() @ccall libcublas.cublasDgemvBatched_64(handle::cublasHandle_t, trans::cublasOperation_t, - m::Int64, n::Int64, alpha::Ptr{Cdouble}, - Aarray::Ptr{Ptr{Cdouble}}, lda::Int64, - xarray::Ptr{Ptr{Cdouble}}, incx::Int64, - beta::Ptr{Cdouble}, yarray::Ptr{Ptr{Cdouble}}, + m::Int64, n::Int64, alpha::RefOrCuRef{Cdouble}, + Aarray::CuPtr{Ptr{Cdouble}}, lda::Int64, + xarray::CuPtr{Ptr{Cdouble}}, incx::Int64, + beta::RefOrCuRef{Cdouble}, yarray::CuPtr{Ptr{Cdouble}}, incy::Int64, batchCount::Int64)::cublasStatus_t end @@ -3306,10 +3306,10 @@ end beta, yarray, incy, batchCount) initialize_context() @ccall libcublas.cublasCgemvBatched(handle::cublasHandle_t, trans::cublasOperation_t, - m::Cint, n::Cint, alpha::Ptr{cuComplex}, - Aarray::Ptr{Ptr{cuComplex}}, lda::Cint, - xarray::Ptr{Ptr{cuComplex}}, incx::Cint, - beta::Ptr{cuComplex}, yarray::Ptr{Ptr{cuComplex}}, + m::Cint, n::Cint, alpha::RefOrCuRef{cuComplex}, + Aarray::CuPtr{Ptr{cuComplex}}, lda::Cint, + xarray::CuPtr{Ptr{cuComplex}}, incx::Cint, + beta::RefOrCuRef{cuComplex}, yarray::CuPtr{Ptr{cuComplex}}, incy::Cint, batchCount::Cint)::cublasStatus_t end @@ -3317,11 +3317,11 @@ end incx, beta, yarray, incy, batchCount) initialize_context() @ccall libcublas.cublasCgemvBatched_64(handle::cublasHandle_t, trans::cublasOperation_t, - m::Int64, n::Int64, alpha::Ptr{cuComplex}, - Aarray::Ptr{Ptr{cuComplex}}, lda::Int64, - xarray::Ptr{Ptr{cuComplex}}, incx::Int64, - beta::Ptr{cuComplex}, - yarray::Ptr{Ptr{cuComplex}}, incy::Int64, + m::Int64, n::Int64, alpha::RefOrCuRef{cuComplex}, + Aarray::CuPtr{Ptr{cuComplex}}, lda::Int64, + xarray::CuPtr{Ptr{cuComplex}}, incx::Int64, + beta::RefOrCuRef{cuComplex}, + yarray::CuPtr{Ptr{cuComplex}}, incy::Int64, batchCount::Int64)::cublasStatus_t end @@ -3329,11 +3329,11 @@ end beta, yarray, incy, batchCount) initialize_context() @ccall libcublas.cublasZgemvBatched(handle::cublasHandle_t, trans::cublasOperation_t, - m::Cint, n::Cint, alpha::Ptr{cuDoubleComplex}, - Aarray::Ptr{Ptr{cuDoubleComplex}}, lda::Cint, - xarray::Ptr{Ptr{cuDoubleComplex}}, incx::Cint, - beta::Ptr{cuDoubleComplex}, - yarray::Ptr{Ptr{cuDoubleComplex}}, incy::Cint, + m::Cint, n::Cint, alpha::RefOrCuRef{cuDoubleComplex}, + Aarray::CuPtr{Ptr{cuDoubleComplex}}, lda::Cint, + xarray::CuPtr{Ptr{cuDoubleComplex}}, incx::Cint, + beta::RefOrCuRef{cuDoubleComplex}, + yarray::CuPtr{Ptr{cuDoubleComplex}}, incy::Cint, batchCount::Cint)::cublasStatus_t end @@ -3341,11 +3341,11 @@ end incx, beta, yarray, incy, batchCount) initialize_context() @ccall libcublas.cublasZgemvBatched_64(handle::cublasHandle_t, trans::cublasOperation_t, - m::Int64, n::Int64, alpha::Ptr{cuDoubleComplex}, - Aarray::Ptr{Ptr{cuDoubleComplex}}, lda::Int64, - xarray::Ptr{Ptr{cuDoubleComplex}}, incx::Int64, - beta::Ptr{cuDoubleComplex}, - yarray::Ptr{Ptr{cuDoubleComplex}}, incy::Int64, + m::Int64, n::Int64, alpha::RefOrCuRef{cuDoubleComplex}, + Aarray::CuPtr{Ptr{cuDoubleComplex}}, lda::Int64, + xarray::CuPtr{Ptr{cuDoubleComplex}}, incx::Int64, + beta::RefOrCuRef{cuDoubleComplex}, + yarray::CuPtr{Ptr{cuDoubleComplex}}, incy::Int64, batchCount::Int64)::cublasStatus_t end @@ -3355,11 +3355,11 @@ end initialize_context() @ccall libcublas.cublasSgemvStridedBatched(handle::cublasHandle_t, trans::cublasOperation_t, m::Cint, n::Cint, - alpha::Ptr{Cfloat}, A::Ptr{Cfloat}, + alpha::RefOrCuRef{Cfloat}, A::CuPtr{Cfloat}, lda::Cint, strideA::Clonglong, - x::Ptr{Cfloat}, incx::Cint, - stridex::Clonglong, beta::Ptr{Cfloat}, - y::Ptr{Cfloat}, incy::Cint, + x::CuPtr{Cfloat}, incx::Cint, + stridex::Clonglong, beta::RefOrCuRef{Cfloat}, + y::CuPtr{Cfloat}, incy::Cint, stridey::Clonglong, batchCount::Cint)::cublasStatus_t end @@ -3370,11 +3370,11 @@ end initialize_context() @ccall libcublas.cublasSgemvStridedBatched_64(handle::cublasHandle_t, trans::cublasOperation_t, m::Int64, - n::Int64, alpha::Ptr{Cfloat}, - A::Ptr{Cfloat}, lda::Int64, - strideA::Clonglong, x::Ptr{Cfloat}, + n::Int64, alpha::RefOrCuRef{Cfloat}, + A::CuPtr{Cfloat}, lda::Int64, + strideA::Clonglong, x::CuPtr{Cfloat}, incx::Int64, stridex::Clonglong, - beta::Ptr{Cfloat}, y::Ptr{Cfloat}, + beta::RefOrCuRef{Cfloat}, y::CuPtr{Cfloat}, incy::Int64, stridey::Clonglong, batchCount::Int64)::cublasStatus_t end @@ -3385,11 +3385,11 @@ end initialize_context() @ccall libcublas.cublasDgemvStridedBatched(handle::cublasHandle_t, trans::cublasOperation_t, m::Cint, n::Cint, - alpha::Ptr{Cdouble}, A::Ptr{Cdouble}, + alpha::RefOrCuRef{Cdouble}, A::CuPtr{Cdouble}, lda::Cint, strideA::Clonglong, - x::Ptr{Cdouble}, incx::Cint, - stridex::Clonglong, beta::Ptr{Cdouble}, - y::Ptr{Cdouble}, incy::Cint, + x::CuPtr{Cdouble}, incx::Cint, + stridex::Clonglong, beta::RefOrCuRef{Cdouble}, + y::CuPtr{Cdouble}, incy::Cint, stridey::Clonglong, batchCount::Cint)::cublasStatus_t end @@ -3400,11 +3400,11 @@ end initialize_context() @ccall libcublas.cublasDgemvStridedBatched_64(handle::cublasHandle_t, trans::cublasOperation_t, m::Int64, - n::Int64, alpha::Ptr{Cdouble}, - A::Ptr{Cdouble}, lda::Int64, - strideA::Clonglong, x::Ptr{Cdouble}, + n::Int64, alpha::RefOrCuRef{Cdouble}, + A::CuPtr{Cdouble}, lda::Int64, + strideA::Clonglong, x::CuPtr{Cdouble}, incx::Int64, stridex::Clonglong, - beta::Ptr{Cdouble}, y::Ptr{Cdouble}, + beta::RefOrCuRef{Cdouble}, y::CuPtr{Cdouble}, incy::Int64, stridey::Clonglong, batchCount::Int64)::cublasStatus_t end @@ -3415,11 +3415,11 @@ end initialize_context() @ccall libcublas.cublasCgemvStridedBatched(handle::cublasHandle_t, trans::cublasOperation_t, m::Cint, n::Cint, - alpha::Ptr{cuComplex}, A::Ptr{cuComplex}, + alpha::RefOrCuRef{cuComplex}, A::CuPtr{cuComplex}, lda::Cint, strideA::Clonglong, - x::Ptr{cuComplex}, incx::Cint, - stridex::Clonglong, beta::Ptr{cuComplex}, - y::Ptr{cuComplex}, incy::Cint, + x::CuPtr{cuComplex}, incx::Cint, + stridex::Clonglong, beta::RefOrCuRef{cuComplex}, + y::CuPtr{cuComplex}, incy::Cint, stridey::Clonglong, batchCount::Cint)::cublasStatus_t end @@ -3430,11 +3430,11 @@ end initialize_context() @ccall libcublas.cublasCgemvStridedBatched_64(handle::cublasHandle_t, trans::cublasOperation_t, m::Int64, - n::Int64, alpha::Ptr{cuComplex}, - A::Ptr{cuComplex}, lda::Int64, - strideA::Clonglong, x::Ptr{cuComplex}, + n::Int64, alpha::RefOrCuRef{cuComplex}, + A::CuPtr{cuComplex}, lda::Int64, + strideA::Clonglong, x::CuPtr{cuComplex}, incx::Int64, stridex::Clonglong, - beta::Ptr{cuComplex}, y::Ptr{cuComplex}, + beta::RefOrCuRef{cuComplex}, y::CuPtr{cuComplex}, incy::Int64, stridey::Clonglong, batchCount::Int64)::cublasStatus_t end @@ -3445,12 +3445,12 @@ end initialize_context() @ccall libcublas.cublasZgemvStridedBatched(handle::cublasHandle_t, trans::cublasOperation_t, m::Cint, n::Cint, - alpha::Ptr{cuDoubleComplex}, - A::Ptr{cuDoubleComplex}, lda::Cint, - strideA::Clonglong, x::Ptr{cuDoubleComplex}, + alpha::RefOrCuRef{cuDoubleComplex}, + A::CuPtr{cuDoubleComplex}, lda::Cint, + strideA::Clonglong, x::CuPtr{cuDoubleComplex}, incx::Cint, stridex::Clonglong, - beta::Ptr{cuDoubleComplex}, - y::Ptr{cuDoubleComplex}, incy::Cint, + beta::RefOrCuRef{cuDoubleComplex}, + y::CuPtr{cuDoubleComplex}, incy::Cint, stridey::Clonglong, batchCount::Cint)::cublasStatus_t end @@ -3461,13 +3461,13 @@ end initialize_context() @ccall libcublas.cublasZgemvStridedBatched_64(handle::cublasHandle_t, trans::cublasOperation_t, m::Int64, - n::Int64, alpha::Ptr{cuDoubleComplex}, - A::Ptr{cuDoubleComplex}, lda::Int64, + n::Int64, alpha::RefOrCuRef{cuDoubleComplex}, + A::CuPtr{cuDoubleComplex}, lda::Int64, strideA::Clonglong, - x::Ptr{cuDoubleComplex}, incx::Int64, + x::CuPtr{cuDoubleComplex}, incx::Int64, stridex::Clonglong, - beta::Ptr{cuDoubleComplex}, - y::Ptr{cuDoubleComplex}, incy::Int64, + beta::RefOrCuRef{cuDoubleComplex}, + y::CuPtr{cuDoubleComplex}, incy::Int64, stridey::Clonglong, batchCount::Int64)::cublasStatus_t end @@ -5302,10 +5302,10 @@ end incx, beta, yarray, incy, batchCount) initialize_context() @ccall libcublas.cublasHSHgemvBatched(handle::cublasHandle_t, trans::cublasOperation_t, - m::Cint, n::Cint, alpha::Ptr{Cfloat}, - Aarray::Ptr{Ptr{Float16}}, lda::Cint, - xarray::Ptr{Ptr{Float16}}, incx::Cint, - beta::Ptr{Cfloat}, yarray::Ptr{Ptr{Float16}}, + m::Cint, n::Cint, alpha::RefOrCuRef{Cfloat}, + Aarray::CuPtr{Ptr{Float16}}, lda::Cint, + xarray::CuPtr{Ptr{Float16}}, incx::Cint, + beta::RefOrCuRef{Cfloat}, yarray::CuPtr{Ptr{Float16}}, incy::Cint, batchCount::Cint)::cublasStatus_t end @@ -5313,10 +5313,10 @@ end incx, beta, yarray, incy, batchCount) initialize_context() @ccall libcublas.cublasHSSgemvBatched(handle::cublasHandle_t, trans::cublasOperation_t, - m::Cint, n::Cint, alpha::Ptr{Cfloat}, - Aarray::Ptr{Ptr{Float16}}, lda::Cint, - xarray::Ptr{Ptr{Float16}}, incx::Cint, - beta::Ptr{Cfloat}, yarray::Ptr{Ptr{Cfloat}}, + m::Cint, n::Cint, alpha::RefOrCuRef{Cfloat}, + Aarray::CuPtr{Ptr{Float16}}, lda::Cint, + xarray::CuPtr{Ptr{Float16}}, incx::Cint, + beta::RefOrCuRef{Cfloat}, yarray::CuPtr{Ptr{Cfloat}}, incy::Cint, batchCount::Cint)::cublasStatus_t end @@ -5348,11 +5348,11 @@ end initialize_context() @ccall libcublas.cublasHSHgemvStridedBatched(handle::cublasHandle_t, trans::cublasOperation_t, m::Cint, n::Cint, - alpha::Ptr{Cfloat}, A::Ptr{Float16}, + alpha::RefOrCuRef{Cfloat}, A::CuPtr{Float16}, lda::Cint, strideA::Clonglong, - x::Ptr{Float16}, incx::Cint, - stridex::Clonglong, beta::Ptr{Cfloat}, - y::Ptr{Float16}, incy::Cint, + x::CuPtr{Float16}, incx::Cint, + stridex::Clonglong, beta::RefOrCuRef{Cfloat}, + y::CuPtr{Float16}, incy::Cint, stridey::Clonglong, batchCount::Cint)::cublasStatus_t end @@ -5363,11 +5363,11 @@ end initialize_context() @ccall libcublas.cublasHSSgemvStridedBatched(handle::cublasHandle_t, trans::cublasOperation_t, m::Cint, n::Cint, - alpha::Ptr{Cfloat}, A::Ptr{Float16}, + alpha::RefOrCuRef{Cfloat}, A::CuPtr{Float16}, lda::Cint, strideA::Clonglong, - x::Ptr{Float16}, incx::Cint, - stridex::Clonglong, beta::Ptr{Cfloat}, - y::Ptr{Cfloat}, incy::Cint, + x::CuPtr{Float16}, incx::Cint, + stridex::Clonglong, beta::RefOrCuRef{Cfloat}, + y::CuPtr{Cfloat}, incy::Cint, stridey::Clonglong, batchCount::Cint)::cublasStatus_t end @@ -5378,11 +5378,11 @@ end initialize_context() @ccall libcublas.cublasTSTgemvStridedBatched(handle::cublasHandle_t, trans::cublasOperation_t, m::Cint, n::Cint, - alpha::Ptr{Cfloat}, A::Ptr{BFloat16}, + alpha::RefOrCuRef{Cfloat}, A::CuPtr{BFloat16}, lda::Cint, strideA::Clonglong, - x::Ptr{BFloat16}, incx::Cint, - stridex::Clonglong, beta::Ptr{Cfloat}, - y::Ptr{BFloat16}, incy::Cint, + x::CuPtr{BFloat16}, incx::Cint, + stridex::Clonglong, beta::RefOrCuRef{Cfloat}, + y::CuPtr{BFloat16}, incy::Cint, stridey::Clonglong, batchCount::Cint)::cublasStatus_t end @@ -5393,11 +5393,11 @@ end initialize_context() @ccall libcublas.cublasTSSgemvStridedBatched(handle::cublasHandle_t, trans::cublasOperation_t, m::Cint, n::Cint, - alpha::Ptr{Cfloat}, A::Ptr{BFloat16}, + alpha::RefOrCuRef{Cfloat}, A::CuPtr{BFloat16}, lda::Cint, strideA::Clonglong, - x::Ptr{BFloat16}, incx::Cint, - stridex::Clonglong, beta::Ptr{Cfloat}, - y::Ptr{Cfloat}, incy::Cint, + x::CuPtr{BFloat16}, incx::Cint, + stridex::Clonglong, beta::RefOrCuRef{Cfloat}, + y::CuPtr{Cfloat}, incy::Cint, stridey::Clonglong, batchCount::Cint)::cublasStatus_t end diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 369600707a..90d1469f64 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -321,28 +321,107 @@ for (fname, elty) in ((:cublasDgemv_v2,:Float64), function gemv!(trans::Char, alpha::Number, A::StridedCuMatrix{$elty}, - X::StridedCuVector{$elty}, + x::StridedCuVector{$elty}, beta::Number, - Y::StridedCuVector{$elty}) + y::StridedCuVector{$elty}) # handle trans m,n = size(A) # check dimensions - length(X) == (trans == 'N' ? n : m) && length(Y) == (trans == 'N' ? m : n) || throw(DimensionMismatch("")) + length(x) == (trans == 'N' ? n : m) && length(y) == (trans == 'N' ? m : n) || throw(DimensionMismatch("")) # compute increments lda = max(1,stride(A,2)) - incx = stride(X,1) - incy = stride(Y,1) - $fname(handle(), trans, m, n, alpha, A, lda, X, incx, beta, Y, incy) - Y + incx = stride(x,1) + incy = stride(y,1) + $fname(handle(), trans, m, n, alpha, A, lda, x, incx, beta, y, incy) + y end end end function gemv(trans::Char, alpha::Number, - A::StridedCuMatrix{T}, X::StridedCuVector{T}) where T - gemv!(trans, alpha, A, X, zero(T), similar(X, size(A, (trans == 'N' ? 1 : 2)))) + A::StridedCuMatrix{T}, x::StridedCuVector{T}) where T + gemv!(trans, alpha, A, x, zero(T), similar(x, size(A, (trans == 'N' ? 1 : 2)))) end -function gemv(trans::Char, A::StridedCuMatrix{T}, X::StridedCuVector{T}) where T - gemv!(trans, one(T), A, X, zero(T), similar(X, T, size(A, (trans == 'N' ? 1 : 2)))) +function gemv(trans::Char, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where T + gemv!(trans, one(T), A, x, zero(T), similar(x, T, size(A, (trans == 'N' ? 1 : 2)))) +end + +for (fname, eltyin, eltyout) in + ((:cublasDgemvBatched,:Float64, :Float64), + (:cublasSgemvBatched,:Float32, :Float32), + (:cublasHSHgemvBatched,:Float16, :Float16), + (:cublasHSSgemvBatched,:Float16, :Float32), + (:cublasZgemvBatched,:ComplexF64, :ComplexF64), + (:cublasCgemvBatched,:ComplexF32, :ComplexF32)) + @eval begin + function gemv_batched!(trans::Char, + alpha::Number, + A::Vector{<:StridedCuMatrix{$eltyin}}, + x::Vector{<:StridedCuVector{$eltyin}}, + beta::Number, + y::Vector{<:StridedCuVector{$eltyout}}) + if length(A) != length(x) || length(A) != length(y) + throw(DimensionMismatch("Lengths of inputs must be the same")) + end + for (i, (As,xs,ys)) in enumerate(zip(A,x,y)) + m,n = size(As) + if length(xs) != (trans == 'N' ? n : m) || length(ys) != (trans == 'N' ? m : n) + throw(DimensionMismatch("Input $i: A has dimension $(size(As)), x has dimension $(size(xs)), y has dimension $(size(ys))")) + end + end + + m = size(A[1], trans == 'N' ? 1 : 2) + n = size(A[1], trans == 'N' ? 2 : 1) + lda = max(1,stride(A[1],2)) + incx = stride(x[1],1) + incy = stride(y[1],1) + Aptrs = unsafe_batch(A) + xptrs = unsafe_batch(x) + yptrs = unsafe_batch(y) + $fname(handle(), trans, m, n, alpha, Aptrs, lda, xptrs, incx, beta, yptrs, incy, length(A)) + unsafe_free!(yptrs) + unsafe_free!(xptrs) + unsafe_free!(Aptrs) + + y + end + end +end + +for (fname, eltyin, eltyout) in + ((:cublasDgemvStridedBatched,:Float64, :Float64), + (:cublasSgemvStridedBatched,:Float32, :Float32), + (:cublasHSHgemvStridedBatched,:Float16, :Float16), + (:cublasHSSgemvStridedBatched,:Float16, :Float32), + (:cublasZgemvStridedBatched,:ComplexF64, :ComplexF64), + (:cublasCgemvStridedBatched,:ComplexF32, :ComplexF32)) + @eval begin + function gemv_strided_batched!(trans::Char, + alpha::Number, + A::AbstractArray{$eltyin, 3}, + x::AbstractArray{$eltyin, 2}, + beta::Number, + y::AbstractArray{$eltyout, 2}) + if size(A, 3) != size(x, 2) || size(A, 3) != size(y, 2) + throw(DimensionMismatch("Batch sizes must be equal for all inputs")) + end + m = size(A, trans == 'N' ? 1 : 2) + n = size(A, trans == 'N' ? 2 : 1) + if m != size(y, 1) || n != size(x, 1) + throw(DimensionMismatch("A has dimension $(size(A)), x has dimension $(size(x)), y has dimension $(size(y))")) + end + + lda = max(1,stride(A, 2)) + incx = stride(x,1) + incy = stride(y,1) + strideA = size(A, 3) == 1 ? 0 : stride(A, 3) + stridex = size(x, 2) == 1 ? 0 : stride(x, 2) + stridey = stride(y, 2) + batchCount = size(A, 3) + $fname(handle(), trans, m, n, alpha, A, lda, strideA, x, incx, stridex, beta, y, incy, stridey, batchCount) + + y + end + end end ### (GB) general banded matrix-vector multiplication @@ -1030,7 +1109,7 @@ function gemm_strided_batched(transA::Char, transB::Char, A::AbstractArray{T, 3} gemm_strided_batched(transA, transB, one(T), A, B) end -## (SY) symmetric matrix-matrix and matrix-vector multiplication +## (Sy) symmetric matrix-matrix and matrix-vector multiplication for (fname, elty) in ((:cublasDsymm_v2,:Float64), (:cublasSsymm_v2,:Float32), (:cublasZsymm_v2,:ComplexF64), @@ -1993,6 +2072,7 @@ for (fname, elty) in ((:cublasXtZher2k,:ComplexF64), end end end + function xt_her2k(uplo::Char, trans::Char, alpha::Number, A::Union{StridedVecOrMat{T}, StridedCuVecOrMat{T}}, B::Union{StridedVecOrMat{T}, StridedCuVecOrMat{T}}) where T diff --git a/res/wrap/cublas.toml b/res/wrap/cublas.toml index 081a7f45f9..aa8c399031 100644 --- a/res/wrap/cublas.toml +++ b/res/wrap/cublas.toml @@ -505,6 +505,153 @@ needs_context = false 6 = "CuPtr{cuDoubleComplex}" 7 = "CuPtr{cuDoubleComplex}" +[api.cublasSgemvBatched.argtypes] +5 = "RefOrCuRef{Cfloat}" +6 = "CuPtr{Ptr{Cfloat}}" +8 = "CuPtr{Ptr{Cfloat}}" +10 = "RefOrCuRef{Cfloat}" +11 = "CuPtr{Ptr{Cfloat}}" + +[api.cublasSgemvBatched_64.argtypes] +5 = "RefOrCuRef{Cfloat}" +6 = "CuPtr{Ptr{Cfloat}}" +8 = "CuPtr{Ptr{Cfloat}}" +10 = "RefOrCuRef{Cfloat}" +11 = "CuPtr{Ptr{Cfloat}}" + +[api.cublasDgemvBatched.argtypes] +5 = "RefOrCuRef{Cdouble}" +6 = "CuPtr{Ptr{Cdouble}}" +8 = "CuPtr{Ptr{Cdouble}}" +10 = "RefOrCuRef{Cdouble}" +11 = "CuPtr{Ptr{Cdouble}}" + +[api.cublasDgemvBatched_64.argtypes] +5 = "RefOrCuRef{Cdouble}" +6 = "CuPtr{Ptr{Cdouble}}" +8 = "CuPtr{Ptr{Cdouble}}" +10 = "RefOrCuRef{Cdouble}" +11 = "CuPtr{Ptr{Cdouble}}" + +[api.cublasCgemvBatched.argtypes] +5 = "RefOrCuRef{cuComplex}" +6 = "CuPtr{Ptr{cuComplex}}" +8 = "CuPtr{Ptr{cuComplex}}" +10 = "RefOrCuRef{cuComplex}" +11 = "CuPtr{Ptr{cuComplex}}" + +[api.cublasCgemvBatched_64.argtypes] +5 = "RefOrCuRef{cuComplex}" +6 = "CuPtr{Ptr{cuComplex}}" +8 = "CuPtr{Ptr{cuComplex}}" +10 = "RefOrCuRef{cuComplex}" +11 = "CuPtr{Ptr{cuComplex}}" + +[api.cublasZgemvBatched.argtypes] +5 = "RefOrCuRef{cuDoubleComplex}" +6 = "CuPtr{Ptr{cuDoubleComplex}}" +8 = "CuPtr{Ptr{cuDoubleComplex}}" +10 = "RefOrCuRef{cuDoubleComplex}" +11 = "CuPtr{Ptr{cuDoubleComplex}}" + +[api.cublasZgemvBatched_64.argtypes] +5 = "RefOrCuRef{cuDoubleComplex}" +6 = "CuPtr{Ptr{cuDoubleComplex}}" +8 = "CuPtr{Ptr{cuDoubleComplex}}" +10 = "RefOrCuRef{cuDoubleComplex}" +11 = "CuPtr{Ptr{cuDoubleComplex}}" + +[api.cublasHSHgemvBatched.argtypes] +5 = "RefOrCuRef{Cfloat}" +6 = "CuPtr{Ptr{Float16}}" +8 = "CuPtr{Ptr{Float16}}" +10 = "RefOrCuRef{Cfloat}" +11 = "CuPtr{Ptr{Float16}}" + +[api.cublasHSSgemvBatched.argtypes] +5 = "RefOrCuRef{Cfloat}" +6 = "CuPtr{Ptr{Float16}}" +8 = "CuPtr{Ptr{Float16}}" +10 = "RefOrCuRef{Cfloat}" +11 = "CuPtr{Ptr{Cfloat}}" + +[api.cublasSgemvStridedBatched.argtypes] +5 = "RefOrCuRef{Cfloat}" +6 = "CuPtr{Cfloat}" +9 = "CuPtr{Cfloat}" +12 = "RefOrCuRef{Cfloat}" +13 = "CuPtr{Cfloat}" + +[api.cublasSgemvStridedBatched_64.argtypes] +5 = "RefOrCuRef{Cfloat}" +6 = "CuPtr{Cfloat}" +9 = "CuPtr{Cfloat}" +12 = "RefOrCuRef{Cfloat}" +13 = "CuPtr{Cfloat}" + +[api.cublasDgemvStridedBatched.argtypes] +5 = "RefOrCuRef{Cdouble}" +6 = "CuPtr{Cdouble}" +9 = "CuPtr{Cdouble}" +12 = "RefOrCuRef{Cdouble}" +13 = "CuPtr{Cdouble}" + +[api.cublasDgemvStridedBatched_64.argtypes] +5 = "RefOrCuRef{Cdouble}" +6 = "CuPtr{Cdouble}" +9 = "CuPtr{Cdouble}" +12 = "RefOrCuRef{Cdouble}" +13 = "CuPtr{Cdouble}" + +[api.cublasCgemvStridedBatched.argtypes] +5 = "RefOrCuRef{cuComplex}" +6 = "CuPtr{cuComplex}" +9 = "CuPtr{cuComplex}" +12 = "RefOrCuRef{cuComplex}" +13 = "CuPtr{cuComplex}" + +[api.cublasCgemvStridedBatched_64.argtypes] +5 = "RefOrCuRef{cuComplex}" +6 = "CuPtr{cuComplex}" +9 = "CuPtr{cuComplex}" +12 = "RefOrCuRef{cuComplex}" +13 = "CuPtr{cuComplex}" + +[api.cublasZgemvStridedBatched.argtypes] +5 = "RefOrCuRef{cuDoubleComplex}" +6 = "CuPtr{cuDoubleComplex}" +9 = "CuPtr{cuDoubleComplex}" +12 = "RefOrCuRef{cuDoubleComplex}" +13 = "CuPtr{cuDoubleComplex}" + +[api.cublasZgemvStridedBatched_64.argtypes] +5 = "RefOrCuRef{cuDoubleComplex}" +6 = "CuPtr{cuDoubleComplex}" +9 = "CuPtr{cuDoubleComplex}" +12 = "RefOrCuRef{cuDoubleComplex}" +13 = "CuPtr{cuDoubleComplex}" + +[api.cublasHSSgemvStridedBatched.argtypes] +5 = "RefOrCuRef{Cfloat}" +6 = "CuPtr{Float16}" +9 = "CuPtr{Float16}" +12 = "RefOrCuRef{Cfloat}" +13 = "CuPtr{Cfloat}" + +[api.cublasTSTgemvStridedBatched.argtypes] +5 = "RefOrCuRef{Cfloat}" +6 = "CuPtr{BFloat16}" +9 = "CuPtr{BFloat16}" +12 = "RefOrCuRef{Cfloat}" +13 = "CuPtr{BFloat16}" + +[api.cublasTSSgemvStridedBatched.argtypes] +5 = "RefOrCuRef{Cfloat}" +6 = "CuPtr{BFloat16}" +9 = "CuPtr{BFloat16}" +12 = "RefOrCuRef{Cfloat}" +13 = "CuPtr{Cfloat}" + [api.cublasStrsv_v2.argtypes] 6 = "CuPtr{Cfloat}" 8 = "CuPtr{Cfloat}" diff --git a/test/libraries/cublas.jl b/test/libraries/cublas.jl index ea0019a67c..7807ee8afe 100644 --- a/test/libraries/cublas.jl +++ b/test/libraries/cublas.jl @@ -125,6 +125,56 @@ end @test hy ≈ A * x end + if CUBLAS.version() >= v"11.9" + @testset "gemv_batched" begin + x = [rand(elty, m) for i=1:10] + A = [rand(elty, n, m) for i=1:10] + y = [rand(elty, n) for i=1:10] + dx = CuArray{elty, 1}[] + dA = CuArray{elty, 2}[] + dy = CuArray{elty, 1}[] + dbad = CuArray{elty, 1}[] + for i=1:length(A) + push!(dA, CuArray(A[i])) + push!(dx, CuArray(x[i])) + push!(dy, CuArray(y[i])) + if i < length(A) - 2 + push!(dbad,CuArray(dx[i])) + end + end + @test_throws DimensionMismatch CUBLAS.gemv_batched!('N', alpha, dA, dx, beta, dbad) + CUBLAS.gemv_batched!('N', alpha, dA, dx, beta, dy) + for i=1:length(A) + hy = collect(dy[i]) + y[i] = alpha * A[i] * x[i] + beta * y[i] + @test y[i] ≈ hy + end + end + end + + if CUBLAS.version() >= v"11.9" + @testset "gemv_strided_batched" begin + x = rand(elty, m, 10) + A = rand(elty, n, m, 10) + y = rand(elty, n, 10) + bad = rand(elty, m, 10) + dx = CuArray(x) + dA = CuArray(A) + dy = CuArray(y) + dbad = CuArray(bad) + @test_throws DimensionMismatch CUBLAS.gemv_strided_batched!('N', alpha, dA, dx, beta, dbad) + bad = rand(elty, n, 2) + dbad = CuArray(bad) + @test_throws DimensionMismatch CUBLAS.gemv_strided_batched!('N', alpha, dA, dx, beta, dbad) + CUBLAS.gemv_strided_batched!('N', alpha, dA, dx, beta, dy) + for i=1:size(A, 3) + hy = collect(dy[:, i]) + y[:, i] = alpha * A[:, :, i] * x[:, i] + beta * y[:, i] + @test y[:, i] ≈ hy + end + end + end + @testset "mul! y = $f(A) * x * $Ts(a) + y * $Ts(b)" for f in (identity, transpose, adjoint), Ts in (Int, elty) y, A, x = rand(elty, 5), rand(elty, 5, 5), rand(elty, 5) dy, dA, dx = CuArray(y), CuArray(A), CuArray(x)