Skip to content

Commit

Permalink
Merge JuliaGPU#180
Browse files Browse the repository at this point in the history
180: Fix sparse-dense matmul, with transposed dense r=maleadt a=irhum

Continuation of the pull request in JuliaGPU/CuArrays.jl#728 

As previously mentioned, the patch makes sure the correct method is dispatched
```julia
# before the patch
densematrix_T = transpose(densematrix)
@which mul!(similar(densematrix_T, Float32, (size(sparsecscmatrix,1), size(densematrix_T,2))), sparsecscmatrix, densematrix_T)
# mul!(C, A, B) at stdlib/v1.4/LinearAlgebra/src/matmul.jl:208

# after the proposed patch
@which mul!(similar(densematrix_T, Float32, (size(sparsecscmatrix,1), size(densematrix_T,2))), sparsecscmatrix, densematrix_T)
# mul!(C::CuArray{T,2,P} where P, A::Union{CuArrays.CUSPARSE.CuSparseMatrixBSR{T}, CuArrays.CUSPARSE.CuSparseMatrixCSC{T}, CuArrays.CUSPARSE.CuSparseMatrixCSR{T}, CuArrays.CUSPARSE.CuSparseMatrixHYB{T}}, transB::Transpose{#s243,#s242} where #s242<:(CuArray{T,2,P} where P) where #s243) where T at CuArrays/l0gXB/src/sparse/interfaces.jl:20
```

and tests have been added for the sparse-transpose(dense) case as well

Co-authored-by: Irhum Shafkat <[email protected]>
  • Loading branch information
bors[bot] and irhum committed May 27, 2020
2 parents f265c94 + 54c012a commit 3af3e31
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSpa
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')

LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},B::CuMatrix{T}) where {T} = mm2!('N','N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},transB::Transpose{<:Any, CuMatrix{T}}) where {T} = mm2!('N','T',one(T),A,parent(transB),zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},transB::Transpose{<:Any, <:CuMatrix{T}}) where {T} = mm2!('N','T',one(T),A,parent(transB),zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},B::CuMatrix{T}) where {T} = mm2!('T','N',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},transB::Transpose{<:Any, CuMatrix{T}}) where {T} = mm2!('T','T',one(T),parent(transA),parent(transB),zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},transB::Transpose{<:Any, <:CuMatrix{T}}) where {T} = mm2!('T','T',one(T),parent(transA),parent(transB),zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:CuSparseMatrix{T}},B::CuMatrix{T}) where {T} = mm2!('C','N',one(T),parent(adjA),B,zero(T),C,'O')

LinearAlgebra.mul!(C::CuMatrix{T},A::HermOrSym{<:Number, <:CuSparseMatrix},B::CuMatrix) where {T} = mm!('N',one(T),A,B,zero(T),C,'O')
Expand Down
10 changes: 10 additions & 0 deletions test/cusparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,7 @@ end
A = sparse(rand(elty,m,k))
B = rand(elty,k,n)
C = rand(elty,m,n)
Bᵀ = collect(transpose(B))
alpha = rand(elty)
beta = rand(elty)
@testset "csr" begin
Expand All @@ -1744,6 +1745,12 @@ end
h_C = collect(d_C)
D = A * B
@test D h_C
d_Bᵀ = CuArray(Bᵀ)
d_C = CuArray(C)
mul!(d_C, d_A, transpose(d_Bᵀ))
h_C = collect(d_C)
D = A * transpose(Bᵀ)
@test D h_C
end
@testset "csc" begin
d_B = CuArray(B)
Expand All @@ -1758,6 +1765,9 @@ end
h_C = collect(d_C)
D = A * B
@test D h_C
d_Bᵀ = CuArray(Bᵀ)
d_C = CuArray(C)
@test_throws CUSPARSEError mul!(d_C, d_A, transpose(d_Bᵀ))
end
end
end
Expand Down

0 comments on commit 3af3e31

Please sign in to comment.