diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 10cc9a2f3459a..a80c8d0600a78 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -523,6 +523,17 @@ _makevector(x::AbstractVector) = Vector(x) _pushzero(A) = (B = similar(A, length(A)+1); @inbounds B[begin:end-1] .= A; @inbounds B[end] = zero(eltype(B)); B) _droplast!(A) = deleteat!(A, lastindex(A)) +# destination type for matmul +matprod_dest(A::StructuredMatrix, B::StructuredMatrix, TS) = similar(B, TS, size(B)) +matprod_dest(A, B::StructuredMatrix, TS) = similar(A, TS, size(A)) +matprod_dest(A::StructuredMatrix, B, TS) = similar(B, TS, size(B)) +matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = similar(A, TS) +matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = similar(B, TS) +matprod_dest(A::Diagonal, B::Diagonal, TS) = similar(B, TS) +matprod_dest(A::HermOrSym, B::Diagonal, TS) = similar(A, TS, size(A)) +matprod_dest(A::Diagonal, B::HermOrSym, TS) = similar(B, TS, size(B)) + +# TODO: remove once not used anymore in SparseArrays.jl # some trait like this would be cool # onedefined(::Type{T}) where {T} = hasmethod(one, (T,)) # but we are actually asking for oneunit(T), that is, however, defined for generic T as diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 88cc145835df9..2b4b2ac98317c 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -802,34 +802,35 @@ ldiv!(c::AbstractVecOrMat, A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMa (t = wrapperop(A); _rdiv!(t(c), t(b), t(A)); return c) ### Generic promotion methods and fallbacks -\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(_initarray(\, eltype(A), eltype(B), B), A, B) +\(A::Bidiagonal, B::AbstractVecOrMat) = + ldiv!(matprod_dest(A, B, promote_op(\, eltype(A), eltype(B))), A, B) \(xA::AdjOrTrans{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = copy(xA) \ B ### Triangular specializations for tri in (:UpperTriangular, :UnitUpperTriangular) @eval function \(B::Bidiagonal, U::$tri) - A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U) + A = ldiv!(matprod_dest(B, U, promote_op(\, eltype(B), eltype(U))), B, U) return B.uplo == 'U' ? UpperTriangular(A) : A end @eval function \(U::$tri, B::Bidiagonal) - A = ldiv!(_initarray(\, eltype(U), eltype(B), U), U, B) + A = ldiv!(matprod_dest(U, B, promote_op(\, eltype(U), eltype(B))), U, B) return B.uplo == 'U' ? UpperTriangular(A) : A end end for tri in (:LowerTriangular, :UnitLowerTriangular) @eval function \(B::Bidiagonal, L::$tri) - A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L) + A = ldiv!(matprod_dest(B, L, promote_op(\, eltype(B), eltype(L))), B, L) return B.uplo == 'L' ? LowerTriangular(A) : A end @eval function \(L::$tri, B::Bidiagonal) - A = ldiv!(_initarray(\, eltype(L), eltype(B), L), L, B) + A = ldiv!(matprod_dest(L, B, promote_op(\, eltype(L), eltype(B))), L, B) return B.uplo == 'L' ? LowerTriangular(A) : A end end ### Diagonal specialization function \(B::Bidiagonal, D::Diagonal) - A = ldiv!(_initarray(\, eltype(B), eltype(D), D), B, D) + A = ldiv!(similar(D, promote_op(\, eltype(B), eltype(D)), size(D)), B, D) return B.uplo == 'U' ? UpperTriangular(A) : LowerTriangular(A) end @@ -879,33 +880,34 @@ rdiv!(A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) = @inline _rdiv!(A, _rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) = (t = wrapperop(B); ldiv!(t(C), t(B), t(A)); return C) -/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(_initarray(/, eltype(A), eltype(B), A), A, B) +/(A::AbstractMatrix, B::Bidiagonal) = + _rdiv!(similar(A, promote_op(/, eltype(A), eltype(B)), size(A)), A, B) ### Triangular specializations for tri in (:UpperTriangular, :UnitUpperTriangular) @eval function /(U::$tri, B::Bidiagonal) - A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B) + A = _rdiv!(matprod_dest(U, B, promote_op(/, eltype(U), eltype(B))), U, B) return B.uplo == 'U' ? UpperTriangular(A) : A end @eval function /(B::Bidiagonal, U::$tri) - A = _rdiv!(_initarray(/, eltype(B), eltype(U), U), B, U) + A = _rdiv!(matprod_dest(B, U, promote_op(/, eltype(B), eltype(U))), B, U) return B.uplo == 'U' ? UpperTriangular(A) : A end end for tri in (:LowerTriangular, :UnitLowerTriangular) @eval function /(L::$tri, B::Bidiagonal) - A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B) + A = _rdiv!(matprod_dest(L, B, promote_op(/, eltype(L), eltype(B))), L, B) return B.uplo == 'L' ? LowerTriangular(A) : A end @eval function /(B::Bidiagonal, L::$tri) - A = _rdiv!(_initarray(/, eltype(B), eltype(L), L), B, L) + A = _rdiv!(matprod_dest(B, L, promote_op(/, eltype(B), eltype(L))), B, L) return B.uplo == 'L' ? LowerTriangular(A) : A end end ### Diagonal specialization function /(D::Diagonal, B::Bidiagonal) - A = _rdiv!(_initarray(/, eltype(D), eltype(B), D), D, B) + A = _rdiv!(similar(D, promote_op(/, eltype(D), eltype(B)), size(D)), D, B) return B.uplo == 'U' ? UpperTriangular(A) : LowerTriangular(A) end diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index e8638d3d7ff12..063dde619bd1a 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -310,15 +310,6 @@ function (*)(D::Diagonal, V::AbstractVector) return D.diag .* V end -(*)(A::AbstractMatrix, D::Diagonal) = - mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D) -(*)(A::HermOrSym, D::Diagonal) = - mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D) -(*)(D::Diagonal, A::AbstractMatrix) = - mul!(similar(A, promote_op(*, eltype(D.diag), eltype(A))), D, A) -(*)(D::Diagonal, A::HermOrSym) = - mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A) - rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D) lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B) @@ -431,8 +422,8 @@ function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal) return Diagonal(Da.diag .* Db.diag .* Dc.diag) end -/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D))), A, D) -/(A::HermOrSym, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D)), size(A)), A, D) +/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(matprod_dest(A, D, promote_op(/, eltype(A), eltype(D))), A, D) +/(A::HermOrSym, D::Diagonal) = _rdiv!(matprod_dest(A, D, promote_op(/, eltype(A), eltype(D))), A, D) rdiv!(A::AbstractVecOrMat, D::Diagonal) = @inline _rdiv!(A, A, D) # avoid copy when possible via internal 3-arg backend @@ -458,8 +449,8 @@ function \(D::Diagonal, B::AbstractVector) isnothing(j) || throw(SingularException(j)) return D.diag .\ B end -\(D::Diagonal, B::AbstractMatrix) = ldiv!(similar(B, _init_eltype(\, eltype(D), eltype(B))), D, B) -\(D::Diagonal, B::HermOrSym) = ldiv!(similar(B, _init_eltype(\, eltype(D), eltype(B)), size(B)), D, B) +\(D::Diagonal, B::AbstractMatrix) = ldiv!(matprod_dest(D, B, promote_op(\, eltype(D), eltype(B))), D, B) +\(D::Diagonal, B::HermOrSym) = ldiv!(matprod_dest(D, B, promote_op(\, eltype(D), eltype(B))), D, B) ldiv!(D::Diagonal, B::AbstractVecOrMat) = @inline ldiv!(B, D, B) function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat) @@ -479,8 +470,8 @@ function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat) end # Optimizations for \, / between Diagonals -\(D::Diagonal, B::Diagonal) = ldiv!(similar(B, promote_op(\, eltype(D), eltype(B))), D, B) -/(A::Diagonal, D::Diagonal) = _rdiv!(similar(A, promote_op(/, eltype(A), eltype(D))), A, D) +\(D::Diagonal, B::Diagonal) = ldiv!(matprod_dest(D, B, promote_op(\, eltype(D), eltype(B))), D, B) +/(A::Diagonal, D::Diagonal) = _rdiv!(matprod_dest(A, D, promote_op(/, eltype(A), eltype(D))), A, D) function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal) n, k = length(Db.diag), length(Da.diag) n == k || throw(DimensionMismatch("left hand side has $n columns but D is $k by $k")) @@ -543,7 +534,7 @@ function (/)(S::SymTridiagonal, D::Diagonal) dl = similar(S.ev, T, max(length(S.dv)-1, 0)) _rdiv!(Tridiagonal(dl, d, du), S, D) end -(/)(T::Tridiagonal, D::Diagonal) = _rdiv!(similar(T, promote_op(/, eltype(T), eltype(D))), T, D) +(/)(T::Tridiagonal, D::Diagonal) = _rdiv!(matprod_dest(T, D, promote_op(/, eltype(T), eltype(D))), T, D) function _rdiv!(T::Tridiagonal, S::Union{SymTridiagonal,Tridiagonal}, D::Diagonal) n = size(S, 2) dd = D.diag @@ -876,9 +867,6 @@ function svd(D::Diagonal{T}) where {T<:Number} return SVD(U, S, Vt) end -# disambiguation methods: * and / of Diagonal and Adj/Trans AbsVec -*(u::AdjointAbsVec, D::Diagonal) = (D'u')' -*(u::TransposeAbsVec, D::Diagonal) = transpose(transpose(D) * transpose(u)) *(x::AdjointAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y) *(x::TransposeAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y) /(u::AdjointAbsVec, D::Diagonal) = (D' \ u')' diff --git a/stdlib/LinearAlgebra/src/hessenberg.jl b/stdlib/LinearAlgebra/src/hessenberg.jl index aed8b3c3747a1..3be41baf24b24 100644 --- a/stdlib/LinearAlgebra/src/hessenberg.jl +++ b/stdlib/LinearAlgebra/src/hessenberg.jl @@ -133,29 +133,29 @@ for T = (:Number, :UniformScaling, :Diagonal) end function *(H::UpperHessenberg, U::UpperOrUnitUpperTriangular) - HH = mul!(_initarray(*, eltype(H), eltype(U), H), H, U) + HH = mul!(matprod_dest(H, U, promote_op(matprod, eltype(H), eltype(U))), H, U) UpperHessenberg(HH) end function *(U::UpperOrUnitUpperTriangular, H::UpperHessenberg) - HH = mul!(_initarray(*, eltype(U), eltype(H), H), U, H) + HH = mul!(matprod_dest(U, H, promote_op(matprod, eltype(U), eltype(H))), U, H) UpperHessenberg(HH) end function /(H::UpperHessenberg, U::UpperTriangular) - HH = _rdiv!(_initarray(/, eltype(H), eltype(U), H), H, U) + HH = _rdiv!(matprod_dest(H, U, promote_op(/, eltype(H), eltype(U))), H, U) UpperHessenberg(HH) end function /(H::UpperHessenberg, U::UnitUpperTriangular) - HH = _rdiv!(_initarray(/, eltype(H), eltype(U), H), H, U) + HH = _rdiv!(matprod_dest(H, U, promote_op(/, eltype(H), eltype(U))), H, U) UpperHessenberg(HH) end function \(U::UpperTriangular, H::UpperHessenberg) - HH = ldiv!(_initarray(\, eltype(U), eltype(H), H), U, H) + HH = ldiv!(matprod_dest(U, H, promote_op(\, eltype(U), eltype(H))), U, H) UpperHessenberg(HH) end function \(U::UnitUpperTriangular, H::UpperHessenberg) - HH = ldiv!(_initarray(\, eltype(U), eltype(H), H), U, H) + HH = ldiv!(matprod_dest(U, H, promote_op(\, eltype(U), eltype(H))), U, H) UpperHessenberg(HH) end diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index cc1c954258a88..afc646c149eaf 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -106,8 +106,11 @@ julia> [1 1; 0 1] * [1 0; 1 1] """ function (*)(A::AbstractMatrix, B::AbstractMatrix) TS = promote_op(matprod, eltype(A), eltype(B)) - mul!(similar(B, TS, (size(A, 1), size(B, 2))), A, B) + mul!(matprod_dest(A, B, TS), A, B) end + +matprod_dest(A, B, TS) = similar(B, TS, (size(A, 1), size(B, 2))) + # optimization for dispatching to BLAS, e.g. *(::Matrix{Float32}, ::Matrix{Float64}) # but avoiding the case *(::Matrix{<:BlasComplex}, ::Matrix{<:BlasReal}) # which is better handled by reinterpreting rather than promotion diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 64e1b320f19dd..781e323a0ca43 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -1545,22 +1545,11 @@ rmul!(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(rmul!(tril!( ## necessary in the general triangular solve problem. _inner_type_promotion(op, ::Type{TA}, ::Type{TB}) where {TA<:Integer,TB<:Integer} = - _init_eltype(*, TA, TB) + promote_op(matprod, TA, TB) _inner_type_promotion(op, ::Type{TA}, ::Type{TB}) where {TA,TB} = - _init_eltype(op, TA, TB) + promote_op(op, TA, TB) ## The general promotion methods -function *(A::AbstractTriangular, B::AbstractTriangular) - TAB = _init_eltype(*, eltype(A), eltype(B)) - mul!(similar(B, TAB, size(B)), A, B) -end - for mat in (:AbstractVector, :AbstractMatrix) - ### Multiplication with triangle to the left and hence rhs cannot be transposed. - @eval function *(A::AbstractTriangular, B::$mat) - require_one_based_indexing(B) - TAB = _init_eltype(*, eltype(A), eltype(B)) - mul!(similar(B, TAB, size(B)), A, B) - end ### Left division with triangle to the left hence rhs cannot be transposed. No quotients. @eval function \(A::Union{UnitUpperTriangular,UnitLowerTriangular}, B::$mat) require_one_based_indexing(B) @@ -1570,7 +1559,7 @@ for mat in (:AbstractVector, :AbstractMatrix) ### Left division with triangle to the left hence rhs cannot be transposed. Quotients. @eval function \(A::Union{UpperTriangular,LowerTriangular}, B::$mat) require_one_based_indexing(B) - TAB = _init_eltype(\, eltype(A), eltype(B)) + TAB = promote_op(\, eltype(A), eltype(B)) ldiv!(similar(B, TAB, size(B)), A, B) end ### Right division with triangle to the right hence lhs cannot be transposed. No quotients. @@ -1582,20 +1571,10 @@ for mat in (:AbstractVector, :AbstractMatrix) ### Right division with triangle to the right hence lhs cannot be transposed. Quotients. @eval function /(A::$mat, B::Union{UpperTriangular,LowerTriangular}) require_one_based_indexing(A) - TAB = _init_eltype(/, eltype(A), eltype(B)) + TAB = promote_op(/, eltype(A), eltype(B)) _rdiv!(similar(A, TAB, size(A)), A, B) end end -### Multiplication with triangle to the right and hence lhs cannot be transposed. -# Only for AbstractMatrix, hence outside the above loop. -function *(A::AbstractMatrix, B::AbstractTriangular) - require_one_based_indexing(A) - TAB = _init_eltype(*, eltype(A), eltype(B)) - mul!(similar(A, TAB, size(A)), A, B) -end -# ambiguity resolution with definitions in matmul.jl -*(v::AdjointAbsVec, A::AbstractTriangular) = adjoint(adjoint(A) * v.parent) -*(v::TransposeAbsVec, A::AbstractTriangular) = transpose(transpose(A) * v.parent) ## Some Triangular-Triangular cases. We might want to write tailored methods ## for these cases, but I'm not sure it is worth it. diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 9db7d8071931c..dab8f45ef66c7 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1242,6 +1242,14 @@ end end end +@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin + A = [i+j for i in 1:2, j in 1:2] + S = SizedArrays.SizedArray{(2,2)}(A) + D = Diagonal([1:2;]) + @test S * D == A * D + @test D * S == D * A +end + @testset "copy" begin @test copy(Diagonal(1:5)) === Diagonal(1:5) end diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index b60efba1d941a..e245240335434 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -896,6 +896,14 @@ end end end +@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin + A = [i+j for i in 1:2, j in 1:2] + S = SizedArrays.SizedArray{(2,2)}(A) + U = UpperTriangular(ones(2,2)) + @test S * U == A * U + @test U * S == U * A +end + @testset "custom axes" begin SZA = SizedArrays.SizedArray{(2,2)}([1 2; 3 4]) for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) diff --git a/test/testhelpers/SizedArrays.jl b/test/testhelpers/SizedArrays.jl index a4810e297e0f9..67f5059b36efd 100644 --- a/test/testhelpers/SizedArrays.jl +++ b/test/testhelpers/SizedArrays.jl @@ -56,4 +56,9 @@ function *(S1::SizedArrayLike, S2::SizedArrayLike) SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2)) SizedArray{SZ}(data) end + +# deliberately wide method definition to ensure that this doesn't lead to ambiguities with +# structured matrices +*(S1::SizedArrayLike, M::AbstractMatrix) = _data(S1) * M + end