Skip to content

Commit

Permalink
Make *Triangular handle units (#43972)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed May 11, 2023
1 parent 528949f commit b21f100
Show file tree
Hide file tree
Showing 12 changed files with 638 additions and 585 deletions.
102 changes: 40 additions & 62 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
const BiTri = Union{Bidiagonal,Tridiagonal}
@inline mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
Expand Down Expand Up @@ -747,39 +748,27 @@ ldiv!(c::AbstractVecOrMat, A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMa
\(xA::AdjOrTrans{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = copy(xA) \ B

### Triangular specializations
function \(B::Bidiagonal, U::UpperTriangular)
A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
function \(B::Bidiagonal, U::UnitUpperTriangular)
A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
function \(B::Bidiagonal, L::LowerTriangular)
A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L)
return B.uplo == 'L' ? LowerTriangular(A) : A
for tri in (:UpperTriangular, :UnitUpperTriangular)
@eval function \(B::Bidiagonal, U::$tri)
A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
@eval function \(U::$tri, B::Bidiagonal)
A = ldiv!(_initarray(\, eltype(U), eltype(B), U), U, B)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
end
function \(B::Bidiagonal, L::UnitLowerTriangular)
A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L)
return B.uplo == 'L' ? LowerTriangular(A) : A
for tri in (:LowerTriangular, :UnitLowerTriangular)
@eval function \(B::Bidiagonal, L::$tri)
A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
@eval function \(L::$tri, B::Bidiagonal)
A = ldiv!(_initarray(\, eltype(L), eltype(B), L), L, B)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
end

function \(U::UpperTriangular, B::Bidiagonal)
A = ldiv!(U, copy_similar(B, _init_eltype(\, eltype(U), eltype(B))))
return B.uplo == 'U' ? UpperTriangular(A) : A
end
function \(U::UnitUpperTriangular, B::Bidiagonal)
A = ldiv!(U, copy_similar(B, _init_eltype(\, eltype(U), eltype(B))))
return B.uplo == 'U' ? UpperTriangular(A) : A
end
function \(L::LowerTriangular, B::Bidiagonal)
A = ldiv!(L, copy_similar(B, _init_eltype(\, eltype(L), eltype(B))))
return B.uplo == 'L' ? LowerTriangular(A) : A
end
function \(L::UnitLowerTriangular, B::Bidiagonal)
A = ldiv!(L, copy_similar(B, _init_eltype(\, eltype(L), eltype(B))))
return B.uplo == 'L' ? LowerTriangular(A) : A
end
### Diagonal specialization
function \(B::Bidiagonal, D::Diagonal)
A = ldiv!(_initarray(\, eltype(B), eltype(D), D), B, D)
Expand Down Expand Up @@ -835,38 +824,27 @@ _rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal})
/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(_initarray(/, eltype(A), eltype(B), A), A, B)

### Triangular specializations
function /(U::UpperTriangular, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
function /(U::UnitUpperTriangular, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
function /(L::LowerTriangular, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
function /(L::UnitLowerTriangular, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
function /(B::Bidiagonal, U::UpperTriangular)
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(U))), U)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
function /(B::Bidiagonal, U::UnitUpperTriangular)
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(U))), U)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
function /(B::Bidiagonal, L::LowerTriangular)
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(L))), L)
return B.uplo == 'L' ? LowerTriangular(A) : A
for tri in (:UpperTriangular, :UnitUpperTriangular)
@eval function /(U::$tri, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
@eval function /(B::Bidiagonal, U::$tri)
A = _rdiv!(_initarray(/, eltype(B), eltype(U), U), B, U)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
end
function /(B::Bidiagonal, L::UnitLowerTriangular)
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(L))), L)
return B.uplo == 'L' ? LowerTriangular(A) : A
for tri in (:LowerTriangular, :UnitLowerTriangular)
@eval function /(L::$tri, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
@eval function /(B::Bidiagonal, L::$tri)
A = _rdiv!(_initarray(/, eltype(B), eltype(L), L), B, L)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
end

### Diagonal specialization
function /(D::Diagonal, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(D), eltype(B), D), D, B)
Expand All @@ -886,8 +864,8 @@ end
factorize(A::Bidiagonal) = A
function inv(B::Bidiagonal{T}) where T
n = size(B, 1)
dest = zeros(typeof(oneunit(T)\one(T)), (n, n))
ldiv!(dest, B, Diagonal{typeof(one(T)\one(T))}(I, n))
dest = zeros(typeof(inv(oneunit(T))), (n, n))
ldiv!(dest, B, Diagonal{typeof(one(T)/one(T))}(I, n))
return B.uplo == 'U' ? UpperTriangular(dest) : LowerTriangular(dest)
end

Expand Down
12 changes: 5 additions & 7 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -907,14 +907,12 @@ sqrt(A::TransposeAbsMat) = transpose(sqrt(parent(A)))

function inv(A::StridedMatrix{T}) where T
checksquare(A)
S = typeof((oneunit(T)*zero(T) + oneunit(T)*zero(T))/oneunit(T))
AA = convert(AbstractArray{S}, A)
if istriu(AA)
Ai = triu!(parent(inv(UpperTriangular(AA))))
elseif istril(AA)
Ai = tril!(parent(inv(LowerTriangular(AA))))
if istriu(A)
Ai = triu!(parent(inv(UpperTriangular(A))))
elseif istril(A)
Ai = tril!(parent(inv(LowerTriangular(A))))
else
Ai = inv!(lu(AA))
Ai = inv!(lu(A))
Ai = convert(typeof(parent(Ai)), Ai)
end
return Ai
Expand Down
25 changes: 18 additions & 7 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,23 @@ function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0})
return out
end

function _mul!(out, A, B, _add)
function _mul_diag!(out, A, B, _add)
_muldiag_size_check(out, A, B)
__muldiag!(out, A, B, _add)
return out
end

_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add) =
_mul_diag!(out, D, V, _add)
_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add) =
_mul_diag!(out, D, B, _add)
_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add) =
_mul_diag!(out, A, D, _add)
_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add) =
_mul_diag!(C, Da, Db, _add)
_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) =
_mul_diag!(C, Da, Db, _add)

function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
_muldiag_size_check(Da, A)
_muldiag_size_check(A, Db)
Expand All @@ -395,6 +406,7 @@ end

/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D))), A, D)
/(A::HermOrSym, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D)), size(A)), A, D)

rdiv!(A::AbstractVecOrMat, D::Diagonal) = @inline _rdiv!(A, A, D)
# avoid copy when possible via internal 3-arg backend
function _rdiv!(B::AbstractVecOrMat, A::AbstractVecOrMat, D::Diagonal)
Expand Down Expand Up @@ -557,22 +569,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
# 3-arg ldiv!
@eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data))
@eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))
# 3-arg mul!: invoke 5-arg mul! rather than lmul!
@eval mul!(C::$Tri, A::Union{$Tri,$UTri}, D::Diagonal) = mul!(C, A, D, true, false)
# 3-arg mul! is disambiguated in special.jl
# 5-arg mul!
@eval _mul!(C::$Tri, D::Diagonal, A::$Tri, _add) = $Tri(mul!(C.data, D, A.data, _add.alpha, _add.beta))
@eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add)
@eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
α, β = _add.alpha, _add.beta
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = iszero(β) ? nothing : diag(C)
diag′ = bis0 ? nothing : diag(C)
data = mul!(C.data, D, A.data, α, β)
$Tri(_setdiag!(data, _add, D.diag, diag′))
end
@eval _mul!(C::$Tri, A::$Tri, D::Diagonal, _add) = $Tri(mul!(C.data, A.data, D, _add.alpha, _add.beta))
@eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add)
@eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
α, β = _add.alpha, _add.beta
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = iszero(β) ? nothing : diag(C)
diag′ = bis0 ? nothing : diag(C)
data = mul!(C.data, A.data, D, α, β)
$Tri(_setdiag!(data, _add, D.diag, diag′))
end
Expand Down
24 changes: 6 additions & 18 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,41 +129,29 @@ for T = (:Number, :UniformScaling, :Diagonal)
end

function *(H::UpperHessenberg, U::UpperOrUnitUpperTriangular)
T = typeof(oneunit(eltype(H))*oneunit(eltype(U)))
HH = copy_similar(H, T)
rmul!(HH, U)
HH = _mulmattri!(_initarray(*, eltype(H), eltype(U), H), H, U)
UpperHessenberg(HH)
end
function *(U::UpperOrUnitUpperTriangular, H::UpperHessenberg)
T = typeof(oneunit(eltype(H))*oneunit(eltype(U)))
HH = copy_similar(H, T)
lmul!(U, HH)
HH = _multrimat!(_initarray(*, eltype(U), eltype(H), H), U, H)
UpperHessenberg(HH)
end

function /(H::UpperHessenberg, U::UpperTriangular)
T = typeof(oneunit(eltype(H))/oneunit(eltype(U)))
HH = copy_similar(H, T)
rdiv!(HH, U)
HH = _rdiv!(_initarray(/, eltype(H), eltype(U), H), H, U)
UpperHessenberg(HH)
end
function /(H::UpperHessenberg, U::UnitUpperTriangular)
T = typeof(oneunit(eltype(H))/oneunit(eltype(U)))
HH = copy_similar(H, T)
rdiv!(HH, U)
HH = _rdiv!(_initarray(/, eltype(H), eltype(U), H), H, U)
UpperHessenberg(HH)
end

function \(U::UpperTriangular, H::UpperHessenberg)
T = typeof(oneunit(eltype(U))\oneunit(eltype(H)))
HH = copy_similar(H, T)
ldiv!(U, HH)
HH = ldiv!(_initarray(\, eltype(U), eltype(H), H), U, H)
UpperHessenberg(HH)
end
function \(U::UnitUpperTriangular, H::UpperHessenberg)
T = typeof(oneunit(eltype(U))\oneunit(eltype(H)))
HH = copy_similar(H, T)
ldiv!(U, HH)
HH = ldiv!(_initarray(\, eltype(U), eltype(H), H), U, H)
UpperHessenberg(HH)
end

Expand Down
5 changes: 2 additions & 3 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,14 @@ julia> C
730.0 740.0
```
"""
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat,
alpha::Number, beta::Number) =
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
generic_matmatmul!(
C,
adj_or_trans_char(A),
adj_or_trans_char(B),
_parent(A),
_parent(B),
MulAddMul(alpha, beta)
MulAddMul(α, β)
)

"""
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ for op in (:+, :-)
end
end

# disambiguation between triangular and banded matrices, banded ones "dominate"
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix) = _mul!(C, A, B, MulAddMul())
mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular) = _mul!(C, A, B, MulAddMul())
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix, alpha::Number, beta::Number) =
_mul!(C, A, B, MulAddMul(alpha, beta))
mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular, alpha::Number, beta::Number) =
_mul!(C, A, B, MulAddMul(alpha, beta))

function *(H::UpperHessenberg, B::Bidiagonal)
T = promote_op(matprod, eltype(H), eltype(B))
A = mul!(similar(H, T, size(H)), H, B)
Expand Down
Loading

0 comments on commit b21f100

Please sign in to comment.