Skip to content

Commit

Permalink
Rework some constructors (#28051)
Browse files Browse the repository at this point in the history
Aside from enforcing 1-indexing, these allow one to coerce some of the
field types via construction (without requiring that the inputs already
have those types).
  • Loading branch information
timholy authored and andreasnoack committed Jul 12, 2018
1 parent 0f4b15a commit e2de8c3
Show file tree
Hide file tree
Showing 15 changed files with 232 additions and 58 deletions.
25 changes: 15 additions & 10 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@ struct Bidiagonal{T,V<:AbstractVector{T}} <: AbstractMatrix{T}
dv::V # diagonal
ev::V # sub/super diagonal
uplo::Char # upper bidiagonal ('U') or lower ('L')
function Bidiagonal{T}(dv::V, ev::V, uplo::AbstractChar) where {T,V<:AbstractVector{T}}
function Bidiagonal{T,V}(dv, ev, uplo::AbstractChar) where {T,V<:AbstractVector{T}}
@assert !has_offset_axes(dv, ev)
if length(ev) != length(dv)-1
throw(DimensionMismatch("length of diagonal vector is $(length(dv)), length of off-diagonal vector is $(length(ev))"))
end
new{T,V}(dv, ev, uplo)
end
function Bidiagonal(dv::V, ev::V, uplo::AbstractChar) where {T,V<:AbstractVector{T}}
Bidiagonal{T}(dv, ev, uplo)
end
end
function Bidiagonal{T,V}(dv, ev, uplo::Symbol) where {T,V<:AbstractVector{T}}
Bidiagonal{T,V}(dv, ev, char_uplo(uplo))
end
function Bidiagonal{T}(dv::AbstractVector, ev::AbstractVector, uplo::Union{Symbol,AbstractChar}) where {T}
Bidiagonal(convert(AbstractVector{T}, dv)::AbstractVector{T},
convert(AbstractVector{T}, ev)::AbstractVector{T},
uplo)
end

"""
Expand Down Expand Up @@ -57,7 +62,10 @@ julia> Bl = Bidiagonal(dv, ev, :L) # ev is on the first subdiagonal
```
"""
function Bidiagonal(dv::V, ev::V, uplo::Symbol) where {T,V<:AbstractVector{T}}
Bidiagonal{T}(dv, ev, char_uplo(uplo))
Bidiagonal{T,V}(dv, ev, char_uplo(uplo))
end
function Bidiagonal(dv::V, ev::V, uplo::AbstractChar) where {T,V<:AbstractVector{T}}
Bidiagonal{T,V}(dv, ev, uplo)
end

"""
Expand Down Expand Up @@ -95,6 +103,8 @@ function Bidiagonal(A::AbstractMatrix, uplo::Symbol)
end

Bidiagonal(A::Bidiagonal) = A
Bidiagonal{T}(A::Bidiagonal{T}) where {T} = A
Bidiagonal{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A.dv, A.ev, A.uplo)

function getindex(A::Bidiagonal{T}, i::Integer, j::Integer) where T
if !((1 <= i <= size(A,2)) && (1 <= j <= size(A,2)))
Expand Down Expand Up @@ -165,11 +175,6 @@ promote_rule(::Type{<:Tridiagonal{T}}, ::Type{<:Bidiagonal{S}}) where {T,S} =
@isdefined(T) && @isdefined(S) ? Tridiagonal{promote_type(T,S)} : Tridiagonal
promote_rule(::Type{<:Tridiagonal}, ::Type{<:Bidiagonal}) = Tridiagonal

# No-op for trivial conversion Bidiagonal{T} -> Bidiagonal{T}
Bidiagonal{T}(A::Bidiagonal{T}) where {T} = A
# Convert Bidiagonal to Bidiagonal{T} by constructing a new instance with converted elements
Bidiagonal{T}(A::Bidiagonal) where {T} =
Bidiagonal(convert(AbstractVector{T}, A.dv), convert(AbstractVector{T}, A.ev), A.uplo)
# When asked to convert Bidiagonal to AbstractMatrix{T}, preserve structure by converting to Bidiagonal{T} <: AbstractMatrix{T}
AbstractMatrix{T}(A::Bidiagonal) where {T} = convert(Bidiagonal{T}, A)

Expand Down
9 changes: 6 additions & 3 deletions stdlib/LinearAlgebra/src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ export lufact!
# also uncomment constructor tests in test/linalg/bidiag.jl
function Bidiagonal(dv::AbstractVector{T}, ev::AbstractVector{S}, uplo::Symbol) where {T,S}
depwarn(string("`Bidiagonal(dv::AbstractVector{T}, ev::AbstractVector{S}, uplo::Symbol) where {T, S}`",
" is deprecated, manually convert both vectors to the same type instead."), :Bidiagonal)
" is deprecated, use `Bidiagonal{R}(dv, ev, uplo)` or `Bidiagonal{R,V}(dv, ev, uplo)` instead,",
" or convert both vectors to the same type manually."), :Bidiagonal)
R = promote_type(T, S)
Bidiagonal(convert(Vector{R}, dv), convert(Vector{R}, ev), uplo)
end
Expand All @@ -73,7 +74,8 @@ end
# also uncomment constructor tests in test/linalg/tridiag.jl
function SymTridiagonal(dv::AbstractVector{T}, ev::AbstractVector{S}) where {T,S}
depwarn(string("`SymTridiagonal(dv::AbstractVector{T}, ev::AbstractVector{S}) ",
"where {T, S}` is deprecated, convert both vectors to the same type instead."), :SymTridiagonal)
"where {T, S}` is deprecated, use `SymTridiagonal{R}(dv, ev)` or `SymTridiagonal{R,V}(dv, ev)` instead,",
" or convert both vectors to the same type manually."), :SymTridiagonal)
R = promote_type(T, S)
SymTridiagonal(convert(Vector{R}, dv), convert(Vector{R}, ev))
end
Expand All @@ -82,7 +84,8 @@ end
# also uncomment constructor tests in test/linalg/tridiag.jl
function Tridiagonal(dl::AbstractVector{Tl}, d::AbstractVector{Td}, du::AbstractVector{Tu}) where {Tl,Td,Tu}
depwarn(string("`Tridiagonal(dl::AbstractVector{Tl}, d::AbstractVector{Td}, du::AbstractVector{Tu}) ",
"where {Tl, Td, Tu}` is deprecated, convert all vectors to the same type instead."), :Tridiagonal)
"where {Tl, Td, Tu}` is deprecated, use `Tridiagonal{T}(dl, d, du)` or `Tridiagonal{T,V}(dl, d, du)` instead,",
" or convert all three vectors to the same type manually."), :Tridiagonal)
Tridiagonal(map(v->convert(Vector{promote_type(Tl,Td,Tu)}, v), (dl, d, du))...)
end

Expand Down
15 changes: 11 additions & 4 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@

struct Diagonal{T,V<:AbstractVector{T}} <: AbstractMatrix{T}
diag::V

function Diagonal{T,V}(diag) where {T,V<:AbstractVector{T}}
@assert !has_offset_axes(diag)
new{T,V}(diag)
end
end
Diagonal(v::AbstractVector{T}) where {T} = Diagonal{T,typeof(v)}(v)
Diagonal{T}(v::AbstractVector) where {T} = Diagonal(convert(AbstractVector{T}, v)::AbstractVector{T})

"""
Diagonal(A::AbstractMatrix)
Expand Down Expand Up @@ -47,11 +55,10 @@ julia> Diagonal(V)
"""
Diagonal(V::AbstractVector)

Diagonal{T}(V::AbstractVector{T}) where {T} = Diagonal{T,typeof(V)}(V)
Diagonal{T}(V::AbstractVector) where {T} = Diagonal{T}(convert(AbstractVector{T}, V))

Diagonal(D::Diagonal) = D
Diagonal{T}(D::Diagonal{T}) where {T} = D
Diagonal{T}(D::Diagonal) where {T} = Diagonal{T}(convert(AbstractVector{T}, D.diag))
Diagonal{T}(D::Diagonal) where {T} = Diagonal{T}(D.diag)

AbstractMatrix{T}(D::Diagonal) where {T} = Diagonal{T}(D)
Matrix(D::Diagonal) = diagm(0 => D.diag)
Array(D::Diagonal) = Matrix(D)
Expand Down
18 changes: 14 additions & 4 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

struct Hessenberg{T,S<:AbstractMatrix} <: Factorization{T}
struct Hessenberg{T,S<:AbstractMatrix{T}} <: Factorization{T}
factors::S
τ::Vector{T}
Hessenberg{T,S}(factors::AbstractMatrix{T}, τ::Vector{T}) where {T,S<:AbstractMatrix} =
new(factors, τ)

function Hessenberg{T,S}(factors, τ) where {T,S<:AbstractMatrix{T}}
@assert !has_offset_axes(factors, τ)
new{T,S}(factors, τ)
end
end
Hessenberg(factors::AbstractMatrix{T}, τ::Vector{T}) where {T} = Hessenberg{T,typeof(factors)}(factors, τ)
function Hessenberg{T}(factors::AbstractMatrix, τ::AbstractVector) where {T}
Hessenberg(convert(AbstractMatrix{T}, factors), convert(Vector{T}, v))
end

Hessenberg(A::StridedMatrix) = Hessenberg(LAPACK.gehrd!(A)...)

# iteration for destructuring into components
Expand Down Expand Up @@ -63,7 +70,10 @@ hessenberg(A::StridedMatrix{T}) where T =
struct HessenbergQ{T,S<:AbstractMatrix} <: AbstractMatrix{T}
factors::S
τ::Vector{T}
HessenbergQ{T,S}(factors::AbstractMatrix{T}, τ::Vector{T}) where {T,S<:AbstractMatrix} = new(factors, τ)
function HessenbergQ{T,S}(factors, τ) where {T,S<:AbstractMatrix}
@assert !has_offset_axes(factors)
new(factors, τ)
end
end
HessenbergQ(factors::AbstractMatrix{T}, τ::Vector{T}) where {T} = HessenbergQ{T,typeof(factors)}(factors, τ)
HessenbergQ(A::Hessenberg) = HessenbergQ(A.factors, A.τ)
Expand Down
19 changes: 13 additions & 6 deletions stdlib/LinearAlgebra/src/ldlt.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

struct LDLt{T,S<:AbstractMatrix} <: Factorization{T}
struct LDLt{T,S<:AbstractMatrix{T}} <: Factorization{T}
data::S

function LDLt{T,S}(data) where {T,S<:AbstractMatrix{T}}
@assert !has_offset_axes(data)
new{T,S}(data)
end
end
LDLt(data::AbstractMatrix{T}) where {T} = LDLt{T,typeof(data)}(data)
LDLt{T}(data::AbstractMatrix) where {T} = LDLt(convert(AbstractMatrix{T}, data)::AbstractMatrix{T})

size(S::LDLt) = size(S.data)
size(S::LDLt, i::Integer) = size(S.data, i)

LDLt{T,S}(F::LDLt) where {T,S<:AbstractMatrix} = LDLt{T,S}(convert(S, F.data))
# NOTE: the annotation <:AbstractMatrix shouldn't be necessary, it is introduced
# to avoid an ambiguity warning (see issue #6383)
LDLt{T}(F::LDLt{S,U}) where {T,S,U<:AbstractMatrix} = LDLt{T,U}(F)
LDLt{T,S}(F::LDLt{T,S}) where {T,S<:AbstractMatrix{T}} = F
LDLt{T,S}(F::LDLt) where {T,S<:AbstractMatrix{T}} = LDLt{T,S}(convert(S, F.data)::S)
LDLt{T}(F::LDLt{T}) where {T} = F
LDLt{T}(F::LDLt) where {T} = LDLt(convert(AbstractMatrix{T}, F.data)::AbstractMatrix{T})

Factorization{T}(F::LDLt{T}) where {T} = F
Factorization{T}(F::LDLt{S,U}) where {T,S,U} = LDLt{T,U}(F)
Factorization{T}(F::LDLt) where {T} = LDLt{T}(F)

# SymTridiagonal
"""
Expand Down
11 changes: 9 additions & 2 deletions stdlib/LinearAlgebra/src/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@

# LQ Factorizations

struct LQ{T,S<:AbstractMatrix} <: Factorization{T}
struct LQ{T,S<:AbstractMatrix{T}} <: Factorization{T}
factors::S
τ::Vector{T}
LQ{T,S}(factors::AbstractMatrix{T}, τ::Vector{T}) where {T,S<:AbstractMatrix} = new(factors, τ)

function LQ{T,S}(factors, τ) where {T,S<:AbstractMatrix{T}}
@assert !has_offset_axes(factors)
new{T,S}(factors, τ)
end
end
LQ(factors::AbstractMatrix{T}, τ::Vector{T}) where {T} = LQ{T,typeof(factors)}(factors, τ)
function LQ{T}(factors::AbstractMatrix, τ::AbstractVector) where {T}
LQ(convert(AbstractMatrix{T}, factors), convert(Vector{T}, τ))
end

# iteration for destructuring into components
Base.iterate(S::LQ) = (S.L, Val(:Q))
Expand Down
17 changes: 14 additions & 3 deletions stdlib/LinearAlgebra/src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@
####################
# LU Factorization #
####################
struct LU{T,S<:AbstractMatrix} <: Factorization{T}
struct LU{T,S<:AbstractMatrix{T}} <: Factorization{T}
factors::S
ipiv::Vector{BlasInt}
info::BlasInt
LU{T,S}(factors::AbstractMatrix{T}, ipiv::Vector{BlasInt}, info::BlasInt) where {T,S} = new(factors, ipiv, info)

function LU{T,S}(factors, ipiv, info) where {T,S<:AbstractMatrix{T}}
@assert !has_offset_axes(factors)
new{T,S}(factors, ipiv, info)
end
end
function LU(factors::AbstractMatrix{T}, ipiv::Vector{BlasInt}, info::BlasInt) where {T}
LU{T,typeof(factors)}(factors, ipiv, info)
end
function LU{T}(factors::AbstractMatrix, ipiv::AbstractVector{<:Integer}, info::Integer) where {T}
LU(convert(AbstractMatrix{T}, factors),
convert(Vector{BlasInt}, ipiv),
BlasInt(info))
end
LU(factors::AbstractMatrix{T}, ipiv::Vector{BlasInt}, info::BlasInt) where {T} = LU{T,typeof(factors)}(factors, ipiv, info)

# iteration for destructuring into components
Base.iterate(S::LU) = (S.L, Val(:U))
Expand Down
60 changes: 48 additions & 12 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,19 @@ The object has two fields:
* `τ` is a vector of length `min(m,n)` containing the coefficients ``\tau_i``.
"""
struct QR{T,S<:AbstractMatrix} <: Factorization{T}
struct QR{T,S<:AbstractMatrix{T}} <: Factorization{T}
factors::S
τ::Vector{T}
QR{T,S}(factors::AbstractMatrix{T}, τ::Vector{T}) where {T,S<:AbstractMatrix} = new(factors, τ)

function QR{T,S}(factors, τ) where {T,S<:AbstractMatrix{T}}
@assert !has_offset_axes(factors)
new{T,S}(factors, τ)
end
end
QR(factors::AbstractMatrix{T}, τ::Vector{T}) where {T} = QR{T,typeof(factors)}(factors, τ)
function QR{T}(factors::AbstractMatrix, τ::AbstractVector) where {T}
QR(convert(AbstractMatrix{T}, factors), convert(Vector{T}, τ))
end

# iteration for destructuring into components
Base.iterate(S::QR) = (S.Q, Val(:R))
Expand Down Expand Up @@ -94,12 +101,19 @@ The object has two fields:
[^Schreiber1989]: R Schreiber and C Van Loan, "A storage-efficient WY representation for products of Householder transformations", SIAM J Sci Stat Comput 10 (1989), 53-57. [doi:10.1137/0910005](https://doi.org/10.1137/0910005)
"""
struct QRCompactWY{S,M<:AbstractMatrix} <: Factorization{S}
struct QRCompactWY{S,M<:AbstractMatrix{S}} <: Factorization{S}
factors::M
T::Matrix{S}
QRCompactWY{S,M}(factors::AbstractMatrix{S}, T::AbstractMatrix{S}) where {S,M<:AbstractMatrix} = new(factors, T)

function QRCompactWY{S,M}(factors, T) where {S,M<:AbstractMatrix{S}}
@assert !has_offset_axes(factors)
new{S,M}(factors, T)
end
end
QRCompactWY(factors::AbstractMatrix{S}, T::Matrix{S}) where {S} = QRCompactWY{S,typeof(factors)}(factors, T)
function QRCompactWY{S}(factors::AbstractMatrix, T::AbstractMatrix) where {S}
QRCompactWY(convert(AbstractMatrix{S}, factors), convert(Matrix{S}, T))
end
QRCompactWY(factors::AbstractMatrix{S}, T::AbstractMatrix{S}) where {S} = QRCompactWY{S,typeof(factors)}(factors, T)

# iteration for destructuring into components
Base.iterate(S::QRCompactWY) = (S.Q, Val(:R))
Expand Down Expand Up @@ -139,15 +153,23 @@ The object has three fields:
* `jpvt` is an integer vector of length `n` corresponding to the permutation ``P``.
"""
struct QRPivoted{T,S<:AbstractMatrix} <: Factorization{T}
struct QRPivoted{T,S<:AbstractMatrix{T}} <: Factorization{T}
factors::S
τ::Vector{T}
jpvt::Vector{BlasInt}
QRPivoted{T,S}(factors::AbstractMatrix{T}, τ::Vector{T}, jpvt::Vector{BlasInt}) where {T,S<:AbstractMatrix} =
new(factors, τ, jpvt)

function QRPivoted{T,S}(factors, τ, jpvt) where {T,S<:AbstractMatrix{T}}
@assert !has_offset_axes(factors, τ, jpvt)
new{T,S}(factors, τ, jpvt)
end
end
QRPivoted(factors::AbstractMatrix{T}, τ::Vector{T}, jpvt::Vector{BlasInt}) where {T} =
QRPivoted{T,typeof(factors)}(factors, τ, jpvt)
function QRPivoted{T}(factors::AbstractMatrix, τ::AbstractVector, jpvt::AbstractVector) where {T}
QRPivoted(convert(AbstractMatrix{T}, factors),
convert(Vector{T}, τ),
convert(Vector{BlasInt}, jpvt))
end

# iteration for destructuring into components
Base.iterate(S::QRPivoted) = (S.Q, Val(:R))
Expand Down Expand Up @@ -435,25 +457,39 @@ abstract type AbstractQ{T} <: AbstractMatrix{T} end
The orthogonal/unitary ``Q`` matrix of a QR factorization stored in [`QR`](@ref) or
[`QRPivoted`](@ref) format.
"""
struct QRPackedQ{T,S<:AbstractMatrix} <: AbstractQ{T}
struct QRPackedQ{T,S<:AbstractMatrix{T}} <: AbstractQ{T}
factors::S
τ::Vector{T}
QRPackedQ{T,S}(factors::AbstractMatrix{T}, τ::Vector{T}) where {T,S<:AbstractMatrix} = new(factors, τ)

function QRPackedQ{T,S}(factors, τ) where {T,S<:AbstractMatrix{T}}
@assert !has_offset_axes(factors)
new{T,S}(factors, τ)
end
end
QRPackedQ(factors::AbstractMatrix{T}, τ::Vector{T}) where {T} = QRPackedQ{T,typeof(factors)}(factors, τ)
function QRPackedQ{T}(factors::AbstractMatrix, τ::AbstractVector) where {T}
QRPackedQ(convert(AbstractMatrix{T}, factors), convert(Vector{T}, τ))
end

"""
QRCompactWYQ <: AbstractMatrix
The orthogonal/unitary ``Q`` matrix of a QR factorization stored in [`QRCompactWY`](@ref)
format.
"""
struct QRCompactWYQ{S, M<:AbstractMatrix} <: AbstractQ{S}
struct QRCompactWYQ{S, M<:AbstractMatrix{S}} <: AbstractQ{S}
factors::M
T::Matrix{S}
QRCompactWYQ{S,M}(factors::AbstractMatrix{S}, T::Matrix{S}) where {S,M<:AbstractMatrix} = new(factors, T)

function QRCompactWYQ{S,M}(factors, T) where {S,M<:AbstractMatrix{S}}
@assert !has_offset_axes(factors)
new{S,M}(factors, T)
end
end
QRCompactWYQ(factors::AbstractMatrix{S}, T::Matrix{S}) where {S} = QRCompactWYQ{S,typeof(factors)}(factors, T)
function QRCompactWYQ{S}(factors::AbstractMatrix, T::AbstractMatrix) where {S}
QRCompactWYQ(convert(AbstractMatrix{S}, factors), convert(Matrix{S}, T))
end

QRPackedQ{T}(Q::QRPackedQ) where {T} = QRPackedQ(convert(AbstractMatrix{T}, Q.factors), convert(Vector{T}, Q.τ))
AbstractMatrix{T}(Q::QRPackedQ{T}) where {T} = Q
Expand Down
13 changes: 10 additions & 3 deletions stdlib/LinearAlgebra/src/svd.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# Singular Value Decomposition
struct SVD{T,Tr,M<:AbstractArray} <: Factorization{T}
struct SVD{T,Tr,M<:AbstractArray{T}} <: Factorization{T}
U::M
S::Vector{Tr}
Vt::M
SVD{T,Tr,M}(U::AbstractArray{T}, S::Vector{Tr}, Vt::AbstractArray{T}) where {T,Tr,M} =
new(U, S, Vt)
function SVD{T,Tr,M}(U, S, Vt) where {T,Tr,M<:AbstractArray{T}}
@assert !has_offset_axes(U, S, Vt)
new{T,Tr,M}(U, S, Vt)
end
end
SVD(U::AbstractArray{T}, S::Vector{Tr}, Vt::AbstractArray{T}) where {T,Tr} = SVD{T,Tr,typeof(U)}(U, S, Vt)
function SVD{T}(U::AbstractArray, S::AbstractVector{Tr}, Vt::AbstractArray) where {T,Tr}
SVD(convert(AbstractArray{T}, U),
convert(Vector{Tr}, S),
convert(AbstractArray{T}, Vt))
end

# iteration for destructuring into components
Base.iterate(S::SVD) = (S.U, Val(:S))
Expand Down
Loading

0 comments on commit e2de8c3

Please sign in to comment.