Skip to content

Commit

Permalink
Avoid constructing MulAddMuls on Julia v1.12+ (#2277)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed May 23, 2024
1 parent 7629209 commit cbbf19a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 27 deletions.
15 changes: 8 additions & 7 deletions lib/cublas/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,10 @@ end
#

# GEMV

function LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::StridedCuMatrix, B::StridedCuVector, _add::MulAddMul)
# legacy method
LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::StridedCuMatrix, B::StridedCuVector, _add::MulAddMul) =
LinearAlgebra.generic_matvecmul!(Y, tA, A, B, _add.alpha, _add.beta)
function LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::StridedCuMatrix, B::StridedCuVector, alpha::Number, beta::Number)
mA, nA = tA == 'N' ? size(A) : reverse(size(A))

if nA != length(B)
Expand All @@ -194,7 +196,6 @@ function LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::Stri
end

T = eltype(Y)
alpha, beta = _add.alpha, _add.beta
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
if T <: CublasFloat && eltype(A) == eltype(B) == T
if tA in ('N', 'T', 'C')
Expand All @@ -206,7 +207,7 @@ function LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::Stri
end
end
end
LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, MulAddMul(alpha, beta))
LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, alpha, beta)
end

if VERSION < v"1.10.0-DEV.1365"
Expand Down Expand Up @@ -282,10 +283,10 @@ end # VERSION
#

# GEMM

function LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::StridedCuVecOrMat, B::StridedCuVecOrMat, _add::MulAddMul)
LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::StridedCuVecOrMat, B::StridedCuVecOrMat, _add::MulAddMul) =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
function LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::StridedCuVecOrMat, B::StridedCuVecOrMat, alpha::Number, beta::Number)
T = eltype(C)
alpha, beta = _add.alpha, _add.beta
mA, nA = size(A, tA == 'N' ? 1 : 2), size(A, tA == 'N' ? 2 : 1)
mB, nB = size(B, tB == 'N' ? 1 : 2), size(B, tB == 'N' ? 2 : 1)

Expand Down
67 changes: 47 additions & 20 deletions lib/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,27 @@ op_wrappers = ((identity, T -> 'N', identity),
(T -> :(Adjoint{T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))),
(T -> :(HermOrSym{T, <:$T}), T -> 'N', A -> :(parent($A))))

function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::DenseCuVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}}
# legacy methods with final MulAddMul argument
LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::DenseCuVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::CuSparseVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::CuSparseMatrix{T}, B::DenseCuMatrix{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::DenseCuVector{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
mv_wrapper(tA, _add.alpha, A, B, _add.beta, C)
mv_wrapper(tA, alpha, A, B, beta, C)
end
function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::CuSparseVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}}
function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::CuSparseVector{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
mv_wrapper(tA, _add.alpha, A, CuVector{T}(B), _add.beta, C)
mv_wrapper(tA, alpha, A, CuVector{T}(B), beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::CuSparseMatrix{T}, B::DenseCuMatrix{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}}
function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::CuSparseMatrix{T}, B::DenseCuMatrix{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm_wrapper(tA, tB, _add.alpha, A, B, _add.beta, C)
mm_wrapper(tA, tB, alpha, A, B, beta, C)
end

for (wrapa, transa, unwrapa) in op_wrappers
Expand All @@ -87,25 +95,36 @@ for (wrapa, transa, unwrapa) in op_wrappers
end
end

function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::DenseCuMatrix{T}, B::CuSparseVector{T}, _add::MulAddMul) where {T <: BlasFloat}
# legacy methods with final MulAddMul argument
LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::DenseCuMatrix{T}, B::CuSparseVector{T}, _add::MulAddMul) where {T <: BlasFloat} =
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)

LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSC{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSR{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCOO{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::DenseCuMatrix{T}, B::CuSparseVector{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
gemvi!(tA, _add.alpha, A, B, _add.beta, C, 'O')
gemvi!(tA, alpha, A, B, beta, C, 'O')
end

function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSC{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}}
function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSC{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm!(tA, tB, _add.alpha, A, B, _add.beta, C, 'O')
mm!(tA, tB, alpha, A, B, beta, C, 'O')
end
function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSR{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}}
function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCSR{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm!(tA, tB, _add.alpha, A, B, _add.beta, C, 'O')
mm!(tA, tB, alpha, A, B, beta, C, 'O')
end
function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCOO{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}}
function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::DenseCuMatrix{T}, B::CuSparseMatrixCOO{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm!(tA, tB, _add.alpha, A, B, _add.beta, C, 'O')
mm!(tA, tB, alpha, A, B, beta, C, 'O')
end

for (wrapa, transa, unwrapa) in op_wrappers
Expand Down Expand Up @@ -143,23 +162,31 @@ for (wrapa, transa, unwrapa) in op_wrappers
end
end

function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSC{T}, tA, tB, A::CuSparseMatrixCSC{T}, B::CuSparseMatrixCSC{T}, _add::MulAddMul) where {T <: BlasFloat}
# legacy methods with final MulAddMul argument
LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSC{T}, tA, tB, A::CuSparseMatrixCSC{T}, B::CuSparseMatrixCSC{T}, _add::MulAddMul) where {T <: BlasFloat} =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSR{T}, tA, tB, A::CuSparseMatrixCSR{T}, B::CuSparseMatrixCSR{T}, _add::MulAddMul) where {T <: BlasFloat} =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCOO{T}, tA, tB, A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}, _add::MulAddMul) where {T <: BlasFloat} =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSC{T}, tA, tB, A::CuSparseMatrixCSC{T}, B::CuSparseMatrixCSC{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
gemm!(tA, tB, _add.alpha, A, B, _add.beta, C, 'O')
gemm!(tA, tB, alpha, A, B, beta, C, 'O')
end
function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSR{T}, tA, tB, A::CuSparseMatrixCSR{T}, B::CuSparseMatrixCSR{T}, _add::MulAddMul) where {T <: BlasFloat}
function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCSR{T}, tA, tB, A::CuSparseMatrixCSR{T}, B::CuSparseMatrixCSR{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
gemm!(tA, tB, _add.alpha, A, B, _add.beta, C, 'O')
gemm!(tA, tB, alpha, A, B, beta, C, 'O')
end
function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCOO{T}, tA, tB, A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}, _add::MulAddMul) where {T <: BlasFloat}
function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCOO{T}, tA, tB, A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
A_csr = CuSparseMatrixCSR(A)
B_csr = CuSparseMatrixCSR(B)
C_csr = CuSparseMatrixCSR(C)
generic_matmatmul!(C_csr, tA, tB, A_csr, B_csr, _add.alpha, _add.beta)
generic_matmatmul!(C_csr, tA, tB, A_csr, B_csr, alpha, beta)
C = CuSparseMatrixCOO(C_csr) # is this in-place of the original C?
end

Expand Down

0 comments on commit cbbf19a

Please sign in to comment.