Skip to content

Commit

Permalink
improve copyto!(::AbstractMatrix,::SpecialMatrix)`
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Mar 15, 2022
1 parent 8702ee7 commit 8ec2be0
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 41 deletions.
39 changes: 22 additions & 17 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,29 +165,34 @@ function Base.replace_in_print_matrix(A::Bidiagonal,i::Integer,j::Integer,s::Abs
end

#Converting from Bidiagonal to dense Matrix
function Matrix{T}(A::Bidiagonal) where T
n = size(A, 1)
B = zeros(T, n, n)
if n == 0
return B
end
for i = 1:n - 1
B[i,i] = A.dv[i]
if A.uplo == 'U'
B[i, i + 1] = A.ev[i]
else
B[i + 1, i] = A.ev[i]
end
end
B[n,n] = A.dv[n]
return B
end
Matrix{T}(A::Bidiagonal) where {T} = copyto!(Matrix{T}(undef, size(A)), A)
Matrix(A::Bidiagonal{T}) where {T} = Matrix{T}(A)
Array(A::Bidiagonal) = Matrix(A)
promote_rule(::Type{Matrix{T}}, ::Type{<:Bidiagonal{S}}) where {T,S} =
@isdefined(T) && @isdefined(S) ? Matrix{promote_type(T,S)} : Matrix
promote_rule(::Type{Matrix}, ::Type{<:Bidiagonal}) = Matrix

function copyto!(A::AbstractMatrix{T}, B::Bidiagonal) where {T}
require_one_based_indexing(A)
n = size(B, 1)
n == 0 && return A
if size(A) == (n, n)
fill!(A, zero(T))
@inbounds for i in 1:n - 1
A[i,i] = B.dv[i]
if B.uplo == 'U'
A[i, i + 1] = B.ev[i]
else
A[i + 1, i] = B.ev[i]
end
end
A[n,n] = B.dv[n]
return A
else
return @invoke copyto!(A::AbstractMatrix, B::AbstractMatrix)
end
end

#Converting from Bidiagonal to Tridiagonal
function Tridiagonal{T}(A::Bidiagonal) where T
dv = convert(AbstractVector{T}, A.dv)
Expand Down
14 changes: 14 additions & 0 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ similar(D::Diagonal, ::Type{T}) where {T} = Diagonal(similar(D.diag, T))
similar(::Diagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = zeros(T, dims...)

copyto!(D1::Diagonal, D2::Diagonal) = (copyto!(D1.diag, D2.diag); D1)
function copyto!(A::AbstractMatrix{T}, D::Diagonal) where {T}
require_one_based_indexing(A)
n = length(D.diag)
n == 0 && return A
if size(A) == (n, n)
fill!(A, zero(T))
@inbounds for i in 1:n
A[i, i] = D.diag[i]
end
return A
else
return @invoke copyto!(A::AbstractMatrix, D::AbstractMatrix)
end
end

size(D::Diagonal) = (n = length(D.diag); (n,n))

Expand Down
8 changes: 4 additions & 4 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,9 @@ lmul!(Q::Adjoint{<:Any,<:QRPackedQ}, B::AbstractTriangular) = lmul!(Q, full!(B))
function _qlmul(Q::AbstractQ, B)
TQB = promote_type(eltype(Q), eltype(B))
if size(Q.factors, 1) == size(B, 1)
Bnew = Matrix{TQB}(B)
Bnew = copy_similar(B, TQB)
elseif size(Q.factors, 2) == size(B, 1)
Bnew = [Matrix{TQB}(B); zeros(TQB, size(Q.factors, 1) - size(B,1), size(B, 2))]
Bnew = [copy_similar(B, TQB); zeros(TQB, size(Q.factors, 1) - size(B,1), size(B, 2))]
else
throw(DimensionMismatch("first dimension of matrix must have size either $(size(Q.factors, 1)) or $(size(Q.factors, 2))"))
end
Expand Down Expand Up @@ -331,9 +331,9 @@ function _qrmul(A, adjQ::Adjoint{<:Any,<:AbstractQ})
Q = adjQ.parent
TAQ = promote_type(eltype(A), eltype(Q))
if size(A,2) == size(Q.factors, 1)
Anew = Matrix{TAQ}(A)
Anew = copy_similar(A, TAQ)
elseif size(A,2) == size(Q.factors,2)
Anew = [Matrix{TAQ}(A) zeros(TAQ, size(A, 1), size(Q.factors, 1) - size(Q.factors, 2))]
Anew = [copy_similar(A, TAQ) zeros(TAQ, size(A, 1), size(Q.factors, 1) - size(Q.factors, 2))]
else
throw(DimensionMismatch("matrix A has dimensions $(size(A)) but matrix B has dimensions $(size(Q))"))
end
Expand Down
50 changes: 32 additions & 18 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,23 @@ SymTridiagonal(S::SymTridiagonal) = S
AbstractMatrix{T}(S::SymTridiagonal) where {T} =
SymTridiagonal(convert(AbstractVector{T}, S.dv)::AbstractVector{T},
convert(AbstractVector{T}, S.ev)::AbstractVector{T})
function Matrix{T}(M::SymTridiagonal) where T
Matrix{T}(M::SymTridiagonal) where {T} = copyto!(Matrix{T}(undef, size(M)), M)
function copyto!(A::AbstractMatrix{T}, M::SymTridiagonal) where {T}
require_one_based_indexing(A)
n = size(M, 1)
Mf = zeros(T, n, n)
n == 0 && return Mf
@inbounds for i = 1:n-1
Mf[i,i] = symmetric(M.dv[i], :U)
Mf[i+1,i] = transpose(M.ev[i])
Mf[i,i+1] = M.ev[i]
n == 0 && return A
if size(A) == (n, n)
fill!(A, zero(T))
@inbounds for i in 1:n-1
A[i,i] = symmetric(M.dv[i], :U)
A[i+1,i] = transpose(M.ev[i])
A[i,i+1] = M.ev[i]
end
A[n,n] = symmetric(M.dv[n], :U)
return A
else
return @invoke copyto!(A::AbstractMatrix, M::AbstractMatrix)
end
Mf[n,n] = symmetric(M.dv[n], :U)
return Mf
end
Matrix(M::SymTridiagonal{T}) where {T} = Matrix{T}(M)
Array(M::SymTridiagonal) = Matrix(M)
Expand Down Expand Up @@ -571,16 +577,24 @@ function size(M::Tridiagonal, d::Integer)
end
end

function Matrix{T}(M::Tridiagonal{T}) where T
A = zeros(T, size(M))
for i = 1:length(M.d)
A[i,i] = M.d[i]
end
for i = 1:length(M.d)-1
A[i+1,i] = M.dl[i]
A[i,i+1] = M.du[i]
Matrix{T}(M::Tridiagonal) where {T} = copyto!(Matrix{T}(undef, size(M)), M)

function copyto!(A::AbstractMatrix{T}, M::Tridiagonal) where {T}
require_one_based_indexing(A)
n = size(M, 1)
n == 0 && return A
if size(A) == (n, n)
fill!(A, zero(T))
@inbounds for i = 1:n-1
A[i,i] = M.d[i]
A[i+1,i] = M.dl[i]
A[i,i+1] = M.du[i]
end
A[n,n] = M.d[n]
return A
else
@invoke copyto!(A::AbstractMatrix, M::AbstractMatrix)
end
A
end
Matrix(M::Tridiagonal{T}) where {T} = Matrix{T}(M)
Array(M::Tridiagonal) = Matrix(M)
Expand Down
5 changes: 3 additions & 2 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,13 @@ end
dl = [1, 1, 1]
d = [1, 1, 1, 1]
D = Diagonal(d)
Bi = Bidiagonal(d, dl, :L)
Bl = Bidiagonal(d, dl, :L)
Bu = Bidiagonal(d, dl, :U)
Tri = Tridiagonal(dl, d, dl)
Sym = SymTridiagonal(d, dl)
F = qr(ones(4, 1))
A = F.Q'
for A in (F.Q, F.Q'), B in (D, Bi, Tri, Sym)
for A in (F.Q, F.Q'), B in (D, Bl, Bu, Tri, Sym)
@test B*A Matrix(B)*A
@test A*B A*Matrix(B)
end
Expand Down

0 comments on commit 8ec2be0

Please sign in to comment.