diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index e755cc88c6572..10e6e7722f414 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -1081,3 +1081,141 @@ function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat end # inbounds C end + +const RealOrComplex = Union{Real,Complex} + +# Three-argument * +""" + *(A, B::AbstractMatrix, C) + A * B * C * D + +Chained multiplication of 3 or 4 matrices is done in the most efficient sequence, +based on the sizes of the arrays. That is, the number of scalar multiplications needed +for `(A * B) * C` (with 3 dense matrices) is compared to that for `A * (B * C)` +to choose which of these to execute. + +If the last factor is a vector, or the first a transposed vector, then it is efficient +to deal with these first. In particular `x' * B * y` means `(x' * B) * y` +for an ordinary column-major `B::Matrix`. Unlike `dot(x, B, y)`, this +allocates an intermediate array. + +If the first or last factor is a number, this will be fused with the matrix +multiplication, using 5-arg [`mul!`](@ref). + +See also [`muladd`](@ref), [`dot`](@ref). + +!!! compat "Julia 1.7" + These optimisations require at least Julia 1.7. +""" +*(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector) = A * (B*x) + +*(tu::AdjOrTransAbsVec, B::AbstractMatrix, v::AbstractVector) = (tu*B) * v +*(tu::AdjOrTransAbsVec, B::AdjOrTransAbsMat, v::AbstractVector) = tu * (B*v) + +*(A::AbstractMatrix, x::AbstractVector, γ::Number) = mat_vec_scalar(A,x,γ) +*(A::AbstractMatrix, B::AbstractMatrix, γ::Number) = mat_mat_scalar(A,B,γ) +*(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractVector{<:RealOrComplex}) = + mat_vec_scalar(B,C,α) +*(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}) = + mat_mat_scalar(B,C,α) + +*(α::Number, u::AbstractVector, tv::AdjOrTransAbsVec) = broadcast(*, α, u, tv) +*(u::AbstractVector, tv::AdjOrTransAbsVec, γ::Number) = broadcast(*, u, tv, γ) +*(u::AbstractVector, tv::AdjOrTransAbsVec, C::AbstractMatrix) = u * (tv*C) + +*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix) = _tri_matmul(A,B,C) +*(tv::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix) = (tv*B) * C + +function _tri_matmul(A,B,C,δ=nothing) + n,m = size(A) + # m,k == size(B) + k,l = size(C) + costAB_C = n*m*k + n*k*l # multiplications, allocations n*k + n*l + costA_BC = m*k*l + n*m*l # m*l + n*l + if costA_BC < costAB_C + isnothing(δ) ? A * (B*C) : A * mat_mat_scalar(B,C,δ) + else + isnothing(δ) ? (A*B) * C : mat_mat_scalar(A*B, C, δ) + end +end + +# Fast path for two arrays * one scalar is opt-in, via mat_vec_scalar and mat_mat_scalar. + +mat_vec_scalar(A, x, γ) = A * (x .* γ) # fallback +mat_vec_scalar(A::StridedMaybeAdjOrTransMat, x::StridedVector, γ) = _mat_vec_scalar(A, x, γ) +mat_vec_scalar(A::AdjOrTransAbsVec, x::StridedVector, γ) = (A * x) * γ + +function _mat_vec_scalar(A, x, γ) + T = promote_type(eltype(A), eltype(x), typeof(γ)) + C = similar(A, T, axes(A,1)) + mul!(C, A, x, γ, false) +end + +mat_mat_scalar(A, B, γ) = (A*B) .* γ # fallback +mat_mat_scalar(A::StridedMaybeAdjOrTransMat, B::StridedMaybeAdjOrTransMat, γ) = + _mat_mat_scalar(A, B, γ) + +function _mat_mat_scalar(A, B, γ) + T = promote_type(eltype(A), eltype(B), typeof(γ)) + C = similar(A, T, axes(A,1), axes(B,2)) + mul!(C, A, B, γ, false) +end + +mat_mat_scalar(A::AdjointAbsVec, B, γ) = (γ' .* (A * B)')' # preserving order, adjoint reverses +mat_mat_scalar(A::AdjointAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) = + mat_vec_scalar(B', A', γ')' + +mat_mat_scalar(A::TransposeAbsVec, B, γ) = transpose(γ .* transpose(A * B)) +mat_mat_scalar(A::TransposeAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) = + transpose(mat_vec_scalar(transpose(B), transpose(A), γ)) + + +# Four-argument *, by type +*(α::Number, β::Number, C::AbstractMatrix, x::AbstractVector) = (α*β) * C * x +*(α::Number, β::Number, C::AbstractMatrix, D::AbstractMatrix) = (α*β) * C * D +*(α::Number, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = α * B * (C*x) +*(α::Number, vt::AdjOrTransAbsVec, C::AbstractMatrix, x::AbstractVector) = α * (vt*C*x) +*(α::RealOrComplex, vt::AdjOrTransAbsVec{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) = + (α*vt*C) * D # solves an ambiguity + +*(A::AbstractMatrix, x::AbstractVector, γ::Number, δ::Number) = A * x * (γ*δ) +*(A::AbstractMatrix, B::AbstractMatrix, γ::Number, δ::Number) = A * B * (γ*δ) +*(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector, δ::Number, ) = A * (B*x*δ) +*(vt::AdjOrTransAbsVec, B::AbstractMatrix, x::AbstractVector, δ::Number) = (vt*B*x) * δ +*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = (vt*B) * C * δ + +*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = A * B * (C*x) +*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) = (vt*B) * C * D +*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = vt * B * (C*x) + +# Four-argument *, by size +*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = _tri_matmul(A,B,C,δ) +*(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) = + _tri_matmul(B,C,D,α) +*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) = + _quad_matmul(A,B,C,D) + +function _quad_matmul(A,B,C,D) + c1 = _mul_cost((A,B),(C,D)) + c2 = _mul_cost(((A,B),C),D) + c3 = _mul_cost(A,(B,(C,D))) + c4 = _mul_cost((A,(B,C)),D) + c5 = _mul_cost(A,((B,C),D)) + cmin = min(c1,c2,c3,c4,c5) + if c1 == cmin + (A*B) * (C*D) + elseif c2 == cmin + ((A*B) * C) * D + elseif c3 == cmin + A * (B * (C*D)) + elseif c4 == cmin + (A * (B*C)) * D + else + A * ((B*C) * D) + end +end +@inline _mul_cost(A::AbstractMatrix) = 0 +@inline _mul_cost((A,B)::Tuple) = _mul_cost(A,B) +@inline _mul_cost(A,B) = _mul_cost(A) + _mul_cost(B) + *(_mul_sizes(A)..., last(_mul_sizes(B))) +@inline _mul_sizes(A::AbstractMatrix) = size(A) +@inline _mul_sizes((A,B)::Tuple) = first(_mul_sizes(A)), last(_mul_sizes(B)) diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 6eed61f901aed..1febdfe49fb3b 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -766,4 +766,103 @@ end @test Matrix{Int}(undef, 2, 0) * Matrix{Int}(undef, 0, 3) == zeros(Int, 2, 3) end +@testset "3-arg *, order by type" begin + x = [1, 2im] + y = [im, 20, 30+40im] + z = [-1, 200+im, -3] + A = [1 2 3im; 4 5 6+im] + B = [-10 -20; -30 -40] + a = 3 + im * round(Int, 10^6*(pi-3)) + b = 123 + + @test x'*A*y == (x'*A)*y == x'*(A*y) + @test y'*A'*x == (y'*A')*x == y'*(A'*x) + @test y'*transpose(A)*x == (y'*transpose(A))*x == y'*(transpose(A)*x) + + @test B*A*y == (B*A)*y == B*(A*y) + + @test a*A*y == (a*A)*y == a*(A*y) + @test A*y*a == (A*y)*a == A*(y*a) + + @test a*B*A == (a*B)*A == a*(B*A) + @test B*A*a == (B*A)*a == B*(A*a) + + @test a*y'*z == (a*y')*z == a*(y'*z) + @test y'*z*a == (y'*z)*a == y'*(z*a) + + @test a*y*z' == (a*y)*z' == a*(y*z') + @test y*z'*a == (y*z')*a == y*(z'*a) + + @test a*x'*A == (a*x')*A == a*(x'*A) + @test x'*A*a == (x'*A)*a == x'*(A*a) + @test a*x'*A isa Adjoint{<:Any, <:Vector} + + @test a*transpose(x)*A == (a*transpose(x))*A == a*(transpose(x)*A) + @test transpose(x)*A*a == (transpose(x)*A)*a == transpose(x)*(A*a) + @test a*transpose(x)*A isa Transpose{<:Any, <:Vector} + + @test x'*B*A == (x'*B)*A == x'*(B*A) + @test x'*B*A isa Adjoint{<:Any, <:Vector} + + @test y*x'*A == (y*x')*A == y*(x'*A) + y31 = reshape(y,3,1) + @test y31*x'*A == (y31*x')*A == y31*(x'*A) + + vm = [rand(1:9,2,2) for _ in 1:3] + Mm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3] + + @test vm' * Mm * vm == (vm' * Mm) * vm == vm' * (Mm * vm) + @test Mm * Mm' * vm == (Mm * Mm') * vm == Mm * (Mm' * vm) + @test vm' * Mm * Mm == (vm' * Mm) * Mm == vm' * (Mm * Mm) + @test Mm * Mm' * Mm == (Mm * Mm') * Mm == Mm * (Mm' * Mm) +end + +@testset "3-arg *, order by size" begin + M44 = randn(4,4) + M24 = randn(2,4) + M42 = randn(4,2) + @test M44*M44*M44 ≈ (M44*M44)*M44 ≈ M44*(M44*M44) + @test M42*M24*M44 ≈ (M42*M24)*M44 ≈ M42*(M24*M44) + @test M44*M42*M24 ≈ (M44*M42)*M24 ≈ M44*(M42*M24) +end + +@testset "4-arg *, by type" begin + y = [im, 20, 30+40im] + z = [-1, 200+im, -3] + a = 3 + im * round(Int, 10^6*(pi-3)) + b = 123 + M = rand(vcat(1:9, im.*[1,2,3]), 3,3) + N = rand(vcat(1:9, im.*[1,2,3]), 3,3) + + @test a * b * M * y == (a*b) * (M*y) + @test a * b * M * N == (a*b) * (M*N) + @test a * M * N * y == (a*M) * (N*y) + @test a * y' * M * z == (a*y') * (M*z) + @test a * y' * M * N == (a*y') * (M*N) + + @test M * y * a * b == (M*y) * (a*b) + @test M * N * a * b == (M*N) * (a*b) + @test M * N * y * a == (a*M) * (N*y) + @test y' * M * z * a == (a*y') * (M*z) + @test y' * M * N * a == (a*y') * (M*N) + + @test M * N * conj(M) * y == (M*N) * (conj(M)*y) + @test y' * M * N * conj(M) == (y'*M) * (N*conj(M)) + @test y' * M * N * z == (y'*M) * (N*z) +end + +@testset "4-arg *, by size" begin + for shift in 1:5 + s1,s2,s3,s4,s5 = circshift(3:7, shift) + a=randn(s1,s2); b=randn(s2,s3); c=randn(s3,s4); d=randn(s4,s5) + + # _quad_matmul + @test *(a,b,c,d) ≈ (a*b) * (c*d) + + # _tri_matmul(A,B,B,δ) + @test *(11.1,b,c,d) ≈ (11.1*b) * (c*d) + @test *(a,b,c,99.9) ≈ (a*b) * (c*99.9) + end +end + end # module TestMatmul