Skip to content

Commit

Permalink
Fix edge cases in SymTridiagonal when ev has an extra element (`+…
Browse files Browse the repository at this point in the history
…`, `-`, `iszero`, `isone`, etc.) (JuliaLang#42472)
  • Loading branch information
mcognetta committed Oct 8, 2021
1 parent f2080d5 commit 146de38
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 29 deletions.
7 changes: 6 additions & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ end

@noinline throw_uplo() = throw(ArgumentError("uplo argument must be either :U (upper) or :L (lower)"))


"""
ldiv!(Y, A, B) -> Y
Expand Down Expand Up @@ -453,6 +452,12 @@ export ⋅, ×
_cut_B(x::AbstractVector, r::UnitRange) = length(x) > length(r) ? x[r] : x
_cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X

# SymTridiagonal ev can be the same length as dv, but the last element is
# ignored. However, some methods can fail if they read the entired ev
# rather than just the meaningful elements. This is a helper function
# for getting only the meaningful elements of ev. See #41089
_evview(S::SymTridiagonal) = @view S.ev[begin:length(S.dv) - 1]

## append right hand side with zeros if necessary
_zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n))
_zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2))
Expand Down
31 changes: 17 additions & 14 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ Tridiagonal(A::Bidiagonal) =

# conversions from SymTridiagonal to other special matrix types
Diagonal(A::SymTridiagonal) = Diagonal(A.dv)

# These can fail when ev has the same length as dv
# TODO: Revisit when a good solution for #42477 is found
Bidiagonal(A::SymTridiagonal) =
iszero(A.ev) ? Bidiagonal(A.dv, A.ev, :U) :
throw(ArgumentError("matrix cannot be represented as Bidiagonal"))
Expand Down Expand Up @@ -154,10 +157,10 @@ end

# this set doesn't have the aforementioned problem

+(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+B.ev, A.d+B.dv, A.du+B.ev)
-(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-B.ev, A.d-B.dv, A.du-B.ev)
+(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev+B.dl, A.dv+B.d, A.ev+B.du)
-(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev-B.dl, A.dv-B.d, A.ev-B.du)
+(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+_evview(B), A.d+B.dv, A.du+_evview(B))
-(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-_evview(B), A.d-B.dv, A.du-_evview(B))
+(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(_evview(A)+B.dl, A.dv+B.d, _evview(A)+B.du)
-(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(_evview(A)-B.dl, A.dv-B.d, _evview(A)-B.du)


function (+)(A::Diagonal, B::Tridiagonal)
Expand Down Expand Up @@ -202,22 +205,22 @@ end

function (+)(A::Bidiagonal, B::SymTridiagonal)
newdv = A.dv+B.dv
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(B.ev), A.dv+B.dv, A.ev+B.ev) : (A.ev+B.ev, A.dv+B.dv, typeof(newdv)(B.ev)))...)
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(_evview(B)), A.dv+B.dv, A.ev+_evview(B)) : (A.ev+_evview(B), A.dv+B.dv, typeof(newdv)(_evview(B))))...)
end

function (-)(A::Bidiagonal, B::SymTridiagonal)
newdv = A.dv-B.dv
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-B.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(-B.ev)))...)
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-_evview(B)), newdv, A.ev-_evview(B)) : (A.ev-_evview(B), newdv, typeof(newdv)(-_evview(B))))...)
end

function (+)(A::SymTridiagonal, B::Bidiagonal)
newdv = A.dv+B.dv
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev+B.ev) : (A.ev+B.ev, newdv, typeof(newdv)(A.ev)))...)
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(_evview(A)), newdv, _evview(A)+B.ev) : (_evview(A)+B.ev, newdv, typeof(newdv)(_evview(A))))...)
end

function (-)(A::SymTridiagonal, B::Bidiagonal)
newdv = A.dv-B.dv
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(A.ev)))...)
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(_evview(A)), newdv, _evview(A)-B.ev) : (_evview(A)-B.ev, newdv, typeof(newdv)(_evview(A))))...)
end

# fixing uniform scaling problems from #28994
Expand Down Expand Up @@ -312,16 +315,16 @@ one(D::Diagonal) = Diagonal(one.(D.diag))
one(A::Bidiagonal{T}) where T = Bidiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))), A.uplo)
one(A::Tridiagonal{T}) where T = Tridiagonal(fill!(similar(A.du, typeof(one(T))), zero(one(T))), fill!(similar(A.d, typeof(one(T))), one(T)), fill!(similar(A.dl, typeof(one(T))), zero(one(T))))
one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))))
# equals and approx equals methods for structured matrices
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl

zero(D::Diagonal) = Diagonal(zero.(D.diag))
oneunit(D::Diagonal) = Diagonal(oneunit.(D.diag))

# SymTridiagonal and Bidiagonal have the same field names
==(A::Diagonal, B::Union{SymTridiagonal, Bidiagonal}) = iszero(B.ev) && A.diag == B.dv
==(B::Bidiagonal, A::Diagonal) = A == B
# equals and approx equals methods for structured matrices
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl

==(A::Diagonal, B::Bidiagonal) = iszero(B.ev) && A.diag == B.dv
==(A::Diagonal, B::SymTridiagonal) = iszero(_evview(B)) && A.diag == B.dv
==(B::Bidiagonal, A::Diagonal) = A == B
==(A::Diagonal, B::Tridiagonal) = iszero(B.dl) && iszero(B.du) && A.diag == B.d
==(B::Tridiagonal, A::Diagonal) = A == B

Expand All @@ -334,5 +337,5 @@ function ==(A::Bidiagonal, B::Tridiagonal)
end
==(B::Tridiagonal, A::Bidiagonal) = A == B

==(A::Bidiagonal, B::SymTridiagonal) = iszero(B.ev) && iszero(A.ev) && A.dv == B.dv
==(A::Bidiagonal, B::SymTridiagonal) = iszero(_evview(B)) && iszero(A.ev) && A.dv == B.dv
==(B::SymTridiagonal, A::Bidiagonal) = A == B
29 changes: 15 additions & 14 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T
# similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...)

copyto!(dest::SymTridiagonal, src::SymTridiagonal) =
(copyto!(dest.dv, src.dv); copyto!(dest.ev, src.ev); dest)
(copyto!(dest.dv, src.dv); copyto!(dest.ev, _evview(src)); dest)

#Elementary operations
for func in (:conj, :copy, :real, :imag)
Expand All @@ -172,7 +172,7 @@ adjoint(S::SymTridiagonal{<:Real}) = S
adjoint(S::SymTridiagonal) = Adjoint(S)
Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...)

ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(@view S.ev[begin:length(S.dv) - 1])
ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S))
issymmetric(S::SymTridiagonal) = true

function diag(M::SymTridiagonal{<:Number}, n::Integer=0)
Expand All @@ -182,7 +182,7 @@ function diag(M::SymTridiagonal{<:Number}, n::Integer=0)
if absn == 0
return copyto!(similar(M.dv, length(M.dv)), M.dv)
elseif absn == 1
return copyto!(similar(M.ev, length(M.ev)), M.ev)
return copyto!(similar(M.ev, length(M.dv)-1), _evview(M))
elseif absn <= size(M,1)
return fill!(similar(M.dv, size(M,1)-absn), 0)
else
Expand All @@ -196,9 +196,9 @@ function diag(M::SymTridiagonal, n::Integer=0)
if n == 0
return copyto!(similar(M.dv, length(M.dv)), symmetric.(M.dv, :U))
elseif n == 1
return copyto!(similar(M.ev, length(M.ev)), M.ev)
return copyto!(similar(M.ev, length(M.dv)-1), _evview(M))
elseif n == -1
return copyto!(similar(M.ev, length(M.ev)), transpose.(M.ev))
return copyto!(similar(M.ev, length(M.dv)-1), transpose.(_evview(M)))
elseif n <= size(M,1)
throw(ArgumentError("requested diagonal contains undefined zeros of an array type"))
else
Expand All @@ -207,14 +207,14 @@ function diag(M::SymTridiagonal, n::Integer=0)
end
end

+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, A.ev+B.ev)
-(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, A.ev-B.ev)
+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, _evview(A)+_evview(B))
-(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, _evview(A)-_evview(B))
-(A::SymTridiagonal) = SymTridiagonal(-A.dv, -A.ev)
*(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv*B, A.ev*B)
*(B::Number, A::SymTridiagonal) = SymTridiagonal(B*A.dv, B*A.ev)
/(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv/B, A.ev/B)
\(B::Number, A::SymTridiagonal) = SymTridiagonal(B\A.dv, B\A.ev)
==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (A.ev==B.ev)
==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (_evview(A)==_evview(B))

@inline mul!(A::StridedVecOrMat, B::SymTridiagonal, C::StridedVecOrMat,
alpha::Number, beta::Number) =
Expand Down Expand Up @@ -359,21 +359,22 @@ function svdvals!(A::SymTridiagonal)
return sort!(map!(abs, vals, vals); rev=true)
end

#tril and triu
# tril and triu

function istriu(M::SymTridiagonal, k::Integer=0)
if k <= -1
return true
elseif k == 0
return iszero(M.ev)
return iszero(_evview(M))
else # k >= 1
return iszero(M.ev) && iszero(M.dv)
return iszero(_evview(M)) && iszero(M.dv)
end
end
istril(M::SymTridiagonal, k::Integer) = istriu(M, -k)
iszero(M::SymTridiagonal) = iszero(M.ev) && iszero(M.dv)
isone(M::SymTridiagonal) = iszero(M.ev) && all(isone, M.dv)
isdiag(M::SymTridiagonal) = iszero(M.ev)
iszero(M::SymTridiagonal) = iszero(_evview(M)) && iszero(M.dv)
isone(M::SymTridiagonal) = iszero(_evview(M)) && all(isone, M.dv)
isdiag(M::SymTridiagonal) = iszero(_evview(M))


function tril!(M::SymTridiagonal, k::Integer=0)
n = length(M.dv)
Expand Down
23 changes: 23 additions & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,27 @@ end
@test A*Sym A*Matrix(Sym)
end

@testset "Ops on SymTridiagonal ev has the same length as dv" begin
x = rand(3)
y = rand(3)
z = rand(2)

S = SymTridiagonal(x, y)
T = Tridiagonal(z, x, z)
Bu = Bidiagonal(x, z, :U)
Bl = Bidiagonal(x, z, :L)

Ms = Matrix(S)
Mt = Matrix(T)
Mbu = Matrix(Bu)
Mbl = Matrix(Bl)

@test S + T Ms + Mt
@test T + S Mt + Ms
@test S + Bu Ms + Mbu
@test Bu + S Mbu + Ms
@test S + Bl Ms + Mbl
@test Bl + S Mbl + Ms
end

end # module TestSpecial
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@ end
@test !isdiag(Tridiagonal(dl,d,zerosdu))
@test !isdiag(Tridiagonal(zerosdl,d,du))
@test !isdiag(Tridiagonal(dl,d,du))

# Test methods that could fail due to dv and ev having the same length
# see #41089

badev = zero(d)
badev[end] = 1
S = SymTridiagonal(d, badev)

@test istriu(S, -2)
@test istriu(S, 0)
@test !istriu(S, 2)

@test isdiag(S)
end

@testset "iszero and isone" begin
Expand All @@ -190,6 +203,12 @@ end
@test isone(Sone)
@test !iszero(Smix)
@test !isone(Smix)

badev = zeros(elty, 3)
badev[end] = 1

@test isone(SymTridiagonal(ones(elty, 3), badev))
@test iszero(SymTridiagonal(zeros(elty, 3), badev))
end

@testset for mat_type in (Tridiagonal, SymTridiagonal)
Expand Down

0 comments on commit 146de38

Please sign in to comment.