diff --git a/lib/cublas/linalg.jl b/lib/cublas/linalg.jl index a5a6cd25b5..8280ff5c79 100644 --- a/lib/cublas/linalg.jl +++ b/lib/cublas/linalg.jl @@ -173,8 +173,10 @@ end # # GEMV - -function LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::StridedCuMatrix, B::StridedCuVector, _add::MulAddMul) +# legacy method +LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::StridedCuMatrix, B::StridedCuVector, _add::MulAddMul) = + LinearAlgebra.generic_matvecmul!(Y, tA, A, B, _add.alpha, _add.beta) +function LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::StridedCuMatrix, B::StridedCuVector, alpha::Number, beta::Number) mA, nA = tA == 'N' ? size(A) : reverse(size(A)) if nA != length(B) @@ -194,7 +196,6 @@ function LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::Stri end T = eltype(Y) - alpha, beta = _add.alpha, _add.beta if alpha isa Union{Bool,T} && beta isa Union{Bool,T} if T <: CublasFloat && eltype(A) == eltype(B) == T if tA in ('N', 'T', 'C') @@ -206,7 +207,7 @@ function LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::Stri end end end - LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, MulAddMul(alpha, beta)) + LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, alpha, beta) end if VERSION < v"1.10.0-DEV.1365" @@ -282,10 +283,10 @@ end # VERSION # # GEMM - -function LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::StridedCuVecOrMat, B::StridedCuVecOrMat, _add::MulAddMul) +LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::StridedCuVecOrMat, B::StridedCuVecOrMat, _add::MulAddMul) = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) +function LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::StridedCuVecOrMat, B::StridedCuVecOrMat, alpha::Number, beta::Number) T = eltype(C) - alpha, beta = _add.alpha, _add.beta mA, nA = size(A, tA == 'N' ? 1 : 2), size(A, tA == 'N' ? 2 : 1) mB, nB = size(B, tB == 'N' ? 1 : 2), size(B, tB == 'N' ? 2 : 1) diff --git a/lib/cusparse/interfaces.jl b/lib/cusparse/interfaces.jl index 686a4a9c22..e6bbe22a1b 100644 --- a/lib/cusparse/interfaces.jl +++ b/lib/cusparse/interfaces.jl @@ -61,19 +61,27 @@ op_wrappers = ((identity, T -> 'N', identity), (T -> :(Adjoint{T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))), (T -> :(HermOrSym{T, <:$T}), T -> 'N', A -> :(parent($A)))) -function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::DenseCuVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} +# legacy methods with final MulAddMul argument +LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::DenseCuVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} = + LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) +LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::CuSparseVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} = + LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) +LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::CuSparseMatrix{T}, B::DenseCuMatrix{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) + +function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::DenseCuVector{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA - mv_wrapper(tA, _add.alpha, A, B, _add.beta, C) + mv_wrapper(tA, alpha, A, B, beta, C) end -function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::CuSparseVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} +function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::CuSparseVector{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA - mv_wrapper(tA, _add.alpha, A, CuVector{T}(B), _add.beta, C) + mv_wrapper(tA, alpha, A, CuVector{T}(B), beta, C) end -function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::CuSparseMatrix{T}, B::DenseCuMatrix{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} +function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::CuSparseMatrix{T}, B::DenseCuMatrix{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB - mm_wrapper(tA, tB, _add.alpha, A, B, _add.beta, C) + mm_wrapper(tA, tB, alpha, A, B, beta, C) end for (wrapa, transa, unwrapa) in op_wrappers @@ -87,25 +95,36 @@ for (wrapa, transa, unwrapa) in op_wrappers end end -function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::DenseCuMatrix{T}, B::CuSparseVector{T}, _add::MulAddMul) where {T <: BlasFloat} +# legacy methods with final MulAddMul argument +LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::DenseCuMatrix{T}, B::CuSparseVector{T}, _add::MulAddMul) where {T <: BlasFloat} = + LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) + +LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSC{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) +LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSR{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) +LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCOO{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) + +function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::DenseCuMatrix{T}, B::CuSparseVector{T}, alpha::Number, beta::Number) where {T <: BlasFloat} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA - gemvi!(tA, _add.alpha, A, B, _add.beta, C, 'O') + gemvi!(tA, alpha, A, B, beta, C, 'O') end -function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSC{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} +function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSC{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB - mm!(tA, tB, _add.alpha, A, B, _add.beta, C, 'O') + mm!(tA, tB, alpha, A, B, beta, C, 'O') end -function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSR{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} +function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSR{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB - mm!(tA, tB, _add.alpha, A, B, _add.beta, C, 'O') + mm!(tA, tB, alpha, A, B, beta, C, 'O') end -function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCOO{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} +function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCOO{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB - mm!(tA, tB, _add.alpha, A, B, _add.beta, C, 'O') + mm!(tA, tB, alpha, A, B, beta, C, 'O') end for (wrapa, transa, unwrapa) in op_wrappers @@ -143,23 +162,31 @@ for (wrapa, transa, unwrapa) in op_wrappers end end -function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSC{T}, tA, tB, A::CuSparseMatrixCSC{T}, B::CuSparseMatrixCSC{T}, _add::MulAddMul) where {T <: BlasFloat} +# legacy methods with final MulAddMul argument +LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSC{T}, tA, tB, A::CuSparseMatrixCSC{T}, B::CuSparseMatrixCSC{T}, _add::MulAddMul) where {T <: BlasFloat} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) +LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSR{T}, tA, tB, A::CuSparseMatrixCSR{T}, B::CuSparseMatrixCSR{T}, _add::MulAddMul) where {T <: BlasFloat} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) +LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCOO{T}, tA, tB, A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}, _add::MulAddMul) where {T <: BlasFloat} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) + +function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSC{T}, tA, tB, A::CuSparseMatrixCSC{T}, B::CuSparseMatrixCSC{T}, alpha::Number, beta::Number) where {T <: BlasFloat} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB - gemm!(tA, tB, _add.alpha, A, B, _add.beta, C, 'O') + gemm!(tA, tB, alpha, A, B, beta, C, 'O') end -function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSR{T}, tA, tB, A::CuSparseMatrixCSR{T}, B::CuSparseMatrixCSR{T}, _add::MulAddMul) where {T <: BlasFloat} +function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSR{T}, tA, tB, A::CuSparseMatrixCSR{T}, B::CuSparseMatrixCSR{T}, alpha::Number, beta::Number) where {T <: BlasFloat} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB - gemm!(tA, tB, _add.alpha, A, B, _add.beta, C, 'O') + gemm!(tA, tB, alpha, A, B, beta, C, 'O') end -function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCOO{T}, tA, tB, A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}, _add::MulAddMul) where {T <: BlasFloat} +function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCOO{T}, tA, tB, A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}, alpha::Number, beta::Number) where {T <: BlasFloat} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB A_csr = CuSparseMatrixCSR(A) B_csr = CuSparseMatrixCSR(B) C_csr = CuSparseMatrixCSR(C) - generic_matmatmul!(C_csr, tA, tB, A_csr, B_csr, _add.alpha, _add.beta) + generic_matmatmul!(C_csr, tA, tB, A_csr, B_csr, alpha, beta) C = CuSparseMatrixCOO(C_csr) # is this in-place of the original C? end