Skip to content

Commit

Permalink
Provides rudimentary interconversions between special matrix types
Browse files Browse the repository at this point in the history
convert() now interconverts Diagonal, Bidiagonal, SymTridiagonal,
Triangular and Matrix using naïve methods.

Moves convert() out of diagonal.jl
  • Loading branch information
jiahao committed Jan 10, 2014
1 parent a17879e commit 5e3f074
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 4 deletions.
4 changes: 0 additions & 4 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ Diagonal(A::Matrix) = Diagonal(diag(A))
size(D::Diagonal) = (length(D.diag),length(D.diag))
size(D::Diagonal,d::Integer) = d<1 ? error("dimension out of range") : (d<=2 ? length(D.diag) : 1)

convert{T}(::Type{Matrix{T}}, D::Diagonal{T}) = diagm(D.diag)
convert{T}(::Type{SymTridiagonal{T}}, D::Diagonal{T}) = SymTridiagonal(D.diag,zeros(T,length(D.diag)-1))
convert{T}(::Type{Tridiagonal{T}}, D::Diagonal{T}) = Tridiagonal(zeros(T,length(D.diag)-1),D.diag,zeros(T,length(D.diag)-1))

full(D::Diagonal) = diagm(D.diag)
getindex(D::Diagonal, i::Integer, j::Integer) = i == j ? D.diag[i] : zero(eltype(D.diag))

Expand Down
69 changes: 69 additions & 0 deletions base/linalg/special.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,74 @@
#Methods operating on different special matrix types

#Interconversion between special matrix types
import Base.convert
convert{T}(::Type{Bidiagonal}, A::Diagonal{T})=Bidiagonal(A.diag, zeros(T, size(A.diag,1)-1), true)
convert{T}(::Type{SymTridiagonal}, A::Diagonal{T})=SymTridiagonal(A.diag, zeros(T, size(A.diag,1)-1))
convert{T}(::Type{Tridiagonal}, A::Diagonal{T})=Tridiagonal(zeros(T, size(A.diag,1)-1), A.diag, zeros(T, size(A.diag,1)-1))
convert(::Type{Triangular}, A::Union(Diagonal, Bidiagonal, SymTridiagonal, Tridiagonal))=Triangular(full(A))
convert(::Type{Matrix}, D::Diagonal) = diagm(D.diag)

function convert(::Type{Diagonal}, A::Union(Bidiagonal, SymTridiagonal))
all(A.ev .== 0) || throw(ArgumentError("Matrix cannot be represented as Diagonal"))
Diagonal(A.dv)
end

function convert(::Type{SymTridiagonal}, A::Bidiagonal)
all(A.ev .== 0) || throw(ArgumentError("Matrix cannot be represented as SymTridiagonal"))
SymTridiagonal(A.dv, A.ev)
end

convert{T}(::Type{Tridiagonal}, A::Bidiagonal{T})=Tridiagonal(A.isupper?zeros(T, size(A.dv,1)-1):A.ev, A.dv, A.isupper?A.ev:zeros(T, size(A.dv,1)-1))

function convert(::Type{Bidiagonal}, A::SymTridiagonal)
all(A.ev .== 0) || throw(ArgumentError("Matrix cannot be represented as Bidiagonal"))
Bidiagonal(A.dv, A.ev, true)
end

function convert(::Type{Diagonal}, A::Tridiagonal)
all(A.dl .== 0) && all(A.du .== 0) || throw(ArgumentError("Matrix cannot be represented as Diagonal"))
Diagonal(A.d)
end

function convert(::Type{Bidiagonal}, A::Tridiagonal)
if all(A.dl .== 0) return Bidiagonal(A.d, A.du, true)
elseif all(A.du .== 0) return Bidiagonal(A.d, A.dl, true)
else throw(ArgumentError("Matrix cannot be represented as Bidiagonal"))
end
end

function convert(::Type{SymTridiagonal}, A::Tridiagonal)
all(A.dl .== A.du) || throw(ArgumentError("Matrix cannot be represented as SymTridiagonal"))
SymTridiagonal(A.d, A.dl)
end

function convert(::Type{Diagonal}, A::Triangular)
full(A) == diagm(diag(A)) || throw(ArgumentError("Matrix cannot be represented as Diagonal"))
Diagonal(diag(A))
end

function convert(::Type{Bidiagonal}, A::Triangular)
fA = full(A)
if fA == diagm(diag(A)) + diagm(diag(fA, 1), 1)
return Bidiagonal(diag(A), diag(fA,1), true)
elseif fA == diagm(diag(A)) + diagm(diag(fA, -1), -1)
return Bidiagonal(diag(A), diag(fA,-1), true)
else
throw(ArgumentError("Matrix cannot be represented as Bidiagonal"))
end
end

convert(::Type{SymTridiagonal}, A::Triangular) = convert(SymTridiagonal, convert(Tridiagonal, A))

function convert(::Type{Tridiagonal}, A::Triangular)
fA = full(A)
if fA == diagm(diag(A)) + diagm(diag(fA, 1), 1) + diagm(diag(fA, -1), -1)
return Tridiagonal(diag(fA, -1), diag(A), diag(fA,1))
else
throw(ArgumentError("Matrix cannot be represented as Tridiagonal"))
end
end

#Constructs two method definitions taking into account (assumed) commutativity
# e.g. @commutative f{S,T}(x::S, y::T) = x+y is the same is defining
# f{S,T}(x::S, y::T) = x+y
Expand Down
35 changes: 35 additions & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,41 @@ for relty in (Float16, Float32, Float64, BigFloat), elty in (relty, Complex{relt
end
end

#Test interconversion between special matrix types
using Base.Test

N=12
A=Diagonal([1:N]*1.0)
for newtype in [Diagonal, Bidiagonal, SymTridiagonal, Tridiagonal, Triangular, Matrix]
@test full(convert(newtype, A)) == full(A)
end

for isupper in (true, false)
A=Bidiagonal([1:N]*1.0, [1:N-1]*1.0, isupper)
for newtype in [Bidiagonal, Tridiagonal, Triangular, Matrix]
@test full(convert(newtype, A)) == full(A)
end
A=Bidiagonal([1:N]*1.0, [1:N-1]*0.0, isupper) #morally Diagonal
for newtype in [Diagonal, Bidiagonal, SymTridiagonal, Tridiagonal, Triangular, Matrix]
@test full(convert(newtype, A)) == full(A)
end
end

A=SymTridiagonal([1:N]*1.0, [1:N-1]*1.0)
for newtype in [Tridiagonal, Matrix]
@test full(convert(newtype, A)) == full(A)
end

A=Tridiagonal([1:N-1]*0.0, [1:N]*1.0, [1:N-1]*0.0) #morally Diagonal
for newtype in [Diagonal, Bidiagonal, SymTridiagonal, Triangular, Matrix]
@test full(convert(newtype, A)) == full(A)
end

A=Triangular(full(Diagonal([1:N]*1.0))) #morally Diagonal
for newtype in [Diagonal, Bidiagonal, SymTridiagonal, Triangular, Matrix]
@test full(convert(newtype, A)) == full(A)
end

# Test gglse
for elty in (Float32, Float64, Complex64, Complex128)
A = convert(Array{elty, 2}, [1 1 1 1; 1 3 1 1; 1 -1 3 1; 1 1 1 3; 1 1 1 -1])
Expand Down

0 comments on commit 5e3f074

Please sign in to comment.