Skip to content

Commit

Permalink
Adding inplace multiplication for (unit)triangular matrices (JuliaLan…
Browse files Browse the repository at this point in the history
…g#36972)

Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
mcognetta and dkarrasch committed Nov 17, 2020
1 parent f0046a0 commit ef1b6d3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
23 changes: 23 additions & 0 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,29 @@ mul!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = lmul!(A,
mul!(C::AbstractVector, A::AbstractTriangular{<:Any,<:Adjoint}, B::Transpose{<:Any,<:AbstractVecOrMat}) = throw(MethodError(mul!, (C, A, B)))
mul!(C::AbstractVector, A::AbstractTriangular{<:Any,<:Transpose}, B::Transpose{<:Any,<:AbstractVecOrMat}) = throw(MethodError(mul!, (C, A, B)))

# preserve triangular structure in in-place multiplication
for (cty, aty, bty) in ((:UpperTriangular, :UpperTriangular, :UpperTriangular),
(:UpperTriangular, :UpperTriangular, :UnitUpperTriangular),
(:UpperTriangular, :UnitUpperTriangular, :UpperTriangular),
(:UnitUpperTriangular, :UnitUpperTriangular, :UnitUpperTriangular),
(:LowerTriangular, :LowerTriangular, :LowerTriangular),
(:LowerTriangular, :LowerTriangular, :UnitLowerTriangular),
(:LowerTriangular, :UnitLowerTriangular, :LowerTriangular),
(:UnitLowerTriangular, :UnitLowerTriangular, :UnitLowerTriangular))
@eval function mul!(C::$cty, A::$aty, B::$bty)
lmul!(A, copyto!(parent(C), B))
return C
end

@eval @inline function mul!(C::$cty, A::$aty, B::$bty, alpha::Number, beta::Number)
if isone(alpha) && iszero(beta)
return mul!(C, A, B)
else
return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta))
end
end
end

# direct multiplication/division
for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
(:UnitLowerTriangular, 'L', 'U'),
Expand Down
41 changes: 41 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,47 @@ end
end
end

@testset "inplace mul of appropriate types should preserve triagular structure" begin
for elty1 in (Float64, ComplexF32), elty2 in (Float64, ComplexF32)
T = promote_type(elty1, elty2)
M1 = rand(elty1, 5, 5)
M2 = rand(elty2, 5, 5)
A = UpperTriangular(M1)
A2 = UpperTriangular(M2)
Au = UnitUpperTriangular(M1)
Au2 = UnitUpperTriangular(M2)
B = LowerTriangular(M1)
B2 = LowerTriangular(M2)
Bu = UnitLowerTriangular(M1)
Bu2 = UnitLowerTriangular(M2)

@test mul!(similar(A), A, A)::typeof(A) == A*A
@test mul!(similar(A, T), A, A2) A*A2
@test mul!(similar(A, T), A2, A) A2*A
@test mul!(typeof(similar(A, T))(A), A, A2, 2.0, 3.0) 2.0*A*A2 + 3.0*A
@test mul!(typeof(similar(A2, T))(A2), A2, A, 2.0, 3.0) 2.0*A2*A + 3.0*A2

@test mul!(similar(A), A, Au)::typeof(A) == A*Au
@test mul!(similar(A), Au, A)::typeof(A) == Au*A
@test mul!(similar(Au), Au, Au)::typeof(Au) == Au*Au
@test mul!(similar(A, T), A, Au2) A*Au2
@test mul!(similar(A, T), Au2, A) Au2*A
@test mul!(similar(Au2), Au2, Au2) == Au2*Au2

@test mul!(similar(B), B, B)::typeof(B) == B*B
@test mul!(similar(B, T), B, B2) B*B2
@test mul!(similar(B, T), B2, B) B2*B
@test mul!(typeof(similar(B, T))(B), B, B2, 2.0, 3.0) 2.0*B*B2 + 3.0*B
@test mul!(typeof(similar(B2, T))(B2), B2, B, 2.0, 3.0) 2.0*B2*B + 3.0*B2

@test mul!(similar(B), B, Bu)::typeof(B) == B*Bu
@test mul!(similar(B), Bu, B)::typeof(B) == Bu*B
@test mul!(similar(Bu), Bu, Bu)::typeof(Bu) == Bu*Bu
@test mul!(similar(B, T), B, Bu2) B*Bu2
@test mul!(similar(B, T), Bu2, B) Bu2*B
end
end

@testset "special printing of Lower/UpperTriangular" begin
@test occursin(r"3×3 (LinearAlgebra\.)?LowerTriangular{Int64, Matrix{Int64}}:\n 2 ⋅ ⋅\n 2 2 ⋅\n 2 2 2",
sprint(show, MIME"text/plain"(), LowerTriangular(2ones(Int64,3,3))))
Expand Down

0 comments on commit ef1b6d3

Please sign in to comment.