Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multiply-add interface in LinearAlgebra #29634

Merged
merged 46 commits into from
Aug 14, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
9cef41a
Multiply-add interface for BLAS.gemm!
tkf Oct 13, 2018
4013d8d
Multiply-add interface for BLAS.syrk!
tkf Oct 13, 2018
93aa9b6
Multiply-add interface for BLAS.herk!
tkf Oct 13, 2018
95deaaf
Multiply-add interface for gemv!
tkf Oct 13, 2018
302396b
Fix UndefRefError from C[i,j]
tkf Oct 13, 2018
3b72e3b
Do not assume *(::Bool, ::eltype(C)) exists
tkf Oct 13, 2018
a27f5f5
Implement mul! in terms of addmul!
tkf Nov 18, 2018
04333aa
Test multiply-add interface
tkf Nov 18, 2018
ae38931
Document addmul!
tkf Nov 19, 2018
87b41b8
Use lmul! for beta * C; eltype may not be commutative
tkf Nov 19, 2018
8f41412
Add _lmul_or_fill!
tkf Nov 19, 2018
b97af31
Add multiply-add interface for symmetric matrices
tkf Nov 19, 2018
cd169b3
Add multiply-add interface for Number and UniformScaling
tkf Nov 19, 2018
a655ff8
Add multiply-add interface for diagonal matrices
tkf Nov 19, 2018
84f009b
Add multiply-add interface for bi- and tri-diagonal matrices
tkf Nov 19, 2018
b0ab7b2
Add multiply-add interface for triangular matrices
tkf Nov 19, 2018
92a7e86
Test multiply-add interface in test/generic.jl
tkf Nov 19, 2018
00cfc91
Fix addmul!(C, s::Number, X, alpha, beta)
tkf Nov 19, 2018
ed6821f
Special-case alpha=1 beta=0 using type parameter
tkf Nov 20, 2018
602fb7b
Test multiply-add interface in test/uniformscaling.jl
tkf Nov 20, 2018
c33767b
Test multiply-add interface in test/diagonal.jl
tkf Nov 20, 2018
1309cc7
Use addmul! in SparseArrays
tkf Nov 20, 2018
4d61edf
Systematically test addmul!
tkf Nov 20, 2018
d88a182
Make MulAddMul benchmark-friendly
tkf Nov 20, 2018
8957d0b
Fix _modify! docstring
tkf Nov 20, 2018
66e04c7
Use MulAddMul in A_mul_B_td!
tkf Nov 20, 2018
d0a7672
Comment out broken test_broken
tkf Nov 20, 2018
96c5550
Relax rtol based on eltype of matrices A, B, C
tkf Nov 20, 2018
57a6077
Pass around MulAddMul instead of alpha and beta for type stability
tkf Nov 21, 2018
46d6614
Inline functions between *(::Matrix, ::Matrix) and gemm_wrapper!
tkf Nov 21, 2018
b0642c5
Annotate argument type MulAddMul
tkf Nov 21, 2018
eb1e87f
Inline all addmul!
tkf Nov 21, 2018
ce7b862
Construct MulAddMul outside A_mul_B_td!
tkf Nov 21, 2018
0a09ec1
Add multiply-add interface in test/tridiag.jl
tkf Nov 21, 2018
474c9f7
Mention combined multiply-add in NEWS.md [ci skip]
tkf Nov 22, 2018
82c5749
Mention that mul!(C, A, B, α, β) is deprecated [ci skip]
tkf Nov 22, 2018
29683db
Change API definition to C = ABα + Cβ
tkf Nov 30, 2018
07058ba
Fix indentation
tkf Dec 1, 2018
fae1a7a
Merge branch 'master' into matmuladd
tkf Dec 1, 2018
84b3be7
Merge branch 'master' into matmuladd
tkf Aug 13, 2019
49fed2a
Fix triangular.jl
tkf Aug 13, 2019
32853b9
Fix UndefRefError in multiplication with Diagonal
tkf Aug 13, 2019
ae925c2
Define Base.convert(::Type{Quaternion{T}}, s::Real)
tkf Aug 13, 2019
7be001f
Rename: addmul! -> mul!
tkf Aug 13, 2019
22121d0
Fix doctest
tkf Aug 13, 2019
c7e3efd
Workaround broadcast error with e.g., triangular matrix of BigFloat
tkf Aug 13, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Multiply-add interface for BLAS.syrk!
  • Loading branch information
tkf committed Nov 18, 2018
commit 4013d8d92aaea5369cefcd032cb5ce8715f7a0e7
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1305,8 +1305,8 @@ for (fname, elty) in ((:dsyrk_,:Float64),
# * .. Array Arguments ..
# REAL A(LDA,*),C(LDC,*)
function syrk!(uplo::AbstractChar, trans::AbstractChar,
alpha::($elty), A::AbstractVecOrMat{$elty},
beta::($elty), C::AbstractMatrix{$elty})
alpha::Union{($elty), Bool}, A::AbstractVecOrMat{$elty},
beta::Union{($elty), Bool}, C::AbstractMatrix{$elty})
@assert !has_offset_axes(A, C)
n = checksquare(C)
nn = size(A, trans == 'N' ? 1 : 2)
Expand Down
42 changes: 26 additions & 16 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,34 +221,38 @@ julia> B
"""
lmul!(A, B)

function mul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T}) where {T<:BlasFloat}
function mul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasFloat}
A = transA.parent
if A===B
return syrk_wrapper!(C, 'T', A)
return syrk_wrapper!(C, 'T', A, alpha, beta)
else
return gemm_wrapper!(C, 'T', 'N', A, B)
return gemm_wrapper!(C, 'T', 'N', A, B, alpha, beta)
end
end
function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat)
function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
alpha::Number = true, beta::Number = false)
A = transA.parent
return generic_matmatmul!(C, 'T', 'N', A, B)
return generic_matmatmul!(C, 'T', 'N', A, B, alpha, beta)
end

function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, transB::Transpose{<:Any,<:StridedVecOrMat{T}}) where {T<:BlasFloat}
function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, transB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasFloat}
B = transB.parent
if A===B
return syrk_wrapper!(C, 'N', A)
return syrk_wrapper!(C, 'N', A, alpha, beta)
else
return gemm_wrapper!(C, 'N', 'T', A, B)
return gemm_wrapper!(C, 'N', 'T', A, B, alpha, beta)
end
end
# Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency.
for elty in (Float32,Float64)
@eval begin
function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, transB::Transpose{<:Any,<:StridedVecOrMat{$elty}})
function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, transB::Transpose{<:Any,<:StridedVecOrMat{$elty}},
alpha::Union{$elty, Bool} = true, beta::Union{$elty, Bool} = false)
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
mul!(Cfl,Afl,transB)
mul!(Cfl, Afl, transB, alpha, beta)
return C
end
end
Expand Down Expand Up @@ -379,7 +383,9 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::
return generic_matvecmul!(y, tA, A, x)
end

function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}) where T<:BlasFloat
function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
alpha::Union{Bool, T} = true,
beta::Union{Bool, T} = false) where T<:BlasFloat
nC = checksquare(C)
if tA == 'T'
(nA, mA) = size(A,1), size(A,2)
Expand All @@ -392,19 +398,23 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat
throw(DimensionMismatch("output matrix has size: $(nC), but should have size $(mA)"))
end
if mA == 0 || nA == 0
return fill!(C,0)
if iszero(beta)
return fill!(C, 0)
else
return rmul!(C, beta)
end
end
if mA == 2 && nA == 2
return matmul2x2!(C,tA,tAt,A,A)
return matmul2x2!(C, tA, tAt, A, A, alpha, beta)
end
if mA == 3 && nA == 3
return matmul3x3!(C,tA,tAt,A,A)
return matmul3x3!(C, tA, tAt, A, A, alpha, beta)
end

if stride(A, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(C, 2) >= size(C, 1)
return copytri!(BLAS.syrk!('U', tA, one(T), A, zero(T), C), 'U')
return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U')
end
return generic_matmatmul!(C, tA, tAt, A, A)
return generic_matmatmul!(C, tA, tAt, A, A, alpha, beta)
end

function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}) where T<:BlasReal
Expand Down