Skip to content

Commit

Permalink
Disambiguate structured and abstract matrix multiplication (JuliaLang…
Browse files Browse the repository at this point in the history
…#52464)

Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
jishnub and dkarrasch committed Jan 6, 2024
1 parent 0cb5a0e commit ea085ea
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 63 deletions.
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,17 @@ _makevector(x::AbstractVector) = Vector(x)
_pushzero(A) = (B = similar(A, length(A)+1); @inbounds B[begin:end-1] .= A; @inbounds B[end] = zero(eltype(B)); B)
_droplast!(A) = deleteat!(A, lastindex(A))

# destination type for matmul
matprod_dest(A::StructuredMatrix, B::StructuredMatrix, TS) = similar(B, TS, size(B))
matprod_dest(A, B::StructuredMatrix, TS) = similar(A, TS, size(A))
matprod_dest(A::StructuredMatrix, B, TS) = similar(B, TS, size(B))
matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = similar(A, TS)
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = similar(B, TS)
matprod_dest(A::Diagonal, B::Diagonal, TS) = similar(B, TS)
matprod_dest(A::HermOrSym, B::Diagonal, TS) = similar(A, TS, size(A))
matprod_dest(A::Diagonal, B::HermOrSym, TS) = similar(B, TS, size(B))

# TODO: remove once not used anymore in SparseArrays.jl
# some trait like this would be cool
# onedefined(::Type{T}) where {T} = hasmethod(one, (T,))
# but we are actually asking for oneunit(T), that is, however, defined for generic T as
Expand Down
26 changes: 14 additions & 12 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -802,34 +802,35 @@ ldiv!(c::AbstractVecOrMat, A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMa
(t = wrapperop(A); _rdiv!(t(c), t(b), t(A)); return c)

### Generic promotion methods and fallbacks
\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(_initarray(\, eltype(A), eltype(B), B), A, B)
\(A::Bidiagonal, B::AbstractVecOrMat) =
ldiv!(matprod_dest(A, B, promote_op(\, eltype(A), eltype(B))), A, B)
\(xA::AdjOrTrans{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = copy(xA) \ B

### Triangular specializations
for tri in (:UpperTriangular, :UnitUpperTriangular)
@eval function \(B::Bidiagonal, U::$tri)
A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U)
A = ldiv!(matprod_dest(B, U, promote_op(\, eltype(B), eltype(U))), B, U)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
@eval function \(U::$tri, B::Bidiagonal)
A = ldiv!(_initarray(\, eltype(U), eltype(B), U), U, B)
A = ldiv!(matprod_dest(U, B, promote_op(\, eltype(U), eltype(B))), U, B)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
end
for tri in (:LowerTriangular, :UnitLowerTriangular)
@eval function \(B::Bidiagonal, L::$tri)
A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L)
A = ldiv!(matprod_dest(B, L, promote_op(\, eltype(B), eltype(L))), B, L)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
@eval function \(L::$tri, B::Bidiagonal)
A = ldiv!(_initarray(\, eltype(L), eltype(B), L), L, B)
A = ldiv!(matprod_dest(L, B, promote_op(\, eltype(L), eltype(B))), L, B)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
end

### Diagonal specialization
function \(B::Bidiagonal, D::Diagonal)
A = ldiv!(_initarray(\, eltype(B), eltype(D), D), B, D)
A = ldiv!(similar(D, promote_op(\, eltype(B), eltype(D)), size(D)), B, D)
return B.uplo == 'U' ? UpperTriangular(A) : LowerTriangular(A)
end

Expand Down Expand Up @@ -879,33 +880,34 @@ rdiv!(A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) = @inline _rdiv!(A,
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) =
(t = wrapperop(B); ldiv!(t(C), t(B), t(A)); return C)

/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(_initarray(/, eltype(A), eltype(B), A), A, B)
/(A::AbstractMatrix, B::Bidiagonal) =
_rdiv!(similar(A, promote_op(/, eltype(A), eltype(B)), size(A)), A, B)

### Triangular specializations
for tri in (:UpperTriangular, :UnitUpperTriangular)
@eval function /(U::$tri, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B)
A = _rdiv!(matprod_dest(U, B, promote_op(/, eltype(U), eltype(B))), U, B)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
@eval function /(B::Bidiagonal, U::$tri)
A = _rdiv!(_initarray(/, eltype(B), eltype(U), U), B, U)
A = _rdiv!(matprod_dest(B, U, promote_op(/, eltype(B), eltype(U))), B, U)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
end
for tri in (:LowerTriangular, :UnitLowerTriangular)
@eval function /(L::$tri, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B)
A = _rdiv!(matprod_dest(L, B, promote_op(/, eltype(L), eltype(B))), L, B)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
@eval function /(B::Bidiagonal, L::$tri)
A = _rdiv!(_initarray(/, eltype(B), eltype(L), L), B, L)
A = _rdiv!(matprod_dest(B, L, promote_op(/, eltype(B), eltype(L))), B, L)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
end

### Diagonal specialization
function /(D::Diagonal, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(D), eltype(B), D), D, B)
A = _rdiv!(similar(D, promote_op(/, eltype(D), eltype(B)), size(D)), D, B)
return B.uplo == 'U' ? UpperTriangular(A) : LowerTriangular(A)
end

Expand Down
26 changes: 7 additions & 19 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,6 @@ function (*)(D::Diagonal, V::AbstractVector)
return D.diag .* V
end

(*)(A::AbstractMatrix, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D)
(*)(A::HermOrSym, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
(*)(D::Diagonal, A::AbstractMatrix) =
mul!(similar(A, promote_op(*, eltype(D.diag), eltype(A))), D, A)
(*)(D::Diagonal, A::HermOrSym) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)

rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)

Expand Down Expand Up @@ -431,8 +422,8 @@ function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal)
return Diagonal(Da.diag .* Db.diag .* Dc.diag)
end

/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D))), A, D)
/(A::HermOrSym, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D)), size(A)), A, D)
/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(matprod_dest(A, D, promote_op(/, eltype(A), eltype(D))), A, D)
/(A::HermOrSym, D::Diagonal) = _rdiv!(matprod_dest(A, D, promote_op(/, eltype(A), eltype(D))), A, D)

rdiv!(A::AbstractVecOrMat, D::Diagonal) = @inline _rdiv!(A, A, D)
# avoid copy when possible via internal 3-arg backend
Expand All @@ -458,8 +449,8 @@ function \(D::Diagonal, B::AbstractVector)
isnothing(j) || throw(SingularException(j))
return D.diag .\ B
end
\(D::Diagonal, B::AbstractMatrix) = ldiv!(similar(B, _init_eltype(\, eltype(D), eltype(B))), D, B)
\(D::Diagonal, B::HermOrSym) = ldiv!(similar(B, _init_eltype(\, eltype(D), eltype(B)), size(B)), D, B)
\(D::Diagonal, B::AbstractMatrix) = ldiv!(matprod_dest(D, B, promote_op(\, eltype(D), eltype(B))), D, B)
\(D::Diagonal, B::HermOrSym) = ldiv!(matprod_dest(D, B, promote_op(\, eltype(D), eltype(B))), D, B)

ldiv!(D::Diagonal, B::AbstractVecOrMat) = @inline ldiv!(B, D, B)
function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat)
Expand All @@ -479,8 +470,8 @@ function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat)
end

# Optimizations for \, / between Diagonals
\(D::Diagonal, B::Diagonal) = ldiv!(similar(B, promote_op(\, eltype(D), eltype(B))), D, B)
/(A::Diagonal, D::Diagonal) = _rdiv!(similar(A, promote_op(/, eltype(A), eltype(D))), A, D)
\(D::Diagonal, B::Diagonal) = ldiv!(matprod_dest(D, B, promote_op(\, eltype(D), eltype(B))), D, B)
/(A::Diagonal, D::Diagonal) = _rdiv!(matprod_dest(A, D, promote_op(/, eltype(A), eltype(D))), A, D)
function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal)
n, k = length(Db.diag), length(Da.diag)
n == k || throw(DimensionMismatch("left hand side has $n columns but D is $k by $k"))
Expand Down Expand Up @@ -543,7 +534,7 @@ function (/)(S::SymTridiagonal, D::Diagonal)
dl = similar(S.ev, T, max(length(S.dv)-1, 0))
_rdiv!(Tridiagonal(dl, d, du), S, D)
end
(/)(T::Tridiagonal, D::Diagonal) = _rdiv!(similar(T, promote_op(/, eltype(T), eltype(D))), T, D)
(/)(T::Tridiagonal, D::Diagonal) = _rdiv!(matprod_dest(T, D, promote_op(/, eltype(T), eltype(D))), T, D)
function _rdiv!(T::Tridiagonal, S::Union{SymTridiagonal,Tridiagonal}, D::Diagonal)
n = size(S, 2)
dd = D.diag
Expand Down Expand Up @@ -876,9 +867,6 @@ function svd(D::Diagonal{T}) where {T<:Number}
return SVD(U, S, Vt)
end

# disambiguation methods: * and / of Diagonal and Adj/Trans AbsVec
*(u::AdjointAbsVec, D::Diagonal) = (D'u')'
*(u::TransposeAbsVec, D::Diagonal) = transpose(transpose(D) * transpose(u))
*(x::AdjointAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
*(x::TransposeAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
/(u::AdjointAbsVec, D::Diagonal) = (D' \ u')'
Expand Down
12 changes: 6 additions & 6 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,29 +133,29 @@ for T = (:Number, :UniformScaling, :Diagonal)
end

function *(H::UpperHessenberg, U::UpperOrUnitUpperTriangular)
HH = mul!(_initarray(*, eltype(H), eltype(U), H), H, U)
HH = mul!(matprod_dest(H, U, promote_op(matprod, eltype(H), eltype(U))), H, U)
UpperHessenberg(HH)
end
function *(U::UpperOrUnitUpperTriangular, H::UpperHessenberg)
HH = mul!(_initarray(*, eltype(U), eltype(H), H), U, H)
HH = mul!(matprod_dest(U, H, promote_op(matprod, eltype(U), eltype(H))), U, H)
UpperHessenberg(HH)
end

function /(H::UpperHessenberg, U::UpperTriangular)
HH = _rdiv!(_initarray(/, eltype(H), eltype(U), H), H, U)
HH = _rdiv!(matprod_dest(H, U, promote_op(/, eltype(H), eltype(U))), H, U)
UpperHessenberg(HH)
end
function /(H::UpperHessenberg, U::UnitUpperTriangular)
HH = _rdiv!(_initarray(/, eltype(H), eltype(U), H), H, U)
HH = _rdiv!(matprod_dest(H, U, promote_op(/, eltype(H), eltype(U))), H, U)
UpperHessenberg(HH)
end

function \(U::UpperTriangular, H::UpperHessenberg)
HH = ldiv!(_initarray(\, eltype(U), eltype(H), H), U, H)
HH = ldiv!(matprod_dest(U, H, promote_op(\, eltype(U), eltype(H))), U, H)
UpperHessenberg(HH)
end
function \(U::UnitUpperTriangular, H::UpperHessenberg)
HH = ldiv!(_initarray(\, eltype(U), eltype(H), H), U, H)
HH = ldiv!(matprod_dest(U, H, promote_op(\, eltype(U), eltype(H))), U, H)
UpperHessenberg(HH)
end

Expand Down
5 changes: 4 additions & 1 deletion stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,11 @@ julia> [1 1; 0 1] * [1 0; 1 1]
"""
function (*)(A::AbstractMatrix, B::AbstractMatrix)
TS = promote_op(matprod, eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))), A, B)
mul!(matprod_dest(A, B, TS), A, B)
end

matprod_dest(A, B, TS) = similar(B, TS, (size(A, 1), size(B, 2)))

# optimization for dispatching to BLAS, e.g. *(::Matrix{Float32}, ::Matrix{Float64})
# but avoiding the case *(::Matrix{<:BlasComplex}, ::Matrix{<:BlasReal})
# which is better handled by reinterpreting rather than promotion
Expand Down
29 changes: 4 additions & 25 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1545,22 +1545,11 @@ rmul!(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(rmul!(tril!(
## necessary in the general triangular solve problem.

_inner_type_promotion(op, ::Type{TA}, ::Type{TB}) where {TA<:Integer,TB<:Integer} =
_init_eltype(*, TA, TB)
promote_op(matprod, TA, TB)
_inner_type_promotion(op, ::Type{TA}, ::Type{TB}) where {TA,TB} =
_init_eltype(op, TA, TB)
promote_op(op, TA, TB)
## The general promotion methods
function *(A::AbstractTriangular, B::AbstractTriangular)
TAB = _init_eltype(*, eltype(A), eltype(B))
mul!(similar(B, TAB, size(B)), A, B)
end

for mat in (:AbstractVector, :AbstractMatrix)
### Multiplication with triangle to the left and hence rhs cannot be transposed.
@eval function *(A::AbstractTriangular, B::$mat)
require_one_based_indexing(B)
TAB = _init_eltype(*, eltype(A), eltype(B))
mul!(similar(B, TAB, size(B)), A, B)
end
### Left division with triangle to the left hence rhs cannot be transposed. No quotients.
@eval function \(A::Union{UnitUpperTriangular,UnitLowerTriangular}, B::$mat)
require_one_based_indexing(B)
Expand All @@ -1570,7 +1559,7 @@ for mat in (:AbstractVector, :AbstractMatrix)
### Left division with triangle to the left hence rhs cannot be transposed. Quotients.
@eval function \(A::Union{UpperTriangular,LowerTriangular}, B::$mat)
require_one_based_indexing(B)
TAB = _init_eltype(\, eltype(A), eltype(B))
TAB = promote_op(\, eltype(A), eltype(B))
ldiv!(similar(B, TAB, size(B)), A, B)
end
### Right division with triangle to the right hence lhs cannot be transposed. No quotients.
Expand All @@ -1582,20 +1571,10 @@ for mat in (:AbstractVector, :AbstractMatrix)
### Right division with triangle to the right hence lhs cannot be transposed. Quotients.
@eval function /(A::$mat, B::Union{UpperTriangular,LowerTriangular})
require_one_based_indexing(A)
TAB = _init_eltype(/, eltype(A), eltype(B))
TAB = promote_op(/, eltype(A), eltype(B))
_rdiv!(similar(A, TAB, size(A)), A, B)
end
end
### Multiplication with triangle to the right and hence lhs cannot be transposed.
# Only for AbstractMatrix, hence outside the above loop.
function *(A::AbstractMatrix, B::AbstractTriangular)
require_one_based_indexing(A)
TAB = _init_eltype(*, eltype(A), eltype(B))
mul!(similar(A, TAB, size(A)), A, B)
end
# ambiguity resolution with definitions in matmul.jl
*(v::AdjointAbsVec, A::AbstractTriangular) = adjoint(adjoint(A) * v.parent)
*(v::TransposeAbsVec, A::AbstractTriangular) = transpose(transpose(A) * v.parent)

## Some Triangular-Triangular cases. We might want to write tailored methods
## for these cases, but I'm not sure it is worth it.
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,14 @@ end
end
end

@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin
A = [i+j for i in 1:2, j in 1:2]
S = SizedArrays.SizedArray{(2,2)}(A)
D = Diagonal([1:2;])
@test S * D == A * D
@test D * S == D * A
end

@testset "copy" begin
@test copy(Diagonal(1:5)) === Diagonal(1:5)
end
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,14 @@ end
end
end

@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin
A = [i+j for i in 1:2, j in 1:2]
S = SizedArrays.SizedArray{(2,2)}(A)
U = UpperTriangular(ones(2,2))
@test S * U == A * U
@test U * S == U * A
end

@testset "custom axes" begin
SZA = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
Expand Down
5 changes: 5 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,9 @@ function *(S1::SizedArrayLike, S2::SizedArrayLike)
SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2))
SizedArray{SZ}(data)
end

# deliberately wide method definition to ensure that this doesn't lead to ambiguities with
# structured matrices
*(S1::SizedArrayLike, M::AbstractMatrix) = _data(S1) * M

end

0 comments on commit ea085ea

Please sign in to comment.