diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index bec89c21cda61..b90928c967c81 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -25,10 +25,23 @@ struct MulAddMul{ais1, bis0, TA, TB} beta::TB end -MulAddMul(alpha::TA, beta::TB) where {TA, TB} = - MulAddMul{isone(alpha), iszero(beta), TA, TB}(alpha, beta) +@inline function MulAddMul(alpha::TA, beta::TB) where {TA,TB} + if isone(alpha) + if iszero(beta) + return MulAddMul{true,true,TA,TB}(alpha, beta) + else + return MulAddMul{true,false,TA,TB}(alpha, beta) + end + else + if iszero(beta) + return MulAddMul{false,true,TA,TB}(alpha, beta) + else + return MulAddMul{false,false,TA,TB}(alpha, beta) + end + end +end -MulAddMul() = MulAddMul(true, false) +MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false) @inline (::MulAddMul{true})(x) = x @inline (p::MulAddMul{false})(x) = x * p.alpha diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 1bbd2230c315d..e5c33388a9af0 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -166,7 +166,7 @@ end @inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, alpha::Number, beta::Number) where {T<:BlasFloat} - return gemm_wrapper!(C, 'N', 'N', A, B, alpha, beta) + return gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) end # Complex Matrix times real matrix: We use that it is generally faster to reinterpret the # first matrix as a real matrix and carry out real matrix matrix multiply @@ -307,7 +307,7 @@ lmul!(A, B) if A===B return syrk_wrapper!(C, 'T', A, alpha, beta) else - return gemm_wrapper!(C, 'T', 'N', A, B, alpha, beta) + return gemm_wrapper!(C, 'T', 'N', A, B, MulAddMul(alpha, beta)) end end @inline function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat, @@ -322,7 +322,7 @@ end if A===B return syrk_wrapper!(C, 'N', A, alpha, beta) else - return gemm_wrapper!(C, 'N', 'T', A, B, alpha, beta) + return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta)) end end # Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency. @@ -349,7 +349,7 @@ end alpha::Number, beta::Number) where {T<:BlasFloat} A = transA.parent B = transB.parent - return gemm_wrapper!(C, 'T', 'T', A, B, alpha, beta) + return gemm_wrapper!(C, 'T', 'T', A, B, MulAddMul(alpha, beta)) end @inline function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) @@ -362,7 +362,7 @@ end alpha::Number, beta::Number) where {T<:BlasFloat} A = transA.parent B = transB.parent - return gemm_wrapper!(C, 'T', 'C', A, B, alpha, beta) + return gemm_wrapper!(C, 'T', 'C', A, B, MulAddMul(alpha, beta)) end @inline function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) @@ -382,7 +382,7 @@ end if A===B return herk_wrapper!(C, 'C', A, alpha, beta) else - return gemm_wrapper!(C, 'C', 'N', A, B, alpha, beta) + return gemm_wrapper!(C, 'C', 'N', A, B, MulAddMul(alpha, beta)) end end @inline function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat, @@ -402,7 +402,7 @@ end if A === B return herk_wrapper!(C, 'N', A, alpha, beta) else - return gemm_wrapper!(C, 'N', 'C', A, B, alpha, beta) + return gemm_wrapper!(C, 'N', 'C', A, B, MulAddMul(alpha, beta)) end end @inline function mul!(C::AbstractMatrix, A::AbstractVecOrMat, adjB::Adjoint{<:Any,<:AbstractVecOrMat}, @@ -415,7 +415,7 @@ end alpha::Number, beta::Number) where {T<:BlasFloat} A = adjA.parent B = adjB.parent - return gemm_wrapper!(C, 'C', 'C', A, B, alpha, beta) + return gemm_wrapper!(C, 'C', 'C', A, B, MulAddMul(alpha, beta)) end @inline function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, adjB::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) @@ -508,7 +508,7 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U') end end - return gemm_wrapper!(C, tA, tAt, A, A, α, β) + return gemm_wrapper!(C, tA, tAt, A, A, MulAddMul(α, β)) end function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, @@ -547,7 +547,7 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true) end end - return gemm_wrapper!(C, tA, tAt, A, A, α, β) + return gemm_wrapper!(C, tA, tAt, A, A, MulAddMul(α, β)) end function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, @@ -561,7 +561,7 @@ end function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number=true, β::Number=false) where {T<:BlasFloat} + _add = MulAddMul()) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) @@ -573,21 +573,21 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar throw(ArgumentError("output matrix must not be aliased with input matrix")) end - if mA == 0 || nA == 0 || nB == 0 || iszero(α) + if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha) if size(C) != (mA, nB) throw(DimensionMismatch("C has dimensions $(size(C)), should have ($mA,$nB)")) end - return _rmul_or_fill!(C, β) + return _rmul_or_fill!(C, _add.beta) end if mA == 2 && nA == 2 && nB == 2 - return matmul2x2!(C, tA, tB, A, B, MulAddMul(α, β)) + return matmul2x2!(C, tA, tB, A, B, _add) end if mA == 3 && nA == 3 && nB == 3 - return matmul3x3!(C, tA, tB, A, B, MulAddMul(α, β)) + return matmul3x3!(C, tA, tB, A, B, _add) end - alpha, beta = promote(α, β, zero(T)) + alpha, beta = promote(_add.alpha, _add.beta, zero(T)) if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && @@ -596,7 +596,7 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar stride(C, 2) >= size(C, 1)) return BLAS.gemm!(tA, tB, alpha, A, B, beta, C) end - generic_matmatmul!(C, tA, tB, A, B, MulAddMul(α, β)) + generic_matmatmul!(C, tA, tB, A, B, _add) end # blas.jl defines matmul for floats; other integer and mixed precision