Skip to content

Commit

Permalink
replace fill!(..., 0) with fill!(..., zero(T)) for structured matrices (
Browse files Browse the repository at this point in the history
  • Loading branch information
mcognetta committed Oct 18, 2021
1 parent b45b9bb commit 806b128
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 50 deletions.
32 changes: 16 additions & 16 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,53 +298,53 @@ function istril(M::Bidiagonal, k::Integer=0)
end
isdiag(M::Bidiagonal) = iszero(M.ev)

function tril!(M::Bidiagonal, k::Integer=0)
function tril!(M::Bidiagonal{T}, k::Integer=0) where T
n = length(M.dv)
if !(-n - 1 <= k <= n - 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n - 1) and at most $(n - 1) in an $n-by-$n matrix")))
elseif M.uplo == 'U' && k < 0
fill!(M.dv,0)
fill!(M.ev,0)
fill!(M.dv, zero(T))
fill!(M.ev, zero(T))
elseif k < -1
fill!(M.dv,0)
fill!(M.ev,0)
fill!(M.dv, zero(T))
fill!(M.ev, zero(T))
elseif M.uplo == 'U' && k == 0
fill!(M.ev,0)
fill!(M.ev, zero(T))
elseif M.uplo == 'L' && k == -1
fill!(M.dv,0)
fill!(M.dv, zero(T))
end
return M
end

function triu!(M::Bidiagonal, k::Integer=0)
function triu!(M::Bidiagonal{T}, k::Integer=0) where T
n = length(M.dv)
if !(-n + 1 <= k <= n + 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least",
"$(-n + 1) and at most $(n + 1) in an $n-by-$n matrix")))
elseif M.uplo == 'L' && k > 0
fill!(M.dv,0)
fill!(M.ev,0)
fill!(M.dv, zero(T))
fill!(M.ev, zero(T))
elseif k > 1
fill!(M.dv,0)
fill!(M.ev,0)
fill!(M.dv, zero(T))
fill!(M.ev, zero(T))
elseif M.uplo == 'L' && k == 0
fill!(M.ev,0)
fill!(M.ev, zero(T))
elseif M.uplo == 'U' && k == 1
fill!(M.dv,0)
fill!(M.dv, zero(T))
end
return M
end

function diag(M::Bidiagonal, n::Integer=0)
function diag(M::Bidiagonal{T}, n::Integer=0) where T
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
if n == 0
return copyto!(similar(M.dv, length(M.dv)), M.dv)
elseif (n == 1 && M.uplo == 'U') || (n == -1 && M.uplo == 'L')
return copyto!(similar(M.ev, length(M.ev)), M.ev)
elseif -size(M,1) <= n <= size(M,1)
return fill!(similar(M.dv, size(M,1)-abs(n)), 0)
return fill!(similar(M.dv, size(M,1)-abs(n)), zero(T))
else
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
Expand Down
16 changes: 8 additions & 8 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,24 @@ isdiag(D::Diagonal) = all(isdiag, D.diag)
isdiag(D::Diagonal{<:Number}) = true
istriu(D::Diagonal, k::Integer=0) = k <= 0 || iszero(D.diag) ? true : false
istril(D::Diagonal, k::Integer=0) = k >= 0 || iszero(D.diag) ? true : false
function triu!(D::Diagonal,k::Integer=0)
function triu!(D::Diagonal{T}, k::Integer=0) where T
n = size(D,1)
if !(-n + 1 <= k <= n + 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n + 1) and at most $(n + 1) in an $n-by-$n matrix")))
elseif k > 0
fill!(D.diag,0)
fill!(D.diag, zero(T))
end
return D
end

function tril!(D::Diagonal,k::Integer=0)
function tril!(D::Diagonal{T}, k::Integer=0) where T
n = size(D,1)
if !(-n - 1 <= k <= n - 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n - 1) and at most $(n - 1) in an $n-by-$n matrix")))
elseif k < 0
fill!(D.diag,0)
fill!(D.diag, zero(T))
end
return D
end
Expand Down Expand Up @@ -488,13 +488,13 @@ adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag))
Base.permutedims(D::Diagonal) = D
Base.permutedims(D::Diagonal, perm) = (Base.checkdims_perm(D, D, perm); D)

function diag(D::Diagonal, k::Integer=0)
function diag(D::Diagonal{T}, k::Integer=0) where T
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of k
if k == 0
return copyto!(similar(D.diag, length(D.diag)), D.diag)
elseif -size(D,1) <= k <= size(D,1)
return fill!(similar(D.diag, size(D,1)-abs(k)), 0)
return fill!(similar(D.diag, size(D,1)-abs(k)), zero(T))
else
throw(ArgumentError(string("requested diagonal, $k, must be at least $(-size(D, 1)) ",
"and at most $(size(D, 2)) for an $(size(D, 1))-by-$(size(D, 2)) matrix")))
Expand Down Expand Up @@ -586,12 +586,12 @@ end
#Singular system
svdvals(D::Diagonal{<:Number}) = sort!(abs.(D.diag), rev = true)
svdvals(D::Diagonal) = [svdvals(v) for v in D.diag]
function svd(D::Diagonal{<:Number})
function svd(D::Diagonal{T}) where T<:Number
S = abs.(D.diag)
piv = sortperm(S, rev = true)
U = Diagonal(D.diag ./ S)
Up = hcat([U[:,i] for i = 1:length(D.diag)][piv]...)
V = Diagonal(fill!(similar(D.diag), one(eltype(D.diag))))
V = Diagonal(fill!(similar(D.diag), one(T)))
Vp = hcat([V[:,i] for i = 1:length(D.diag)][piv]...)
return SVD(Up, S[piv], copy(Vp'))
end
Expand Down
52 changes: 26 additions & 26 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(ad
ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S))
issymmetric(S::SymTridiagonal) = true

function diag(M::SymTridiagonal{<:Number}, n::Integer=0)
function diag(M::SymTridiagonal{T}, n::Integer=0) where T<:Number
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
absn = abs(n)
Expand All @@ -184,7 +184,7 @@ function diag(M::SymTridiagonal{<:Number}, n::Integer=0)
elseif absn == 1
return copyto!(similar(M.ev, length(M.dv)-1), _evview(M))
elseif absn <= size(M,1)
return fill!(similar(M.dv, size(M,1)-absn), 0)
return fill!(similar(M.dv, size(M,1)-absn), zero(T))
else
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
Expand Down Expand Up @@ -376,17 +376,17 @@ isone(M::SymTridiagonal) = iszero(_evview(M)) && all(isone, M.dv)
isdiag(M::SymTridiagonal) = iszero(_evview(M))


function tril!(M::SymTridiagonal, k::Integer=0)
function tril!(M::SymTridiagonal{T}, k::Integer=0) where T
n = length(M.dv)
if !(-n - 1 <= k <= n - 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n - 1) and at most $(n - 1) in an $n-by-$n matrix")))
elseif k < -1
fill!(M.ev,0)
fill!(M.dv,0)
fill!(M.ev, zero(T))
fill!(M.dv, zero(T))
return Tridiagonal(M.ev,M.dv,copy(M.ev))
elseif k == -1
fill!(M.dv,0)
fill!(M.dv, zero(T))
return Tridiagonal(M.ev,M.dv,zero(M.ev))
elseif k == 0
return Tridiagonal(M.ev,M.dv,zero(M.ev))
Expand All @@ -395,17 +395,17 @@ function tril!(M::SymTridiagonal, k::Integer=0)
end
end

function triu!(M::SymTridiagonal, k::Integer=0)
function triu!(M::SymTridiagonal{T}, k::Integer=0) where T
n = length(M.dv)
if !(-n + 1 <= k <= n + 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n + 1) and at most $(n + 1) in an $n-by-$n matrix")))
elseif k > 1
fill!(M.ev,0)
fill!(M.dv,0)
fill!(M.ev, zero(T))
fill!(M.dv, zero(T))
return Tridiagonal(M.ev,M.dv,copy(M.ev))
elseif k == 1
fill!(M.dv,0)
fill!(M.dv, zero(T))
return Tridiagonal(zero(M.ev),M.dv,M.ev)
elseif k == 0
return Tridiagonal(zero(M.ev),M.dv,M.ev)
Expand Down Expand Up @@ -617,7 +617,7 @@ issymmetric(S::Tridiagonal) = S.du == S.dl

\(A::Adjoint{<:Any,<:Tridiagonal}, B::Adjoint{<:Any,<:StridedVecOrMat}) = copy(A) \ B

function diag(M::Tridiagonal, n::Integer=0)
function diag(M::Tridiagonal{T}, n::Integer=0) where T
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
if n == 0
Expand All @@ -627,7 +627,7 @@ function diag(M::Tridiagonal, n::Integer=0)
elseif n == 1
return copyto!(similar(M.du, length(M.du)), M.du)
elseif abs(n) <= size(M,1)
return fill!(similar(M.d, size(M,1)-abs(n)), 0)
return fill!(similar(M.d, size(M,1)-abs(n)), zero(T))
else
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
Expand Down Expand Up @@ -696,38 +696,38 @@ function istril(M::Tridiagonal, k::Integer=0)
end
isdiag(M::Tridiagonal) = iszero(M.dl) && iszero(M.du)

function tril!(M::Tridiagonal, k::Integer=0)
function tril!(M::Tridiagonal{T}, k::Integer=0) where T
n = length(M.d)
if !(-n - 1 <= k <= n - 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n - 1) and at most $(n - 1) in an $n-by-$n matrix")))
elseif k < -1
fill!(M.dl,0)
fill!(M.d,0)
fill!(M.du,0)
fill!(M.dl, zero(T))
fill!(M.d, zero(T))
fill!(M.du, zero(T))
elseif k == -1
fill!(M.d,0)
fill!(M.du,0)
fill!(M.d, zero(T))
fill!(M.du, zero(T))
elseif k == 0
fill!(M.du,0)
fill!(M.du, zero(T))
end
return M
end

function triu!(M::Tridiagonal, k::Integer=0)
function triu!(M::Tridiagonal{T}, k::Integer=0) where T
n = length(M.d)
if !(-n + 1 <= k <= n + 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n + 1) and at most $(n + 1) in an $n-by-$n matrix")))
elseif k > 1
fill!(M.dl,0)
fill!(M.d,0)
fill!(M.du,0)
fill!(M.dl, zero(T))
fill!(M.d, zero(T))
fill!(M.du, zero(T))
elseif k == 1
fill!(M.dl,0)
fill!(M.d,0)
fill!(M.dl, zero(T))
fill!(M.d, zero(T))
elseif k == 0
fill!(M.dl,0)
fill!(M.dl, zero(T))
end
return M
end
Expand Down

0 comments on commit 806b128

Please sign in to comment.