Skip to content

Commit

Permalink
Construct MulAddMul at gemm_wrapper! call sites (#34601)
Browse files Browse the repository at this point in the history
* Construct MulAddMul at gemm_wrapper! call sites

* Add branches manually in MulAddMul constructor

This is suggested by chethega in:
#29634 (comment)

* Update stdlib/LinearAlgebra/src/generic.jl

Co-Authored-By: Kristoffer Carlsson <[email protected]>

Co-authored-by: Kristoffer Carlsson <[email protected]>
  • Loading branch information
tkf and KristofferC committed Feb 1, 2020
1 parent ab9410d commit 2da42e0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
19 changes: 16 additions & 3 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,23 @@ struct MulAddMul{ais1, bis0, TA, TB}
beta::TB
end

MulAddMul(alpha::TA, beta::TB) where {TA, TB} =
MulAddMul{isone(alpha), iszero(beta), TA, TB}(alpha, beta)
@inline function MulAddMul(alpha::TA, beta::TB) where {TA,TB}
if isone(alpha)
if iszero(beta)
return MulAddMul{true,true,TA,TB}(alpha, beta)
else
return MulAddMul{true,false,TA,TB}(alpha, beta)
end
else
if iszero(beta)
return MulAddMul{false,true,TA,TB}(alpha, beta)
else
return MulAddMul{false,false,TA,TB}(alpha, beta)
end
end
end

MulAddMul() = MulAddMul(true, false)
MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false)

@inline (::MulAddMul{true})(x) = x
@inline (p::MulAddMul{false})(x) = x * p.alpha
Expand Down
34 changes: 17 additions & 17 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ end

@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasFloat}
return gemm_wrapper!(C, 'N', 'N', A, B, alpha, beta)
return gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta))
end
# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
# first matrix as a real matrix and carry out real matrix matrix multiply
Expand Down Expand Up @@ -307,7 +307,7 @@ lmul!(A, B)
if A===B
return syrk_wrapper!(C, 'T', A, alpha, beta)
else
return gemm_wrapper!(C, 'T', 'N', A, B, alpha, beta)
return gemm_wrapper!(C, 'T', 'N', A, B, MulAddMul(alpha, beta))
end
end
@inline function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
Expand All @@ -322,7 +322,7 @@ end
if A===B
return syrk_wrapper!(C, 'N', A, alpha, beta)
else
return gemm_wrapper!(C, 'N', 'T', A, B, alpha, beta)
return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta))
end
end
# Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency.
Expand All @@ -349,7 +349,7 @@ end
alpha::Number, beta::Number) where {T<:BlasFloat}
A = transA.parent
B = transB.parent
return gemm_wrapper!(C, 'T', 'T', A, B, alpha, beta)
return gemm_wrapper!(C, 'T', 'T', A, B, MulAddMul(alpha, beta))
end
@inline function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number)
Expand All @@ -362,7 +362,7 @@ end
alpha::Number, beta::Number) where {T<:BlasFloat}
A = transA.parent
B = transB.parent
return gemm_wrapper!(C, 'T', 'C', A, B, alpha, beta)
return gemm_wrapper!(C, 'T', 'C', A, B, MulAddMul(alpha, beta))
end
@inline function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number)
Expand All @@ -382,7 +382,7 @@ end
if A===B
return herk_wrapper!(C, 'C', A, alpha, beta)
else
return gemm_wrapper!(C, 'C', 'N', A, B, alpha, beta)
return gemm_wrapper!(C, 'C', 'N', A, B, MulAddMul(alpha, beta))
end
end
@inline function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
Expand All @@ -402,7 +402,7 @@ end
if A === B
return herk_wrapper!(C, 'N', A, alpha, beta)
else
return gemm_wrapper!(C, 'N', 'C', A, B, alpha, beta)
return gemm_wrapper!(C, 'N', 'C', A, B, MulAddMul(alpha, beta))
end
end
@inline function mul!(C::AbstractMatrix, A::AbstractVecOrMat, adjB::Adjoint{<:Any,<:AbstractVecOrMat},
Expand All @@ -415,7 +415,7 @@ end
alpha::Number, beta::Number) where {T<:BlasFloat}
A = adjA.parent
B = adjB.parent
return gemm_wrapper!(C, 'C', 'C', A, B, alpha, beta)
return gemm_wrapper!(C, 'C', 'C', A, B, MulAddMul(alpha, beta))
end
@inline function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, adjB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number)
Expand Down Expand Up @@ -508,7 +508,7 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat
return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U')
end
end
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
return gemm_wrapper!(C, tA, tAt, A, A, MulAddMul(α, β))
end

function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
Expand Down Expand Up @@ -547,7 +547,7 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
end
end
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
return gemm_wrapper!(C, tA, tAt, A, A, MulAddMul(α, β))
end

function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
Expand All @@ -561,7 +561,7 @@ end

function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number=true, β::Number=false) where {T<:BlasFloat}
_add = MulAddMul()) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)

Expand All @@ -573,21 +573,21 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end

if mA == 0 || nA == 0 || nB == 0 || iszero(α)
if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha)
if size(C) != (mA, nB)
throw(DimensionMismatch("C has dimensions $(size(C)), should have ($mA,$nB)"))
end
return _rmul_or_fill!(C, β)
return _rmul_or_fill!(C, _add.beta)
end

if mA == 2 && nA == 2 && nB == 2
return matmul2x2!(C, tA, tB, A, B, MulAddMul(α, β))
return matmul2x2!(C, tA, tB, A, B, _add)
end
if mA == 3 && nA == 3 && nB == 3
return matmul3x3!(C, tA, tB, A, B, MulAddMul(α, β))
return matmul3x3!(C, tA, tB, A, B, _add)
end

alpha, beta = promote(α, β, zero(T))
alpha, beta = promote(_add.alpha, _add.beta, zero(T))
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
Expand All @@ -596,7 +596,7 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
stride(C, 2) >= size(C, 1))
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
end
generic_matmatmul!(C, tA, tB, A, B, MulAddMul(α, β))
generic_matmatmul!(C, tA, tB, A, B, _add)
end

# blas.jl defines matmul for floats; other integer and mixed precision
Expand Down

0 comments on commit 2da42e0

Please sign in to comment.