Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 3-arg * methods #37898

Merged
merged 32 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2ed51d6
add 3-arg star
Oct 6, 2020
14fa53e
also handle 3 matrices
Oct 6, 2020
b0c33f4
a few more cases, and tests
Oct 6, 2020
1981525
two scalars + one array
Oct 6, 2020
71dfbb0
avoid dot dispatch
Oct 6, 2020
be5eee3
add some array-of-array tests
Oct 6, 2020
8e9bcd5
fix order of multiplication
Oct 6, 2020
cfa28cb
add a docstring
Oct 6, 2020
93d03b8
preserve order in fallbacks
Oct 6, 2020
b04b3a5
opt-in, take 1
Oct 6, 2020
1a29b74
remove two-scalar methods
Oct 6, 2020
6c3c9c0
fewer adjoint = less chance of spurious dimensions
Oct 6, 2020
a1ddb42
Update stdlib/LinearAlgebra/src/matmul.jl
mcabbott Oct 7, 2020
ea98220
four-argument *, why not
Oct 8, 2020
d6954e0
one more case
Oct 8, 2020
ffbb4b8
rm ambiguity
Oct 8, 2020
f0b3648
use 3-arg dot only when length(x)<64
Oct 17, 2020
1c92294
more explicit name for _SafeMatrix
Oct 17, 2020
0a71d33
don't use 3-arg dot at all
Oct 18, 2020
fb3b7da
use RealOrComplex for eltypes, too
Oct 27, 2020
f005fb5
better fallback for mat_vec_scalar
Nov 17, 2020
eac95e7
remove a now-duplicate definition of StridedMaybeAdjOrTransMat
Nov 20, 2020
2f39e33
Apply suggestions from code review
mcabbott Dec 11, 2020
d9b456e
constrain adjoint eltypes
Dec 11, 2020
766368f
two tiny fixes
mcabbott May 14, 2021
15f9f21
optimise dot(x,A,y)-like cases based on whether A is transposed
mcabbott May 16, 2021
5ef2922
two tests
mcabbott May 28, 2021
f778a00
Update stdlib/LinearAlgebra/test/matmul.jl
mcabbott May 28, 2021
eacc11b
change zero(γ) -> false as suggested
mcabbott Jun 7, 2021
ebb52fb
better docstring
mcabbott Jun 7, 2021
e7aa4ad
wording, plus examples
mcabbott Jun 7, 2021
468d19b
docstring
mcabbott Jun 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1081,3 +1081,141 @@ function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat
end # inbounds
C
end

const RealOrComplex = Union{Real,Complex}

# Three-argument *
"""
*(A, B::AbstractMatrix, C)
A * B * C * D

Chained multiplication of 3 or 4 matrices is done in the most efficient sequence,
based on the sizes of the arrays. That is, the number of scalar multiplications needed
for `(A * B) * C` (with 3 dense matrices) is compared to that for `A * (B * C)`
to choose which of these to execute.

If the last factor is a vector, or the first a transposed vector, then it is efficient
to deal with these first. In particular `x' * B * y` means `(x' * B) * y`
for an ordinary column-major `B::Matrix`. Unlike `dot(x, B, y)`, this
allocates an intermediate array.

If the first or last factor is a number, this will be fused with the matrix
multiplication, using 5-arg [`mul!`](@ref).

See also [`muladd`](@ref), [`dot`](@ref).

!!! compat "Julia 1.7"
These optimisations require at least Julia 1.7.
"""
*(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector) = A * (B*x)

*(tu::AdjOrTransAbsVec, B::AbstractMatrix, v::AbstractVector) = (tu*B) * v
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
*(tu::AdjOrTransAbsVec, B::AdjOrTransAbsMat, v::AbstractVector) = tu * (B*v)

*(A::AbstractMatrix, x::AbstractVector, γ::Number) = mat_vec_scalar(A,x,γ)
*(A::AbstractMatrix, B::AbstractMatrix, γ::Number) = mat_mat_scalar(A,B,γ)
*(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractVector{<:RealOrComplex}) =
mat_vec_scalar(B,C,α)
*(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}) =
mat_mat_scalar(B,C,α)

*(α::Number, u::AbstractVector, tv::AdjOrTransAbsVec) = broadcast(*, α, u, tv)
*(u::AbstractVector, tv::AdjOrTransAbsVec, γ::Number) = broadcast(*, u, tv, γ)
*(u::AbstractVector, tv::AdjOrTransAbsVec, C::AbstractMatrix) = u * (tv*C)

*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix) = _tri_matmul(A,B,C)
*(tv::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix) = (tv*B) * C
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

function _tri_matmul(A,B,C,δ=nothing)
n,m = size(A)
# m,k == size(B)
k,l = size(C)
costAB_C = n*m*k + n*k*l # multiplications, allocations n*k + n*l
costA_BC = m*k*l + n*m*l # m*l + n*l
if costA_BC < costAB_C
isnothing(δ) ? A * (B*C) : A * mat_mat_scalar(B,C,δ)
else
isnothing(δ) ? (A*B) * C : mat_mat_scalar(A*B, C, δ)
end
end

# Fast path for two arrays * one scalar is opt-in, via mat_vec_scalar and mat_mat_scalar.
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

mat_vec_scalar(A, x, γ) = A * (x .* γ) # fallback
mat_vec_scalar(A::StridedMaybeAdjOrTransMat, x::StridedVector, γ) = _mat_vec_scalar(A, x, γ)
mat_vec_scalar(A::AdjOrTransAbsVec, x::StridedVector, γ) = (A * x) * γ

function _mat_vec_scalar(A, x, γ)
T = promote_type(eltype(A), eltype(x), typeof(γ))
C = similar(A, T, axes(A,1))
mul!(C, A, x, γ, false)
end

mat_mat_scalar(A, B, γ) = (A*B) .* γ # fallback
mat_mat_scalar(A::StridedMaybeAdjOrTransMat, B::StridedMaybeAdjOrTransMat, γ) =
_mat_mat_scalar(A, B, γ)

function _mat_mat_scalar(A, B, γ)
T = promote_type(eltype(A), eltype(B), typeof(γ))
C = similar(A, T, axes(A,1), axes(B,2))
mul!(C, A, B, γ, false)
end

mat_mat_scalar(A::AdjointAbsVec, B, γ) = (γ' .* (A * B)')' # preserving order, adjoint reverses
mat_mat_scalar(A::AdjointAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) =
mat_vec_scalar(B', A', γ')'

mat_mat_scalar(A::TransposeAbsVec, B, γ) = transpose(γ .* transpose(A * B))
mat_mat_scalar(A::TransposeAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) =
transpose(mat_vec_scalar(transpose(B), transpose(A), γ))


# Four-argument *, by type
*(α::Number, β::Number, C::AbstractMatrix, x::AbstractVector) = (α*β) * C * x
*(α::Number, β::Number, C::AbstractMatrix, D::AbstractMatrix) = (α*β) * C * D
*(α::Number, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = α * B * (C*x)
*(α::Number, vt::AdjOrTransAbsVec, C::AbstractMatrix, x::AbstractVector) = α * (vt*C*x)
*(α::RealOrComplex, vt::AdjOrTransAbsVec{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) =
(α*vt*C) * D # solves an ambiguity

*(A::AbstractMatrix, x::AbstractVector, γ::Number, δ::Number) = A * x * (γ*δ)
*(A::AbstractMatrix, B::AbstractMatrix, γ::Number, δ::Number) = A * B * (γ*δ)
*(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector, δ::Number, ) = A * (B*x*δ)
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, x::AbstractVector, δ::Number) = (vt*B*x) * δ
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = (vt*B) * C * δ

*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = A * B * (C*x)
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) = (vt*B) * C * D
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = vt * B * (C*x)

# Four-argument *, by size
*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = _tri_matmul(A,B,C,δ)
*(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) =
_tri_matmul(B,C,D,α)
*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) =
_quad_matmul(A,B,C,D)

function _quad_matmul(A,B,C,D)
c1 = _mul_cost((A,B),(C,D))
c2 = _mul_cost(((A,B),C),D)
c3 = _mul_cost(A,(B,(C,D)))
c4 = _mul_cost((A,(B,C)),D)
c5 = _mul_cost(A,((B,C),D))
cmin = min(c1,c2,c3,c4,c5)
if c1 == cmin
(A*B) * (C*D)
elseif c2 == cmin
((A*B) * C) * D
elseif c3 == cmin
A * (B * (C*D))
elseif c4 == cmin
(A * (B*C)) * D
else
A * ((B*C) * D)
end
end
@inline _mul_cost(A::AbstractMatrix) = 0
@inline _mul_cost((A,B)::Tuple) = _mul_cost(A,B)
@inline _mul_cost(A,B) = _mul_cost(A) + _mul_cost(B) + *(_mul_sizes(A)..., last(_mul_sizes(B)))
@inline _mul_sizes(A::AbstractMatrix) = size(A)
@inline _mul_sizes((A,B)::Tuple) = first(_mul_sizes(A)), last(_mul_sizes(B))
99 changes: 99 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -766,4 +766,103 @@ end
@test Matrix{Int}(undef, 2, 0) * Matrix{Int}(undef, 0, 3) == zeros(Int, 2, 3)
end

@testset "3-arg *, order by type" begin
x = [1, 2im]
y = [im, 20, 30+40im]
z = [-1, 200+im, -3]
A = [1 2 3im; 4 5 6+im]
B = [-10 -20; -30 -40]
a = 3 + im * round(Int, 10^6*(pi-3))
b = 123

@test x'*A*y == (x'*A)*y == x'*(A*y)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
@test y'*A'*x == (y'*A')*x == y'*(A'*x)
@test y'*transpose(A)*x == (y'*transpose(A))*x == y'*(transpose(A)*x)

@test B*A*y == (B*A)*y == B*(A*y)

@test a*A*y == (a*A)*y == a*(A*y)
@test A*y*a == (A*y)*a == A*(y*a)

@test a*B*A == (a*B)*A == a*(B*A)
@test B*A*a == (B*A)*a == B*(A*a)

@test a*y'*z == (a*y')*z == a*(y'*z)
@test y'*z*a == (y'*z)*a == y'*(z*a)

@test a*y*z' == (a*y)*z' == a*(y*z')
@test y*z'*a == (y*z')*a == y*(z'*a)

@test a*x'*A == (a*x')*A == a*(x'*A)
@test x'*A*a == (x'*A)*a == x'*(A*a)
@test a*x'*A isa Adjoint{<:Any, <:Vector}

@test a*transpose(x)*A == (a*transpose(x))*A == a*(transpose(x)*A)
@test transpose(x)*A*a == (transpose(x)*A)*a == transpose(x)*(A*a)
@test a*transpose(x)*A isa Transpose{<:Any, <:Vector}

@test x'*B*A == (x'*B)*A == x'*(B*A)
@test x'*B*A isa Adjoint{<:Any, <:Vector}

@test y*x'*A == (y*x')*A == y*(x'*A)
y31 = reshape(y,3,1)
@test y31*x'*A == (y31*x')*A == y31*(x'*A)

vm = [rand(1:9,2,2) for _ in 1:3]
Mm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3]

@test vm' * Mm * vm == (vm' * Mm) * vm == vm' * (Mm * vm)
@test Mm * Mm' * vm == (Mm * Mm') * vm == Mm * (Mm' * vm)
@test vm' * Mm * Mm == (vm' * Mm) * Mm == vm' * (Mm * Mm)
@test Mm * Mm' * Mm == (Mm * Mm') * Mm == Mm * (Mm' * Mm)
end

@testset "3-arg *, order by size" begin
M44 = randn(4,4)
M24 = randn(2,4)
M42 = randn(4,2)
@test M44*M44*M44 ≈ (M44*M44)*M44 ≈ M44*(M44*M44)
@test M42*M24*M44 ≈ (M42*M24)*M44 ≈ M42*(M24*M44)
@test M44*M42*M24 ≈ (M44*M42)*M24 ≈ M44*(M42*M24)
end

@testset "4-arg *, by type" begin
y = [im, 20, 30+40im]
z = [-1, 200+im, -3]
a = 3 + im * round(Int, 10^6*(pi-3))
b = 123
M = rand(vcat(1:9, im.*[1,2,3]), 3,3)
N = rand(vcat(1:9, im.*[1,2,3]), 3,3)

@test a * b * M * y == (a*b) * (M*y)
@test a * b * M * N == (a*b) * (M*N)
@test a * M * N * y == (a*M) * (N*y)
@test a * y' * M * z == (a*y') * (M*z)
@test a * y' * M * N == (a*y') * (M*N)

@test M * y * a * b == (M*y) * (a*b)
@test M * N * a * b == (M*N) * (a*b)
@test M * N * y * a == (a*M) * (N*y)
@test y' * M * z * a == (a*y') * (M*z)
@test y' * M * N * a == (a*y') * (M*N)

@test M * N * conj(M) * y == (M*N) * (conj(M)*y)
@test y' * M * N * conj(M) == (y'*M) * (N*conj(M))
@test y' * M * N * z == (y'*M) * (N*z)
end

@testset "4-arg *, by size" begin
for shift in 1:5
s1,s2,s3,s4,s5 = circshift(3:7, shift)
a=randn(s1,s2); b=randn(s2,s3); c=randn(s3,s4); d=randn(s4,s5)

# _quad_matmul
@test *(a,b,c,d) ≈ (a*b) * (c*d)

# _tri_matmul(A,B,B,δ)
@test *(11.1,b,c,d) ≈ (11.1*b) * (c*d)
@test *(a,b,c,99.9) ≈ (a*b) * (c*99.9)
end
end

end # module TestMatmul