Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disambiguate structured and abstract matrix multiplication #52464

Merged
merged 9 commits into from
Jan 6, 2024
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,14 @@ _init_eltype(op, ::Type{TA}, ::Type{TB}) where {TA,TB} =
_initarray(op, ::Type{TA}, ::Type{TB}, C) where {TA,TB} =
similar(C, _init_eltype(op, TA, TB), size(C))

# destination type for matmul
matprod_dest(A::StructuredMatrix, B::StructuredMatrix, TS) = similar(B, TS, size(B))
matprod_dest(A, B::StructuredMatrix, TS) = similar(A, TS, size(A))
matprod_dest(A::StructuredMatrix, B, TS) = similar(B, TS, size(B))
matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = similar(A, TS)
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = similar(B, TS)
matprod_dest(A::Diagonal, B::Diagonal, TS) = similar(B, TS)

# General fallback definition for handling under- and overdetermined system as well as square problems
# While this definition is pretty general, it does e.g. promote to common element type of lhs and rhs
# which is required by LAPACK but not SuiteSparse which allows real-complex solves in some cases. Hence,
Expand Down
9 changes: 0 additions & 9 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,6 @@ function (*)(D::Diagonal, V::AbstractVector)
return D.diag .* V
end

(*)(A::AbstractMatrix, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D)
(*)(A::HermOrSym, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
(*)(D::Diagonal, A::AbstractMatrix) =
mul!(similar(A, promote_op(*, eltype(D.diag), eltype(A))), D, A)
(*)(D::Diagonal, A::HermOrSym) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)

rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)

Expand Down
5 changes: 4 additions & 1 deletion stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,11 @@ julia> [1 1; 0 1] * [1 0; 1 1]
"""
function (*)(A::AbstractMatrix, B::AbstractMatrix)
TS = promote_op(matprod, eltype(A), eltype(B))
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
mul!(similar(B, TS, (size(A, 1), size(B, 2))), A, B)
mul!(matprod_dest(A, B, TS), A, B)
end

matprod_dest(A, B, TS) = similar(B, TS, (size(A, 1), size(B, 2)))

# optimization for dispatching to BLAS, e.g. *(::Matrix{Float32}, ::Matrix{Float64})
# but avoiding the case *(::Matrix{<:BlasComplex}, ::Matrix{<:BlasReal})
# which is better handled by reinterpreting rather than promotion
Expand Down
13 changes: 0 additions & 13 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1471,12 +1471,6 @@ function *(A::AbstractTriangular, B::AbstractTriangular)
end

for mat in (:AbstractVector, :AbstractMatrix)
### Multiplication with triangle to the left and hence rhs cannot be transposed.
@eval function *(A::AbstractTriangular, B::$mat)
require_one_based_indexing(B)
TAB = _init_eltype(*, eltype(A), eltype(B))
mul!(similar(B, TAB, size(B)), A, B)
end
### Left division with triangle to the left hence rhs cannot be transposed. No quotients.
@eval function \(A::Union{UnitUpperTriangular,UnitLowerTriangular}, B::$mat)
require_one_based_indexing(B)
Expand All @@ -1502,13 +1496,6 @@ for mat in (:AbstractVector, :AbstractMatrix)
_rdiv!(similar(A, TAB, size(A)), A, B)
end
end
### Multiplication with triangle to the right and hence lhs cannot be transposed.
# Only for AbstractMatrix, hence outside the above loop.
function *(A::AbstractMatrix, B::AbstractTriangular)
require_one_based_indexing(A)
TAB = _init_eltype(*, eltype(A), eltype(B))
mul!(similar(A, TAB, size(A)), A, B)
end
# ambiguity resolution with definitions in matmul.jl
*(v::AdjointAbsVec, A::AbstractTriangular) = adjoint(adjoint(A) * v.parent)
*(v::TransposeAbsVec, A::AbstractTriangular) = transpose(transpose(A) * v.parent)
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1235,4 +1235,12 @@ end
end
end

@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin
A = [i+j for i in 1:2, j in 1:2]
S = SizedArrays.SizedArray{(2,2)}(A)
D = Diagonal([1:2;])
@test S * D == A * D
@test D * S == D * A
end

end # module TestDiagonal
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ using LinearAlgebra: BlasFloat, errorbounds, full!, transpose!,
UnitUpperTriangular, UnitLowerTriangular,
mul!, rdiv!, rmul!, lmul!

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl"))
using .Main.SizedArrays

debug && println("Triangular matrices")

n = 9
Expand Down Expand Up @@ -866,4 +870,12 @@ end
end
end

@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin
A = [i+j for i in 1:2, j in 1:2]
S = SizedArrays.SizedArray{(2,2)}(A)
U = UpperTriangular(ones(2,2))
@test S * U == A * U
@test U * S == U * A
end

end # module TestTriangular
5 changes: 5 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,9 @@ function *(S1::SizedArrayLike, S2::SizedArrayLike)
SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2))
SizedArray{SZ}(data)
end

# deliberately wide method definition to ensure that this doesn't lead to ambiguities with
# structured matrices
*(S1::SizedArrayLike, M::AbstractMatrix) = _data(S1) * M

end