diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index f8cc3ceadcfad..78c79b6fcefac 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -426,6 +426,7 @@ end const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or BiDiTriSym const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal} +const TriSym = Union{Tridiagonal,SymTridiagonal} const BiTri = Union{Bidiagonal,Tridiagonal} @inline mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) @inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) @@ -433,6 +434,9 @@ const BiTri = Union{Bidiagonal,Tridiagonal} @inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) @inline mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul()) +rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul()) + function check_A_mul_B!_sizes(C, A, B) mA, nA = size(A) mB, nB = size(B) @@ -460,7 +464,11 @@ function _diag(A::Bidiagonal, k) end end -function _mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, _add::MulAddMul = MulAddMul()) +_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul = MulAddMul()) = + _bibimul!(C, A, B, _add) +_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul = MulAddMul()) = + _bibimul!(C, A, B, _add) +function _bibimul!(C, A, B, _add) check_A_mul_B!_sizes(C, A, B) n = size(A,1) n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) @@ -583,7 +591,7 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA C end -function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, _add::MulAddMul = MulAddMul()) +function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A) check_A_mul_B!_sizes(C, A, B) iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) @@ -618,7 +626,37 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, _add::MulAddMu C end -function _mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, _add::MulAddMul = MulAddMul()) +function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul = MulAddMul()) + require_one_based_indexing(C, A) + check_A_mul_B!_sizes(C, A, B) + iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) + if size(A, 1) <= 3 || size(B, 2) <= 1 + return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) + end + m, n = size(A) + @inbounds if B.uplo == 'U' + for i in 1:m + for j in n:-1:2 + _modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j)) + end + _modify!(_add, A[i,1] * B.dv[1], C, (i, 1)) + end + else # uplo == 'L' + for i in 1:m + for j in 1:n-1 + _modify!(_add, A[i,j] * B.dv[j] + A[i,j+1] * B.ev[j], C, (i, j)) + end + _modify!(_add, A[i,n] * B.dv[n], C, (i, n)) + end + end + C +end + +_mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul = MulAddMul()) = + _dibimul!(C, A, B, _add) +_mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul = MulAddMul()) = + _dibimul!(C, A, B, _add) +function _dibimul!(C, A, B, _add) require_one_based_indexing(C) check_A_mul_B!_sizes(C, A, B) n = size(A,1) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index a3e5a2f437e93..2fce781e30ab1 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -439,6 +439,9 @@ Random.seed!(1) for op in (+, -, *) @test Array(op(T, T2)) ≈ op(Tfull, Tfull2) end + A = kron(T.dv, T.dv') + @test T * A ≈ lmul!(T, copy(A)) + @test A * T ≈ rmul!(copy(A), T) end # test pass-through of mul! for SymTridiagonal*Bidiagonal TriSym = SymTridiagonal(T.dv, T.ev) @@ -446,7 +449,8 @@ Random.seed!(1) # test pass-through of mul! for AbstractTriangular*Bidiagonal Tri = UpperTriangular(diagm(1 => T.ev)) Dia = Diagonal(T.dv) - @test Array(Tri*T) ≈ Array(Tri)*Array(T) + @test Array(Tri*T) ≈ Array(Tri)*Array(T) ≈ rmul!(copy(Tri), T) + @test Array(T*Tri) ≈ Array(T)*Array(Tri) ≈ lmul!(T, copy(Tri)) # test mul! itself for these types for AA in (Tri, Dia) for f in (identity, transpose, adjoint) @@ -459,8 +463,10 @@ Random.seed!(1) for f in (identity, transpose, adjoint) C = relty == Int ? rand(float(elty), n, n) : rand(elty, n, n) B = rand(elty, n, n) - D = copy(C) + 2.0 * Array(T*f(B)) - mul!(C, T, f(B), 2.0, 1.0) ≈ D + D = C + 2.0 * Array(T*f(B)) + @test mul!(C, T, f(B), 2.0, 1.0) ≈ D + @test lmul!(T, copy(f(B))) ≈ T * f(B) + @test rmul!(copy(f(B)), T) ≈ f(B) * T end # Issue #31870