Skip to content

Commit

Permalink
Add lmul! and rmul! for Bidiagonal (#51777)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Nov 19, 2023
1 parent 19ca07d commit ec3911c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
44 changes: 41 additions & 3 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,17 @@ end

const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or BiDiTriSym
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
const TriSym = Union{Tridiagonal,SymTridiagonal}
const BiTri = Union{Bidiagonal,Tridiagonal}
@inline mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))

lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())

function check_A_mul_B!_sizes(C, A, B)
mA, nA = size(A)
mB, nB = size(B)
Expand Down Expand Up @@ -460,7 +464,11 @@ function _diag(A::Bidiagonal, k)
end
end

function _mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, _add::MulAddMul = MulAddMul())
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul = MulAddMul()) =
_bibimul!(C, A, B, _add)
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul = MulAddMul()) =
_bibimul!(C, A, B, _add)
function _bibimul!(C, A, B, _add)
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
Expand Down Expand Up @@ -583,7 +591,7 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
C
end

function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, _add::MulAddMul = MulAddMul())
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A)
check_A_mul_B!_sizes(C, A, B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
Expand Down Expand Up @@ -618,7 +626,37 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, _add::MulAddMu
C
end

function _mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, _add::MulAddMul = MulAddMul())
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A)
check_A_mul_B!_sizes(C, A, B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
if size(A, 1) <= 3 || size(B, 2) <= 1
return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
end
m, n = size(A)
@inbounds if B.uplo == 'U'
for i in 1:m
for j in n:-1:2
_modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j))
end
_modify!(_add, A[i,1] * B.dv[1], C, (i, 1))
end
else # uplo == 'L'
for i in 1:m
for j in 1:n-1
_modify!(_add, A[i,j] * B.dv[j] + A[i,j+1] * B.ev[j], C, (i, j))
end
_modify!(_add, A[i,n] * B.dv[n], C, (i, n))
end
end
C
end

_mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul = MulAddMul()) =
_dibimul!(C, A, B, _add)
_mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul = MulAddMul()) =
_dibimul!(C, A, B, _add)
function _dibimul!(C, A, B, _add)
require_one_based_indexing(C)
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
Expand Down
12 changes: 9 additions & 3 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,14 +439,18 @@ Random.seed!(1)
for op in (+, -, *)
@test Array(op(T, T2)) op(Tfull, Tfull2)
end
A = kron(T.dv, T.dv')
@test T * A lmul!(T, copy(A))
@test A * T rmul!(copy(A), T)
end
# test pass-through of mul! for SymTridiagonal*Bidiagonal
TriSym = SymTridiagonal(T.dv, T.ev)
@test Array(TriSym*T) Array(TriSym)*Array(T)
# test pass-through of mul! for AbstractTriangular*Bidiagonal
Tri = UpperTriangular(diagm(1 => T.ev))
Dia = Diagonal(T.dv)
@test Array(Tri*T) Array(Tri)*Array(T)
@test Array(Tri*T) Array(Tri)*Array(T) rmul!(copy(Tri), T)
@test Array(T*Tri) Array(T)*Array(Tri) lmul!(T, copy(Tri))
# test mul! itself for these types
for AA in (Tri, Dia)
for f in (identity, transpose, adjoint)
Expand All @@ -459,8 +463,10 @@ Random.seed!(1)
for f in (identity, transpose, adjoint)
C = relty == Int ? rand(float(elty), n, n) : rand(elty, n, n)
B = rand(elty, n, n)
D = copy(C) + 2.0 * Array(T*f(B))
mul!(C, T, f(B), 2.0, 1.0) D
D = C + 2.0 * Array(T*f(B))
@test mul!(C, T, f(B), 2.0, 1.0) D
@test lmul!(T, copy(f(B))) T * f(B)
@test rmul!(copy(f(B)), T) f(B) * T
end

# Issue #31870
Expand Down

0 comments on commit ec3911c

Please sign in to comment.