Skip to content

Commit

Permalink
Merge pull request #450 from JuliaGPU/tb/gemv
Browse files Browse the repository at this point in the history
Allow use of strided vectors with mul! (gemv! and gemm!)
  • Loading branch information
maleadt authored Sep 28, 2020
2 parents b529985 + 06466f6 commit 75f7d30
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 57 deletions.
100 changes: 58 additions & 42 deletions lib/cublas/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# interfacing with LinearAlgebra standard library

cublas_size(t::Char, M::CuVecOrMat) = (size(M, t=='N' ? 1 : 2), size(M, t=='N' ? 2 : 1))



#
# BLAS 1
Expand Down Expand Up @@ -74,40 +71,52 @@ end

# GEMV

function gemv_wrapper!(y::CuVector{T}, tA::Char, A::CuMatrix{T}, x::CuVector{T},
alpha::Number = true, beta::Number = false) where T<:CublasFloat
mA, nA = cublas_size(tA, A)
if nA != length(x)
throw(DimensionMismatch("second dimension of A, $nA, does not match length of x, $(length(x))"))
function gemv_dispatch!(Y::CuVector, A, B, alpha::Number=true, beta::Number=false)
mA, nA = size(A)

if nA != length(B)
throw(DimensionMismatch("second dimension of A, $nA, does not match length of B, $(length(B))"))
end
if mA != length(y)
throw(DimensionMismatch("first dimension of A, $mA, does not match length of y, $(length(y))"))

if mA != length(Y)
throw(DimensionMismatch("first dimension of A, $mA, does not match length of Y, $(length(Y))"))
end

if mA == 0
return y
return Y
end

if nA == 0
return rmul!(y, 0)
return rmul!(Y, 0)
end

tA, dA = if A isa Transpose
'T', parent(A)
elseif A isa Adjoint
'C', parent(A)
else
'N', A
end

T = eltype(Y)
if T <: CublasFloat && A isa StridedCuArray{T} && B isa StridedCuArray{T}
gemv!(tA, alpha, dA, B, beta, Y)
else
gemm_dispatch!(Y, A, B, alpha, beta)
end
gemv!(tA, alpha, A, x, beta, y)
end

LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat =
gemv_wrapper!(Y, 'N', A, B, a, b)
LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat =
gemv_wrapper!(Y, 'T', A.parent, B, a, b)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Real, b::Real) where T<:CublasReal =
gemv_wrapper!(Y, 'T', A.parent, B, a, b)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasComplex =
gemv_wrapper!(Y, 'C', A.parent, B, a, b)

# ambiguity hacks: Base and GPUArrays has mul! with a::Real, b::Real
LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Real, b::Real) where T<:CublasFloat =
gemv_wrapper!(Y, 'N', A, B, a, b)
LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Real, b::Real) where T<:CublasFloat =
gemv_wrapper!(Y, 'T', A.parent, B, a, b)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Real, b::Real) where T<:CublasComplex =
gemv_wrapper!(Y, 'C', A.parent, B, a, b)
for NT in (Number, Real)
# NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods
@eval begin
LinearAlgebra.mul!(Y::CuVector, A::StridedCuMatrix, B::StridedCuVector, a::$NT, b::$NT) =
gemv_dispatch!(Y, A, B, a, b)
LinearAlgebra.mul!(Y::CuVector, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::StridedCuVector, a::$NT, b::$NT) =
gemv_dispatch!(Y, A, B, a, b)
LinearAlgebra.mul!(Y::CuVector, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::StridedCuVector, a::$NT, b::$NT) =
gemv_dispatch!(Y, A, B, a, b)
end
end

# TRSV

Expand Down Expand Up @@ -162,8 +171,13 @@ end
# GEMM

function gemm_dispatch!(C::CuVecOrMat, A, B, alpha::Number=true, beta::Number=false)
mA, nA = size(A)
mB, nB = size(B)
if ndims(A) > 2
throw(ArgumentError("A has more than 2 dimensions"))
elseif ndims(B) > 2
throw(ArgumentError("B has more than 2 dimensions"))
end
mA, nA = size(A,1), size(A,2)
mB, nB = size(B,1), size(B,2)

if nA != mB
throw(DimensionMismatch("A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
Expand Down Expand Up @@ -196,9 +210,11 @@ function gemm_dispatch!(C::CuVecOrMat, A, B, alpha::Number=true, beta::Number=fa
'N', B
end

if gemmExComputeType(eltype(A), eltype(B), eltype(C), mA, nA, nB) !== nothing
T = eltype(C)
if dA isa DenseCuArray && dB isa DenseCuArray &&
gemmExComputeType(eltype(A), eltype(B), eltype(C), mA, nA, nB) !== nothing
gemmEx!(tA, tB, alpha, dA, dB, beta, C)
elseif eltype(A) === eltype(B) === eltype(C) && eltype(A) <: CublasFloat
elseif T <: CublasFloat && dA isa DenseCuArray{T} && dB isa DenseCuArray{T}
gemm!(tA, tB, alpha, dA, dB, beta, C)
else
GPUArrays.generic_matmatmul!(C, A, B, alpha, beta)
Expand All @@ -208,26 +224,26 @@ end
for NT in (Number, Real)
# NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods
@eval begin
LinearAlgebra.mul!(C::CuMatrix, A::CuVecOrMat, B::CuVecOrMat, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::StridedCuVecOrMat, B::StridedCuVecOrMat, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)

LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:CuVecOrMat}, B::CuMatrix, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::StridedCuMatrix, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
LinearAlgebra.mul!(C::CuMatrix, A::CuMatrix, B::Transpose{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::StridedCuMatrix, B::Transpose{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:CuVecOrMat}, B::Transpose{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::Transpose{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)

LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:CuVecOrMat}, B::CuMatrix, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::StridedCuMatrix, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
LinearAlgebra.mul!(C::CuMatrix, A::CuMatrix, B::Adjoint{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::StridedCuMatrix, B::Adjoint{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:CuVecOrMat}, B::Adjoint{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::Adjoint{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)

LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:CuVecOrMat}, B::Adjoint{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::Adjoint{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:CuVecOrMat}, B::Transpose{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::Transpose{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
end
end
Expand Down
30 changes: 15 additions & 15 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,10 @@ for (fname, elty) in ((:cublasDgemv_v2,:Float64),
@eval begin
function gemv!(trans::Char,
alpha::Number,
A::CuMatrix{$elty},
X::CuVector{$elty},
A::StridedCuMatrix{$elty},
X::StridedCuVector{$elty},
beta::Number,
Y::CuVector{$elty})
Y::DenseCuVector{$elty})
# handle trans
m,n = size(A)
# check dimensions
Expand All @@ -296,10 +296,10 @@ for (fname, elty) in ((:cublasDgemv_v2,:Float64),
$fname(handle(), trans, m, n, alpha, A, lda, X, incx, beta, Y, incy)
Y
end
function gemv(trans::Char, alpha::Number, A::CuMatrix{$elty}, X::CuVector{$elty})
function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{$elty}, X::StridedCuVector{$elty})
gemv!(trans, alpha, A, X, zero($elty), similar(X, $elty, size(A, (trans == 'N' ? 1 : 2))))
end
function gemv(trans::Char, A::CuMatrix{$elty}, X::CuVector{$elty})
function gemv(trans::Char, A::StridedCuMatrix{$elty}, X::StridedCuVector{$elty})
gemv!(trans, one($elty), A, X, zero($elty), similar(X, $elty, size(A, (trans == 'N' ? 1 : 2))))
end
end
Expand Down Expand Up @@ -708,10 +708,10 @@ for (fname, elty) in
function gemm!(transA::Char,
transB::Char,
alpha::Number,
A::CuVecOrMat{$elty},
B::CuVecOrMat{$elty},
A::DenseCuVecOrMat{$elty},
B::DenseCuVecOrMat{$elty},
beta::Number,
C::CuVecOrMat{$elty})
C::DenseCuVecOrMat{$elty})
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)
Expand All @@ -727,16 +727,16 @@ for (fname, elty) in
function gemm(transA::Char,
transB::Char,
alpha::Number,
A::CuMatrix{$elty},
B::CuMatrix{$elty})
A::DenseCuMatrix{$elty},
B::DenseCuMatrix{$elty})
gemm!(transA, transB, alpha, A, B, zero($elty),
similar(B, $elty, (size(A, transA == 'N' ? 1 : 2),
size(B, transB == 'N' ? 2 : 1))))
end
function gemm(transA::Char,
transB::Char,
A::CuMatrix{$elty},
B::CuMatrix{$elty})
A::DenseCuMatrix{$elty},
B::DenseCuMatrix{$elty})
gemm(transA, transB, one($elty), A, B)
end
end
Expand Down Expand Up @@ -810,10 +810,10 @@ end

function gemmEx!(transA::Char, transB::Char,
@nospecialize(alpha::Number),
@nospecialize(A::CuVecOrMat),
@nospecialize(B::CuVecOrMat),
@nospecialize(A::DenseCuVecOrMat),
@nospecialize(B::DenseCuVecOrMat),
@nospecialize(beta::Number),
@nospecialize(C::CuVecOrMat);
@nospecialize(C::DenseCuVecOrMat);
algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT)
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
Expand Down
5 changes: 5 additions & 0 deletions src/pointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{
Base._memory_offset(V.parent, map(first, V.indices)...)
end

# from reshaped subarrays
function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{Base.RangeIndex,Base.ReshapedUnitRange}}}}) where {T,N,P}
return Base. unsafe_convert(CuPtr{T}, parent(V)) +
(Base.first_index(V)-1)*sizeof(T)
end

## limited pointer arithmetic & comparison

Expand Down
16 changes: 16 additions & 0 deletions test/cublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ end
dA = CuArray(A)
@test_throws DimensionMismatch mul!(dy, dA, dx)
end

@testset "mul! y = $f(A) * x * $Ts(a) + y * $Ts(b)" for f in (identity, transpose, adjoint), Ts in (Int, elty)
y, A, x = rand(elty, 5), rand(elty, 5, 5), rand(elty, 5)
dy, dA, dx = CuArray(y), CuArray(A), CuArray(x)
Expand Down Expand Up @@ -419,6 +420,13 @@ end
end
end
end

@testset "gemv! with strided inputs" begin # JuliaGPU/CUDA.jl#445
testf(rand(16), rand(4)) do p, b
W = @view p[reshape(1:(16),4,4)]
W*b
end
end
end

############################################################################################
Expand Down Expand Up @@ -1284,6 +1292,14 @@ end
@test C Array(dC) rtol=rtol
end
end

@testset "gemm! with strided inputs" begin # JuliaGPU/CUDA.jl#78
inn = 784; out = 32
testf(randn(784*100), rand(Float32, 784, 100)) do p, x
p[reshape(1:(out*inn),out,inn)] * x
@view(p[reshape(1:(out*inn),out,inn)]) * x
end
end
end

############################################################################################
Expand Down

0 comments on commit 75f7d30

Please sign in to comment.