Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 3-arg * methods #37898

Merged
merged 32 commits into from
Jun 7, 2021
Merged

Add 3-arg * methods #37898

merged 32 commits into from
Jun 7, 2021

Conversation

mcabbott
Copy link
Contributor

@mcabbott mcabbott commented Oct 6, 2020

This addresses the simplest part of #12065, by adding some methods for * with 3 arguments, where this can be done more efficiently than working left-to-right:

using BenchmarkTools

# Matrix-matrix-vector
@btime A*B*x  setup=(N=100; A=rand(N,N); B=rand(N,N); x=rand(N));
@btime A*(B*x)  setup=(N=100; A=rand(N,N); B=rand(N,N); x=rand(N)); # 10x faster

# Scalar-matrix-vector (with function from PR)
@btime a*B*x  setup=(N=100; a=rand(); B=rand(N,N); x=rand(N));
@btime mat_vec_scalar(a,B,x)  setup=(N=100; a=rand(); B=rand(N,N); x=rand(N)); # 5x faster

# 3-arg dot
@btime x'*A*y  setup=(N=100; x=rand(N); A=rand(N,N); y=rand(N)); 
@btime dot(x,A,y)  setup=(N=100; x=rand(N); A=rand(N,N); y=rand(N)); # slightly faster, zero alloc.

# Three matrices
@btime A*B*C  setup=(N=100; A=rand(N,10); B=rand(10,N); C=rand(N,N));
@btime A*(B*C)  setup=(N=100; A=rand(N,10); B=rand(10,N); C=rand(N,N)); # 3x faster

I think it's careful about Adjoint & Transpose vectors, but might need more thought about Diagonal and other special matrix types. (Edit -- testing says (zeros(0))' * Diagonal(zeros(0)) * zeros(0) is ambiguous.)

See also https://github.com/AustinPrivett/MatrixChainMultiply.jl, and discussion https://discourse.julialang.org/t/why-is-multiplication-a-b-c-left-associative-foldl-not-right-associative-foldr/17552.

@dkarrasch dkarrasch added domain:linear algebra Linear algebra performance Must go faster labels Oct 6, 2020
@fredrikekre
Copy link
Member

Will this not run into the same problems as #24343 (comment)? That PR ended up doing it for Array only.

@mcabbott
Copy link
Contributor Author

mcabbott commented Oct 6, 2020

I guess the functions which fuse the scalar like mat_vec_scalar(A, x, γ) won't respect an overload of 2-arg * for your matrix type. Their fast path via similar & mul! could be made opt-in. Right now:

julia> sprand(13, 11, 0.3) * sprand(11, 0.3) * 17                            
13-element SparseVector{Float64,Int64} with 3 stored entries:
  [1 ]  =  1.40409
  [2 ]  =  5.72894
  [7 ]  =  5.95465

julia> SA[1 2; 3 4] * SA[5,6] * 7
2-element MArray{Tuple{2},Int64,1,2} with indices SOneTo(2):
 119
 273

julia> 5 * Diagonal(1:2) * Diagonal(3:4)
2×2 Adjoint{Int64,SparseMatrixCSC{Int64,Int64}}:
 15   0
  0  40

The rest dispatches to either 2-arg *, or to broadcast (when * does) so probably should be fine.

I haven't thought through all the weird special matrices in LinearAlgebra.

Edit -- the above examples (and all special matrices) now go the fallback path, and no longer give unexpected types:

julia> SA[1 2; 3 4] * SA[5,6] * 7
2-element SArray{Tuple{2},Int64,1,2} with indices SOneTo(2):
 119
 273

julia> 5 * Diagonal(1:2) * Diagonal(3:4)
2×2 Diagonal{Int64,Array{Int64,1}}:
 15   
    40

Copy link
Member

@dkarrasch dkarrasch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. Pretty verbose, but that's perhaps due to the matter. One "beautification" suggestion, and then I'd shout out for another review.

stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link
Contributor Author

mcabbott commented Oct 7, 2020

While it doesn't fit the title, should we also add some 4-arg cases? There are quite a few which are trivial to route to 2-arg and 3-arg methods, without needing any new functions:

# Four-argument *
*::Number, β::Number, γ::Number, D::AbstractArray) =*β*γ) * D
*::Number, β::Number, C::AbstractMatrix, D::AbstractVecOrMat) =*β) * C * D
*::Number, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = α * B * (C*x)
*::Number, vt::AdjOrTransAbsVec, C::AbstractMatrix, D::AbstractMatrix) =*vt*C) * D
*::Number, vt::AdjOrTransAbsVec, C::AbstractMatrix, x::AbstractVector) = α * (vt*C*x)

*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = A * B * (C*x)
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) = (vt*B) * C * D
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = vt * B * (C*x)

I guess there are 6 5 possible orders for the case of 4 matrices, which it also wouldn't be crazy to just write out:

function _quad_mul(A,B,C,D)
    c1 = _cost((A,B),(C,D))
    c2 = _cost(((A,B),C),D)
    c3 = _cost(A,(B,(C,D)))
    c4 = _cost((A,(B,C)),D)
    c5 = _cost(A,((B,C),D))
    cmin = min(c1,c2,c3,c4,c5)
    if c1 == cmin
        (A*B) * (C*D)
    elseif c2 == cmin
        ((A*B) * C) * D
    elseif c3 == cmin
        A * (B * (C*D))
    elseif c4 == cmin
        (A * (B*C)) * D
    else
        A * ((B*C) * D)
    end
end
@inline _cost(A::AbstractMatrix) = 0
@inline _cost((A,B)::Tuple) = _cost(A,B)
@inline _cost(A,B) = _cost(A) + _cost(B) + *(_sizes(A)..., _sizes(B)[end])
@inline _sizes(A::AbstractMatrix) = size(A)
@inline _sizes((A,B)::Tuple) = _sizes(A)[begin], _sizes(B)[end]

using Random, Test, BenchmarkTools
s1,s2,s3,s4,s5 = shuffle([5,10,20,100,200])

a=rand(s1,s2); b=rand(s2,s3); c=rand(s3,s4); d=rand(s4,s5);
@test *(a,b,c,d)  _quad_mul(a,b,c,d)
@btime *($a,$b,$c,$d);
@btime _quad_mul($a,$b,$c,$d);

s1,s2,s3,s4,s5 = fill(30,5) # 0.2% overhead at size 30
s1,s2,s3,s4,s5 = fill(3,5)  # 4% overhead at size 3 (7ns)

using StaticArrays
s1,s2,s3,s4,s5 = shuffle([2,3,5,7,11])
a=@SMatrix rand(s1,s2); b=@SMatrix rand(s2,s3); c=@SMatrix rand(s3,s4); d=@SMatrix rand(s4,s5);
@test *(a,b,c,d)  _quad_mul(a,b,c,d)
@btime *($(Ref(a))[],$(Ref(b))[],$(Ref(c))[],$(Ref(d))[]);
@btime _quad_mul($(Ref(a))[],$(Ref(b))[],$(Ref(c))[],$(Ref(d))[]);

s1,s2,s3,s4,s5 = fill(3,5) # 28% overhead for 3x3 SMatrix (6ns)

It should be easy to make StaticArrays avoid this overhead: mcabbott/StaticArrays.jl@a1aa074

@mcabbott
Copy link
Contributor Author

mcabbott commented Oct 16, 2020

Using 3-arg dot isn't always faster. Would it be too strange to check the size before deciding what method to call? dot is always zero-allocation, but right now x'*A*y isn't, so calling dot only when it's sure to be faster seems like an upgrade.

julia> N = 10;

julia> @btime x'*A*y  setup=(x=rand($N); A=rand($N,$N); y=rand($N));
  138.139 ns (1 allocation: 160 bytes)

julia> @btime dot(x,A,y)  setup=(x=rand($N); A=rand($N,$N); y=rand($N));
  83.566 ns (0 allocations: 0 bytes)

julia> N = 100;

julia> @btime x'*A*y  setup=(x=rand($N); A=rand($N,$N); y=rand($N));
  1.151 μs (1 allocation: 896 bytes)

julia> @btime dot(x,A,y)  setup=(x=rand($N); A=rand($N,$N); y=rand($N));
  1.526 μs (0 allocations: 0 bytes)

julia> N = 10_000;

julia> @btime x'*A*y  setup=(x=rand($N); A=rand($N,$N); y=rand($N));
  33.634 ms (2 allocations: 78.20 KiB)

julia> @btime dot(x,A,y)  setup=(x=rand($N); A=rand($N,$N); y=rand($N));
  62.127 ms (0 allocations: 0 bytes)

julia> BLAS.vendor()
:mkl

Edit -- maybe it's best just never to call dot. It's a little faster only in a few small cases. I worry that this will be a source of surprising bugs, if some package's tests don't cover all sizes.

@mcabbott
Copy link
Contributor Author

Bump?

Am happy to remove 4-arg cases, if they are thought too exotic / too complicated for Base LinearAlgebra to support.

stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
Copy link
Member

@dkarrasch dkarrasch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, but I think we should have somebody else take a look.

Copy link
Member

@dkarrasch dkarrasch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few minor comments, and then we should really push this over the finish line. This is great stuff.

stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
@dkarrasch dkarrasch added the backport 1.6 Change should be backported to release-1.6 label Dec 9, 2020
@KristofferC
Copy link
Sponsor Member

Why is this marked with backport label @dkarrasch?

@dkarrasch
Copy link
Member

I thought we have a "soft feature freeze", and that we include/backport PRs that have been developed for a while in the v1.6 cycle. This one, specifically, just got a bit forgotten, and my last comments are rather cosmetic. Please feel free to remove the backport label if I misunderstood the intention of the soft feature freeze.

@KristofferC KristofferC removed the backport 1.6 Change should be backported to release-1.6 label Dec 14, 2020
@KristofferC
Copy link
Sponsor Member

I thought we have a "soft feature freeze", and that we include/backport PRs that have been developed for a while in the v1.6 cycle.

True, but now it is getting a bit late I think. I'll remove the label for now.

@mcabbott
Copy link
Contributor Author

This just missed 1.6, but perhaps it should be in 1.7?

Looking over it again briefly, I'm slightly dismayed that it needs 100 lines of code, for what seems a pretty simple optimisation. Half of them are for the 4-arg case. However, there are very few existing methods to bump into (filter(m -> m.nargs == 4, methods(*)) has a few dot cases and some BigInt & BigFloat) so we can hope that the complexity will seldom trip anyone up.

Copy link
Member

@dkarrasch dkarrasch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two minor suggestions. Could you also add a comment as to what is optimized? I think it's the number of operations, and not memory allocation for intermediate results, right?

stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link
Contributor Author

mcabbott commented Jun 7, 2021

Good point, I have re-written the docstring to be more explicit, see if you like it.

stdlib/LinearAlgebra/src/matmul.jl Outdated Show resolved Hide resolved

If the last factor is a vector, or the first a transposed vector, then it is efficient
to deal with these first. In particular `x' * B * y` means `(x' * B) * y`
for an ordinary colum-major `B::Matrix`. This is often equivalent to `dot(x, B, y)`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

colum -> column

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, why not be explicit: "For scalar eltypes, this is equivalent to dot(x, B, y)" or something in that direction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "Unlike dot(), ..."? I don't want to dwell on pinning down exactly what types they agree or disagree on.

Copy link
Contributor Author

@mcabbott mcabbott Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an aside, trying to puzzle out what recursive dot means in the 3-arg case, is this the intended behaviour?

julia> x = [rand(2,2) for _ in 1:3]; A = [rand(2,2) for _ in 1:3, _ in 1:3]; y = [rand(2,2) for _ in 1:3];

julia> dot(x,A,y)
7.411848453027886

julia> @tullio _ := x[i][b,a] * A[i,j][b,c] * y[j][c,a]
7.4118484530278845

julia> @tullio _ := tr(x[i]' * A[i,j] * y[j])
7.411848453027887

The action on the component matrices looks like a trace of a matrix product.

@dkarrasch
Copy link
Member

This is ready to go, I think.

@dkarrasch dkarrasch added the status:merge me PR is reviewed. Merge when all tests are passing label Jun 7, 2021
@vtjnash vtjnash merged commit 51f5740 into JuliaLang:master Jun 7, 2021
@mcabbott mcabbott deleted the star branch June 7, 2021 22:56
shirodkara pushed a commit to shirodkara/julia that referenced this pull request Jun 9, 2021
This addresses the simplest part of JuliaLang#12065 (optimizing * for optimal matrix order), by adding some methods for * with 3 arguments, where this can be done more efficiently than working left-to-right.

Co-authored-by: Daniel Karrasch <[email protected]>
@DilumAluthge DilumAluthge removed the status:merge me PR is reviewed. Merge when all tests are passing label Jun 18, 2021
johanmon pushed a commit to johanmon/julia that referenced this pull request Jul 5, 2021
This addresses the simplest part of JuliaLang#12065 (optimizing * for optimal matrix order), by adding some methods for * with 3 arguments, where this can be done more efficiently than working left-to-right.

Co-authored-by: Daniel Karrasch <[email protected]>
@marius311
Copy link
Contributor

marius311 commented Jul 11, 2021

The fallback introduced here:

mat_mat_scalar(A, B, γ) = (A*B) .* γ # fallback

breaks some of my custom array code (and myabe others) and I'm wondering if it could be changed to just (A*B) * γ? (and maybe in time for 1.7?)

For Base Arrays these should be identical in terms of performance since it will just call broadcasting one method deeper, but for custom arrays which had only defined *(::CustomArray, ::Number) and *(::Number, ::CustomArray) and made no attempt to implement the more complex broadcasting API (eg, thats what my custom arrays did), this change breaks that code since it skips straight to broadcasting.

@dkarrasch
Copy link
Member

For such a generic fallback, I think it is very reasonable to make it (A * B) * γ, indeed.

@mcabbott
Copy link
Contributor Author

Sorry about breaking your code! I agree there's no good reason for .* there, I just didn't think to worry about the distinction. Have made a PR removing these.

@marius311
Copy link
Contributor

Awesome thanks for the quick reply! And no worries, it does also improve some other things as intended!

vtjnash pushed a commit that referenced this pull request Feb 12, 2024
PR #37898 added methods to `*` for chained matrix multiplication. They
have a descriptive docstring but I don't think this is mentioned in the
manual.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
domain:linear algebra Linear algebra performance Must go faster
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants