Skip to content

Commit

Permalink
Add 3-arg * methods (JuliaLang#37898)
Browse files Browse the repository at this point in the history
This addresses the simplest part of JuliaLang#12065 (optimizing * for optimal matrix order), by adding some methods for * with 3 arguments, where this can be done more efficiently than working left-to-right.

Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
mcabbott and dkarrasch committed Jun 7, 2021
1 parent 708729b commit 51f5740
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 0 deletions.
138 changes: 138 additions & 0 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1081,3 +1081,141 @@ function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat
end # inbounds
C
end

const RealOrComplex = Union{Real,Complex}

# Three-argument *
"""
*(A, B::AbstractMatrix, C)
A * B * C * D
Chained multiplication of 3 or 4 matrices is done in the most efficient sequence,
based on the sizes of the arrays. That is, the number of scalar multiplications needed
for `(A * B) * C` (with 3 dense matrices) is compared to that for `A * (B * C)`
to choose which of these to execute.
If the last factor is a vector, or the first a transposed vector, then it is efficient
to deal with these first. In particular `x' * B * y` means `(x' * B) * y`
for an ordinary column-major `B::Matrix`. Unlike `dot(x, B, y)`, this
allocates an intermediate array.
If the first or last factor is a number, this will be fused with the matrix
multiplication, using 5-arg [`mul!`](@ref).
See also [`muladd`](@ref), [`dot`](@ref).
!!! compat "Julia 1.7"
These optimisations require at least Julia 1.7.
"""
*(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector) = A * (B*x)

*(tu::AdjOrTransAbsVec, B::AbstractMatrix, v::AbstractVector) = (tu*B) * v
*(tu::AdjOrTransAbsVec, B::AdjOrTransAbsMat, v::AbstractVector) = tu * (B*v)

*(A::AbstractMatrix, x::AbstractVector, γ::Number) = mat_vec_scalar(A,x,γ)
*(A::AbstractMatrix, B::AbstractMatrix, γ::Number) = mat_mat_scalar(A,B,γ)
*::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractVector{<:RealOrComplex}) =
mat_vec_scalar(B,C,α)
*::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}) =
mat_mat_scalar(B,C,α)

*::Number, u::AbstractVector, tv::AdjOrTransAbsVec) = broadcast(*, α, u, tv)
*(u::AbstractVector, tv::AdjOrTransAbsVec, γ::Number) = broadcast(*, u, tv, γ)
*(u::AbstractVector, tv::AdjOrTransAbsVec, C::AbstractMatrix) = u * (tv*C)

*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix) = _tri_matmul(A,B,C)
*(tv::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix) = (tv*B) * C

function _tri_matmul(A,B,C,δ=nothing)
n,m = size(A)
# m,k == size(B)
k,l = size(C)
costAB_C = n*m*k + n*k*l # multiplications, allocations n*k + n*l
costA_BC = m*k*l + n*m*l # m*l + n*l
if costA_BC < costAB_C
isnothing(δ) ? A * (B*C) : A * mat_mat_scalar(B,C,δ)
else
isnothing(δ) ? (A*B) * C : mat_mat_scalar(A*B, C, δ)
end
end

# Fast path for two arrays * one scalar is opt-in, via mat_vec_scalar and mat_mat_scalar.

mat_vec_scalar(A, x, γ) = A * (x .* γ) # fallback
mat_vec_scalar(A::StridedMaybeAdjOrTransMat, x::StridedVector, γ) = _mat_vec_scalar(A, x, γ)
mat_vec_scalar(A::AdjOrTransAbsVec, x::StridedVector, γ) = (A * x) * γ

function _mat_vec_scalar(A, x, γ)
T = promote_type(eltype(A), eltype(x), typeof(γ))
C = similar(A, T, axes(A,1))
mul!(C, A, x, γ, false)
end

mat_mat_scalar(A, B, γ) = (A*B) .* γ # fallback
mat_mat_scalar(A::StridedMaybeAdjOrTransMat, B::StridedMaybeAdjOrTransMat, γ) =
_mat_mat_scalar(A, B, γ)

function _mat_mat_scalar(A, B, γ)
T = promote_type(eltype(A), eltype(B), typeof(γ))
C = similar(A, T, axes(A,1), axes(B,2))
mul!(C, A, B, γ, false)
end

mat_mat_scalar(A::AdjointAbsVec, B, γ) =' .* (A * B)')' # preserving order, adjoint reverses
mat_mat_scalar(A::AdjointAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) =
mat_vec_scalar(B', A', γ')'

mat_mat_scalar(A::TransposeAbsVec, B, γ) = transpose.* transpose(A * B))
mat_mat_scalar(A::TransposeAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) =
transpose(mat_vec_scalar(transpose(B), transpose(A), γ))


# Four-argument *, by type
*::Number, β::Number, C::AbstractMatrix, x::AbstractVector) =*β) * C * x
*::Number, β::Number, C::AbstractMatrix, D::AbstractMatrix) =*β) * C * D
*::Number, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = α * B * (C*x)
*::Number, vt::AdjOrTransAbsVec, C::AbstractMatrix, x::AbstractVector) = α * (vt*C*x)
*::RealOrComplex, vt::AdjOrTransAbsVec{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) =
*vt*C) * D # solves an ambiguity

*(A::AbstractMatrix, x::AbstractVector, γ::Number, δ::Number) = A * x **δ)
*(A::AbstractMatrix, B::AbstractMatrix, γ::Number, δ::Number) = A * B **δ)
*(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector, δ::Number, ) = A * (B*x*δ)
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, x::AbstractVector, δ::Number) = (vt*B*x) * δ
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = (vt*B) * C * δ

*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = A * B * (C*x)
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) = (vt*B) * C * D
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = vt * B * (C*x)

# Four-argument *, by size
*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = _tri_matmul(A,B,C,δ)
*::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) =
_tri_matmul(B,C,D,α)
*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) =
_quad_matmul(A,B,C,D)

function _quad_matmul(A,B,C,D)
c1 = _mul_cost((A,B),(C,D))
c2 = _mul_cost(((A,B),C),D)
c3 = _mul_cost(A,(B,(C,D)))
c4 = _mul_cost((A,(B,C)),D)
c5 = _mul_cost(A,((B,C),D))
cmin = min(c1,c2,c3,c4,c5)
if c1 == cmin
(A*B) * (C*D)
elseif c2 == cmin
((A*B) * C) * D
elseif c3 == cmin
A * (B * (C*D))
elseif c4 == cmin
(A * (B*C)) * D
else
A * ((B*C) * D)
end
end
@inline _mul_cost(A::AbstractMatrix) = 0
@inline _mul_cost((A,B)::Tuple) = _mul_cost(A,B)
@inline _mul_cost(A,B) = _mul_cost(A) + _mul_cost(B) + *(_mul_sizes(A)..., last(_mul_sizes(B)))
@inline _mul_sizes(A::AbstractMatrix) = size(A)
@inline _mul_sizes((A,B)::Tuple) = first(_mul_sizes(A)), last(_mul_sizes(B))
99 changes: 99 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -766,4 +766,103 @@ end
@test Matrix{Int}(undef, 2, 0) * Matrix{Int}(undef, 0, 3) == zeros(Int, 2, 3)
end

@testset "3-arg *, order by type" begin
x = [1, 2im]
y = [im, 20, 30+40im]
z = [-1, 200+im, -3]
A = [1 2 3im; 4 5 6+im]
B = [-10 -20; -30 -40]
a = 3 + im * round(Int, 10^6*(pi-3))
b = 123

@test x'*A*y == (x'*A)*y == x'*(A*y)
@test y'*A'*x == (y'*A')*x == y'*(A'*x)
@test y'*transpose(A)*x == (y'*transpose(A))*x == y'*(transpose(A)*x)

@test B*A*y == (B*A)*y == B*(A*y)

@test a*A*y == (a*A)*y == a*(A*y)
@test A*y*a == (A*y)*a == A*(y*a)

@test a*B*A == (a*B)*A == a*(B*A)
@test B*A*a == (B*A)*a == B*(A*a)

@test a*y'*z == (a*y')*z == a*(y'*z)
@test y'*z*a == (y'*z)*a == y'*(z*a)

@test a*y*z' == (a*y)*z' == a*(y*z')
@test y*z'*a == (y*z')*a == y*(z'*a)

@test a*x'*A == (a*x')*A == a*(x'*A)
@test x'*A*a == (x'*A)*a == x'*(A*a)
@test a*x'*A isa Adjoint{<:Any, <:Vector}

@test a*transpose(x)*A == (a*transpose(x))*A == a*(transpose(x)*A)
@test transpose(x)*A*a == (transpose(x)*A)*a == transpose(x)*(A*a)
@test a*transpose(x)*A isa Transpose{<:Any, <:Vector}

@test x'*B*A == (x'*B)*A == x'*(B*A)
@test x'*B*A isa Adjoint{<:Any, <:Vector}

@test y*x'*A == (y*x')*A == y*(x'*A)
y31 = reshape(y,3,1)
@test y31*x'*A == (y31*x')*A == y31*(x'*A)

vm = [rand(1:9,2,2) for _ in 1:3]
Mm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3]

@test vm' * Mm * vm == (vm' * Mm) * vm == vm' * (Mm * vm)
@test Mm * Mm' * vm == (Mm * Mm') * vm == Mm * (Mm' * vm)
@test vm' * Mm * Mm == (vm' * Mm) * Mm == vm' * (Mm * Mm)
@test Mm * Mm' * Mm == (Mm * Mm') * Mm == Mm * (Mm' * Mm)
end

@testset "3-arg *, order by size" begin
M44 = randn(4,4)
M24 = randn(2,4)
M42 = randn(4,2)
@test M44*M44*M44 (M44*M44)*M44 M44*(M44*M44)
@test M42*M24*M44 (M42*M24)*M44 M42*(M24*M44)
@test M44*M42*M24 (M44*M42)*M24 M44*(M42*M24)
end

@testset "4-arg *, by type" begin
y = [im, 20, 30+40im]
z = [-1, 200+im, -3]
a = 3 + im * round(Int, 10^6*(pi-3))
b = 123
M = rand(vcat(1:9, im.*[1,2,3]), 3,3)
N = rand(vcat(1:9, im.*[1,2,3]), 3,3)

@test a * b * M * y == (a*b) * (M*y)
@test a * b * M * N == (a*b) * (M*N)
@test a * M * N * y == (a*M) * (N*y)
@test a * y' * M * z == (a*y') * (M*z)
@test a * y' * M * N == (a*y') * (M*N)

@test M * y * a * b == (M*y) * (a*b)
@test M * N * a * b == (M*N) * (a*b)
@test M * N * y * a == (a*M) * (N*y)
@test y' * M * z * a == (a*y') * (M*z)
@test y' * M * N * a == (a*y') * (M*N)

@test M * N * conj(M) * y == (M*N) * (conj(M)*y)
@test y' * M * N * conj(M) == (y'*M) * (N*conj(M))
@test y' * M * N * z == (y'*M) * (N*z)
end

@testset "4-arg *, by size" begin
for shift in 1:5
s1,s2,s3,s4,s5 = circshift(3:7, shift)
a=randn(s1,s2); b=randn(s2,s3); c=randn(s3,s4); d=randn(s4,s5)

# _quad_matmul
@test *(a,b,c,d) (a*b) * (c*d)

# _tri_matmul(A,B,B,δ)
@test *(11.1,b,c,d) (11.1*b) * (c*d)
@test *(a,b,c,99.9) (a*b) * (c*99.9)
end
end

end # module TestMatmul

0 comments on commit 51f5740

Please sign in to comment.