Skip to content

Commit

Permalink
Reduce compile time for generic matmatmul (JuliaLang#52038)
Browse files Browse the repository at this point in the history
This is another attempt at improving the compile time issue with generic
matmatmul, hopefully improving runtime performance also.

@chriselrod @jishnub

There seems to be a little typo/oversight somewhere, but it shows how it
could work. Locally, this reduces benchmark times from
JuliaLang#51812 (comment) by
more than 50%.

---------

Co-authored-by: Chris Elrod <[email protected]>
  • Loading branch information
dkarrasch and chriselrod committed Nov 14, 2023
1 parent 4bc45a7 commit 0cf2bf1
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 211 deletions.
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,11 @@ adjoint(A::Adjoint) = A.parent
transpose(A::Transpose) = A.parent
adjoint(A::Transpose{<:Real}) = A.parent
transpose(A::Adjoint{<:Real}) = A.parent
adjoint(A::Transpose{<:Any,<:Adjoint}) = transpose(A.parent.parent)
transpose(A::Adjoint{<:Any,<:Transpose}) = adjoint(A.parent.parent)
# disambiguation
adjoint(A::Transpose{<:Real,<:Adjoint}) = transpose(A.parent.parent)
transpose(A::Adjoint{<:Real,<:Transpose}) = A.parent

# printing
function Base.showarg(io::IO, v::Adjoint, toplevel)
Expand Down Expand Up @@ -395,11 +400,16 @@ map(f, avs::AdjointAbsVec...) = adjoint(map((xs...) -> adjoint(f(adjoint.(xs)...
map(f, tvs::TransposeAbsVec...) = transpose(map((xs...) -> transpose(f(transpose.(xs)...)), parent.(tvs)...))
quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in the defs below
quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below
quasiparentc(x) = parent(parent(x)); quasiparentc(x::Number) = conj(x) # to handle numbers in the defs below
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...))
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
# Hack to preserve behavior after #32122; this needs to be done with a broadcast style instead to support dotted fusion
Broadcast.broadcast_preserving_zero_d(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...))
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,Transpose{<:Any,<:AdjointAbsVec}}...) =
transpose(adjoint(broadcast((xs...) -> adjoint(transpose(f(conj.(xs)...))), quasiparentc.(tvs)...)))
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,Adjoint{<:Any,<:TransposeAbsVec}}...) =
adjoint(transpose(broadcast((xs...) -> transpose(adjoint(f(conj.(xs)...))), quasiparentc.(tvs)...)))
# TODO unify and allow mixed combinations with a broadcast style


Expand Down
256 changes: 62 additions & 194 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ julia> lmul!(F.Q, B)
lmul!(A, B)

# THE one big BLAS dispatch
@inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
if all(in(('N', 'T', 'C')), (tA, tB))
if tA == 'T' && tB == 'N' && A === B
Expand All @@ -364,16 +364,16 @@ lmul!(A, B)
return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
end
end
return _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end

# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
@inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal}
if all(in(('N', 'T', 'C')), (tA, tB))
gemm_wrapper!(C, tA, tB, A, B, _add)
else
_generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end
end

Expand Down Expand Up @@ -563,11 +563,11 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
if all(in(('N', 'T', 'C')), (tA, tB))
gemm_wrapper!(C, tA, tB, A, B)
else
_generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end
end

function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
Expand Down Expand Up @@ -604,10 +604,10 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
stride(C, 2) >= size(C, 1))
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
end
_generic_matmatmul!(C, tA, tB, A, B, _add)
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end

function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasReal}
mA, nA = lapack_size(tA, A)
Expand Down Expand Up @@ -647,7 +647,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
return C
end
_generic_matmatmul!(C, tA, tB, A, B, _add)
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end

# blas.jl defines matmul for floats; other integer and mixed precision
Expand Down Expand Up @@ -764,197 +764,65 @@ end

const tilebufsize = 10800 # Approximately 32k/3

function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul)
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
mC, nC = size(C)

if iszero(_add.alpha)
return _rmul_or_fill!(C, _add.beta)
end
if mA == nA == mB == nB == mC == nC == 2
return matmul2x2!(C, tA, tB, A, B, _add)
end
if mA == nA == mB == nB == mC == nC == 3
return matmul3x3!(C, tA, tB, A, B, _add)
end
A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA)
B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB)
_generic_matmatmul!(C, tA, tB, A, B, _add)
end
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) =
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)

function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
_add::MulAddMul) where {T,S,R}
@assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C')
require_one_based_indexing(C, A, B)

mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if mB != nA
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)"))
end
if size(C,1) != mA || size(C,2) != nB
throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs ($mA,$nB)"))
end

if iszero(_add.alpha) || isempty(A) || isempty(B)
return _rmul_or_fill!(C, _add.beta)
end

tile_size = 0
if isbitstype(R) && isbitstype(T) && isbitstype(S) && (tA == 'N' || tB != 'N')
tile_size = floor(Int, sqrt(tilebufsize / max(sizeof(R), sizeof(S), sizeof(T), 1)))
end
@inbounds begin
if tile_size > 0
sz = (tile_size, tile_size)
Atile = Array{T}(undef, sz)
Btile = Array{S}(undef, sz)

z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1])
z = convert(promote_type(typeof(z1), R), z1)

if mA < tile_size && nA < tile_size && nB < tile_size
copy_transpose!(Atile, 1:nA, 1:mA, tA, A, 1:mA, 1:nA)
copyto!(Btile, 1:mB, 1:nB, tB, B, 1:mB, 1:nB)
for j = 1:nB
boff = (j-1)*tile_size
for i = 1:mA
aoff = (i-1)*tile_size
s = z
for k = 1:nA
s += Atile[aoff+k] * Btile[boff+k]
end
_modify!(_add, s, C, (i,j))
end
end
else
Ctile = Array{R}(undef, sz)
for jb = 1:tile_size:nB
jlim = min(jb+tile_size-1,nB)
jlen = jlim-jb+1
for ib = 1:tile_size:mA
ilim = min(ib+tile_size-1,mA)
ilen = ilim-ib+1
fill!(Ctile, z)
for kb = 1:tile_size:nA
klim = min(kb+tile_size-1,mB)
klen = klim-kb+1
copy_transpose!(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim)
copyto!(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim)
for j=1:jlen
bcoff = (j-1)*tile_size
for i = 1:ilen
aoff = (i-1)*tile_size
s = z
for k = 1:klen
s += Atile[aoff+k] * Btile[bcoff+k]
end
Ctile[bcoff+i] += s
end
end
end
if isone(_add.alpha) && iszero(_add.beta)
copyto!(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen)
else
C[ib:ilim, jb:jlim] .= @views _add.(Ctile[1:ilen, 1:jlen], C[ib:ilim, jb:jlim])
end
end
AxM = axes(A, 1)
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
BxK = axes(B, 1)
BxN = axes(B, 2)
CxM = axes(C, 1)
CxN = axes(C, 2)
if AxM != CxM
throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)"))
end
if AxK != BxK
throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)"))
end
if BxN != CxN
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
end
if isbitstype(R) && sizeof(R) 16 && !(A isa Adjoint || A isa Transpose)
_rmul_or_fill!(C, _add.beta)
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
@inbounds for n in BxN, k in BxK
Balpha = B[k,n]*_add.alpha
@simd for m in AxM
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
end
end
elseif isbitstype(R) && sizeof(R) 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose))
_rmul_or_fill!(C, _add.beta)
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
t = wrapperop(A)
pB = parent(B)
pA = parent(A)
tmp = similar(C, CxN)
ci = first(CxM)
ta = t(_add.alpha)
for i in AxM
mul!(tmp, pB, view(pA, :, i))
C[ci,:] .+= t.(ta .* tmp)
ci += 1
end
else
# Multiplication for non-plain-data uses the naive algorithm
if tA == 'N'
if tB == 'N'
for i = 1:mA, j = 1:nB
z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[i, k]*B[k, j]
end
_modify!(_add, Ctmp, C, (i,j))
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1]))
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[i, k] * transpose(B[j, k])
end
_modify!(_add, Ctmp, C, (i,j))
end
else
for i = 1:mA, j = 1:nB
z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]')
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[i, k]*B[j, k]'
end
_modify!(_add, Ctmp, C, (i,j))
end
end
elseif tA == 'T'
if tB == 'N'
for i = 1:mA, j = 1:nB
z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += transpose(A[k, i]) * B[k, j]
end
_modify!(_add, Ctmp, C, (i,j))
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1]))
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += transpose(A[k, i]) * transpose(B[j, k])
end
_modify!(_add, Ctmp, C, (i,j))
end
else
for i = 1:mA, j = 1:nB
z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]')
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += transpose(A[k, i]) * adjoint(B[j, k])
end
_modify!(_add, Ctmp, C, (i,j))
end
end
else
if tB == 'N'
for i = 1:mA, j = 1:nB
z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[k, i]'B[k, j]
end
_modify!(_add, Ctmp, C, (i,j))
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1]))
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += adjoint(A[k, i]) * transpose(B[j, k])
end
_modify!(_add, Ctmp, C, (i,j))
end
else
for i = 1:mA, j = 1:nB
z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]')
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[k, i]'B[j, k]'
end
_modify!(_add, Ctmp, C, (i,j))
end
if iszero(_add.alpha) || isempty(A) || isempty(B)
return _rmul_or_fill!(C, _add.beta)
end
a1 = first(AxK)
b1 = first(BxK)
@inbounds for i in AxM, j in BxN
z2 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
@simd for k in AxK
Ctmp = muladd(A[i, k], B[k, j], Ctmp)
end
_modify!(_add, Ctmp, C, (i,j))
end
end
end # @inbounds
C
return C
end


Expand All @@ -963,7 +831,7 @@ function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
end

function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
if !(size(A) == size(B) == size(C) == (2,2))
Expand Down Expand Up @@ -1030,7 +898,7 @@ function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
end

function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
if !(size(A) == size(B) == size(C) == (3,3))
Expand Down
Loading

0 comments on commit 0cf2bf1

Please sign in to comment.