Skip to content

Commit

Permalink
Add missing optimization for *,/ between Diagonal and Triangular (J…
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Oct 30, 2021
1 parent 4a12d1e commit 66d05d5
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 99 deletions.
181 changes: 91 additions & 90 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,6 @@ function (*)(D::Diagonal, V::AbstractVector)
return D.diag .* V
end

(*)(A::AbstractTriangular, D::Diagonal) =
rmul!(copy_oftype(A, promote_op(*, eltype(A), eltype(D.diag))), D)
(*)(D::Diagonal, B::AbstractTriangular) =
lmul!(D, copy_oftype(B, promote_op(*, eltype(B), eltype(D.diag))))

(*)(A::AbstractMatrix, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
(*)(D::Diagonal, A::AbstractMatrix) =
Expand All @@ -245,37 +240,7 @@ end
rmul!(A::AbstractMatrix, D::Diagonal) = mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = mul!(B, D, B)

rmul!(A::Union{LowerTriangular,UpperTriangular}, D::Diagonal) = typeof(A)(rmul!(A.data, D))
function rmul!(A::UnitLowerTriangular, D::Diagonal)
rmul!(A.data, D)
for i = 1:size(A, 1)
A.data[i,i] = D.diag[i]
end
LowerTriangular(A.data)
end
function rmul!(A::UnitUpperTriangular, D::Diagonal)
rmul!(A.data, D)
for i = 1:size(A, 1)
A.data[i,i] = D.diag[i]
end
UpperTriangular(A.data)
end

function lmul!(D::Diagonal, B::UnitLowerTriangular)
lmul!(D, B.data)
for i = 1:size(B, 1)
B.data[i,i] = D.diag[i]
end
LowerTriangular(B.data)
end
function lmul!(D::Diagonal, B::UnitUpperTriangular)
lmul!(D, B.data)
for i = 1:size(B, 1)
B.data[i,i] = D.diag[i]
end
UpperTriangular(B.data)
end

#TODO: It seems better to call (D' * adjA')' directly?
function *(adjA::Adjoint{<:Any,<:AbstractMatrix}, D::Diagonal)
A = adjA.parent
Ac = similar(A, promote_op(*, eltype(A), eltype(D.diag)), (size(A, 2), size(A, 1)))
Expand Down Expand Up @@ -382,42 +347,97 @@ function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta
return C
end

(/)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag)

ldiv!(x::AbstractArray, A::Diagonal, b::AbstractArray) = (x .= A.diag .\ b)
/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(similar(A, promote_op(/, eltype(A), eltype(D)), size(A)), A, D)

function ldiv!(D::Diagonal, A::Union{LowerTriangular,UpperTriangular})
broadcast!(\, parent(A), D.diag, parent(A))
A
end

function rdiv!(A::AbstractMatrix, D::Diagonal)
rdiv!(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(A, A, D)
# avoid copy when possible via internal 3-arg backend
function _rdiv!(B::AbstractVecOrMat, A::AbstractVecOrMat, D::Diagonal)
require_one_based_indexing(A)
dd = D.diag
m, n = size(A)
m, n = size(A, 1), size(A, 2)
if (k = length(dd)) n
throw(DimensionMismatch("left hand side has $n columns but D is $k by $k"))
end
@inbounds for j in 1:n
ddj = dd[j]
if iszero(ddj)
throw(SingularException(j))
end
iszero(ddj) && throw(SingularException(j))
for i in 1:m
A[i, j] /= ddj
B[i, j] = A[i, j] / ddj
end
end
A
end

function rdiv!(A::Union{LowerTriangular,UpperTriangular}, D::Diagonal)
broadcast!(/, parent(A), parent(A), permutedims(D.diag))
A
B
end

\(D::Diagonal, B::AbstractVecOrMat) = ldiv!(similar(B, promote_op(\, eltype(D), eltype(B)), size(B)), D, B)

ldiv!(D::Diagonal, B::AbstractVecOrMat) = ldiv!(B, D, B)
function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat)
require_one_based_indexing(A, B)
d = length(D.diag)
m, n = size(A, 1), size(A, 2)
m′, n′ = size(B, 1), size(B, 2)
m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d"))
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
j = findfirst(iszero, D.diag)
isnothing(j) || throw(SingularException(j))
B .= D.diag .\ A
end

# Optimizations for \, / between Diagonals
\(D::Diagonal, B::Diagonal) = ldiv!(similar(B, promote_op(\, eltype(D), eltype(B))), D, B)
/(A::Diagonal, D::Diagonal) = _rdiv!(similar(A, promote_op(/, eltype(A), eltype(D))), A, D)
function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal)
n, k = length(Db.diag), length(Db.diag)
n == k || throw(DimensionMismatch("left hand side has $n columns but D is $k by $k"))
j = findfirst(iszero, Da.diag)
isnothing(j) || throw(SingularException(j))
Dc.diag .= Db.diag ./ Da.diag
Dc
end
ldiv!(Dc::Diagonal, Da::Diagonal, Db::Diagonal) = Diagonal(ldiv!(Dc.diag, Da, Db.diag))

# Optimizations for [l/r]mul!, l/rdiv!, *, / and \ between Triangular and Diagonal.
# These functions are generally more efficient if we calculate the whole data field.
# The following code implements them in a unified pattern to avoid missing.
@inline function _setdiag!(data, f, diag, diag′ = nothing)
@inbounds for i in 1:length(diag)
data[i,i] = isnothing(diag′) ? f(diag[i]) : f(diag[i],diag′[i])
end
data
end
for Tri in (:UpperTriangular, :LowerTriangular)
UTri = Symbol(:Unit, Tri)
# 2 args
for (fun, f) in zip((:*, :rmul!, :rdiv!, :/), (:identity, :identity, :inv, :inv))
@eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D))
@eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag))
end
for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv))
@eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data))
@eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag))
end
# 3-arg ldiv!
@eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data))
@eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))
# 3-arg mul!: invoke 5-arg mul! rather than lmul!
@eval mul!(C::$Tri, A::Union{$Tri,$UTri}, D::Diagonal) = mul!(C, A, D, true, false)
# 5-arg mul!
@eval @inline mul!(C::$Tri, D::Diagonal, A::$Tri, α::Number, β::Number) = $Tri(mul!(C.data, D, A.data, α, β))
@eval @inline function mul!(C::$Tri, D::Diagonal, A::$UTri, α::Number, β::Number)
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = iszero(β) ? nothing : diag(C)
data = mul!(C.data, D, A.data, α, β)
$Tri(_setdiag!(data, MulAddMul(α, β), D.diag, diag′))
end
@eval @inline mul!(C::$Tri, A::$Tri, D::Diagonal, α::Number, β::Number) = $Tri(mul!(C.data, A.data, D, α, β))
@eval @inline function mul!(C::$Tri, A::$UTri, D::Diagonal, α::Number, β::Number)
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = iszero(β) ? nothing : diag(C)
data = mul!(C.data, A.data, D, α, β)
$Tri(_setdiag!(data, MulAddMul(α, β), D.diag, diag′))
end
end

(/)(A::Union{StridedMatrix, AbstractTriangular}, D::Diagonal) =
rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D)

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
Expand All @@ -435,7 +455,7 @@ end
kron(A::Diagonal{<:Number}, B::Diagonal{<:Number}) = Diagonal(kron(A.diag, B.diag))

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix)
Base.require_one_based_indexing(B)
require_one_based_indexing(B)
(mA, nA) = size(A)
(mB, nB) = size(B)
(mC, nC) = size(C)
Expand Down Expand Up @@ -516,30 +536,6 @@ for f in (:exp, :cis, :log, :sqrt,
@eval $f(D::Diagonal) = Diagonal($f.(D.diag))
end

(\)(D::Diagonal, A::AbstractMatrix) =
ldiv!(D, (typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A))

(\)(D::Diagonal, b::AbstractVector) = D.diag .\ b
(\)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .\ Db.diag)

function ldiv!(D::Diagonal, B::AbstractVecOrMat)
m, n = size(B, 1), size(B, 2)
if m != length(D.diag)
throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $m rows"))
end
(m == 0 || n == 0) && return B
for j = 1:n
for i = 1:m
di = D.diag[i]
if di == 0
throw(SingularException(i))
end
B[i,j] = di \ B[i,j]
end
end
return B
end

function inv(D::Diagonal{T}) where T
Di = similar(D.diag, typeof(inv(zero(T))))
for i = 1:length(D.diag)
Expand Down Expand Up @@ -596,11 +592,17 @@ function svd(D::Diagonal{T}) where T<:Number
return SVD(Up, S[piv], copy(Vp'))
end

# disambiguation methods: * of Diagonal and Adj/Trans AbsVec
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
# disambiguation methods: * and / of Diagonal and Adj/Trans AbsVec
*(x::AdjointAbsVec, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
*(x::TransposeAbsVec, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
*(x::AdjointAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
*(x::TransposeAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
/(u::AdjointAbsVec, D::Diagonal) = adjoint(adjoint(D) \ u.parent)
/(u::TransposeAbsVec, D::Diagonal) = transpose(transpose(D) \ u.parent)
# disambiguation methods: Call unoptimized version for user defined AbstractTriangular.
*(A::AbstractTriangular, D::Diagonal) = Base.@invoke *(A::AbstractMatrix, D::Diagonal)
*(D::Diagonal, A::AbstractTriangular) = Base.@invoke *(D::Diagonal, A::AbstractMatrix)

dot(x::AbstractVector, D::Diagonal, y::AbstractVector) = _mapreduce_prod(dot, x, D, y)

dot(A::Diagonal, B::Diagonal) = dot(A.diag, B.diag)
Expand All @@ -619,7 +621,6 @@ function _mapreduce_prod(f, x, D::Diagonal, y)
end
end


function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true)
info = 0
for (i, di) in enumerate(A.diag)
Expand Down
32 changes: 32 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -902,4 +902,36 @@ end
@test oneunit(D3) isa typeof(D3)
end

@testset "AbstractTriangular" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
A = randn(4, 4)
TriA = Tri(A)
UTriA = UTri(A)
D = Diagonal(1.0:4.0)
DM = Matrix(D)
DMF = factorize(DM)
outTri = similar(TriA)
out = similar(A)
# 2 args
for fun in (*, rmul!, rdiv!, /)
@test fun(copy(TriA), D)::Tri == fun(Matrix(TriA), D)
@test fun(copy(UTriA), D)::Tri == fun(Matrix(UTriA), D)
end
for fun in (*, lmul!, ldiv!, \)
@test fun(D, copy(TriA))::Tri == fun(D, Matrix(TriA))
@test fun(D, copy(UTriA))::Tri == fun(D, Matrix(UTriA))
end
# 3 args
@test outTri === ldiv!(outTri, D, TriA)::Tri == ldiv!(out, D, Matrix(TriA))
@test outTri === ldiv!(outTri, D, UTriA)::Tri == ldiv!(out, D, Matrix(UTriA))
@test outTri === mul!(outTri, D, TriA)::Tri == mul!(out, D, Matrix(TriA))
@test outTri === mul!(outTri, D, UTriA)::Tri == mul!(out, D, Matrix(UTriA))
@test outTri === mul!(outTri, TriA, D)::Tri == mul!(out, Matrix(TriA), D)
@test outTri === mul!(outTri, UTriA, D)::Tri == mul!(out, Matrix(UTriA), D)
# 5 args
@test outTri === mul!(outTri, D, TriA, 2, 1)::Tri == mul!(out, D, Matrix(TriA), 2, 1)
@test outTri === mul!(outTri, D, UTriA, 2, 1)::Tri == mul!(out, D, Matrix(UTriA), 2, 1)
@test outTri === mul!(outTri, TriA, D, 2, 1)::Tri == mul!(out, Matrix(TriA), D, 2, 1)
@test outTri === mul!(outTri, UTriA, D, 2, 1)::Tri == mul!(out, Matrix(UTriA), D, 2, 1)
end

end # module TestDiagonal
14 changes: 6 additions & 8 deletions stdlib/LinearAlgebra/test/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,12 @@ let n = 10
@testset "Multiplication/division" begin
for x = (5, 5I, Diagonal(d), Bidiagonal(d,dl,:U),
UpperTriangular(A), UnitUpperTriangular(A))
@test H*x == Array(H)*x broken = eltype(H) <: Furlong && x isa Bidiagonal
@test x*H == x*Array(H) broken = eltype(H) <: Furlong && x isa Bidiagonal
@test H/x == Array(H)/x broken = eltype(H) <: Furlong && x isa Union{Bidiagonal, Diagonal, UpperTriangular}
@test x\H == x\Array(H) broken = eltype(H) <: Furlong && x isa Union{Bidiagonal, Diagonal, UpperTriangular}
@test H*x isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal
@test x*H isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal
@test H/x isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Union{Bidiagonal, Diagonal}
@test x\H isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Union{Bidiagonal, Diagonal}
@test (H*x)::UpperHessenberg == Array(H)*x broken = eltype(H) <: Furlong && x isa Bidiagonal
@test (x*H)::UpperHessenberg == x*Array(H) broken = eltype(H) <: Furlong && x isa Bidiagonal
@test H/x == Array(H)/x broken = eltype(H) <: Furlong && x isa Union{Bidiagonal, UpperTriangular}
@test x\H == x\Array(H) broken = eltype(H) <: Furlong && x isa Union{Bidiagonal, UpperTriangular}
@test H/x isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal
@test x\H isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal
end
x = Bidiagonal(d, dl, :L)
@test H*x == Array(H)*x
Expand Down
3 changes: 2 additions & 1 deletion stdlib/SparseArrays/test/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,8 @@ end
@test 2 .* ((1:5) .+ A) == 2:2:10
@test 2 .* (A .+ (1:5)) == 2:2:10

@test Diagonal(spzeros(5)) \ view(rand(10), 1:5) == [Inf,Inf,Inf,Inf,Inf]
# lu(zeros(5,5)) throw SingularException, see #42343
@test_throws SingularException Diagonal(spzeros(5)) \ view(rand(10), 1:5)
end

@testset "Issue #27836" begin
Expand Down

0 comments on commit 66d05d5

Please sign in to comment.