forked from JuliaArrays/StaticArrays.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SHermitianCompact.jl
262 lines (227 loc) · 9.16 KB
/
SHermitianCompact.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""
SHermitianCompact{N, T, L} <: StaticMatrix{N, N, T}
A [`StaticArray`](@ref) subtype that represents a Hermitian matrix. Unlike
`LinearAlgebra.Hermitian`, `SHermitianCompact` stores only the lower triangle
of the matrix (as an `SVector`). The lower triangle is stored in column-major order.
For example, for an `SHermitianCompact{3}`, the indices of the stored elements
can be visualized as follows:
```
┌ 1 ⋅ ⋅ ┐
| 2 4 ⋅ |
└ 3 5 6 ┘
```
Type parameters:
* `N`: matrix dimension;
* `T`: element type for lower triangle;
* `L`: length of the `SVector` storing the lower triangular elements.
Note that `L` is always the `N`th [triangular number](https://en.wikipedia.org/wiki/Triangular_number).
An `SHermitianCompact` may be constructed either:
* from an `AbstractVector` containing the lower triangular elements; or
* from a `Tuple` containing both upper and lower triangular elements in column major order; or
* from another `StaticMatrix`.
For the latter two cases, only the lower triangular elements are used; the upper triangular
elements are ignored.
"""
struct SHermitianCompact{N, T, L} <: StaticMatrix{N, N, T}
lowertriangle::SVector{L, T}
@inline function SHermitianCompact{N, T, L}(lowertriangle::SVector{L}) where {N, T, L}
_check_hermitian_parameters(Val(N), Val(L))
new{N, T, L}(lowertriangle)
end
end
@inline function _check_hermitian_parameters(::Val{N}, ::Val{L}) where {N, L}
if 2 * L !== N * (N + 1)
throw(ArgumentError("Size mismatch in SHermitianCompact parameters. Got dimension $N and length $L."))
end
end
triangularnumber(N::Int) = div(N * (N + 1), 2)
@generated function triangularroot(::Val{L}) where {L}
return div(isqrt(8 * L + 1) - 1, 2) # from quadratic formula
end
lowertriangletype(::Type{SHermitianCompact{N, T, L}}) where {N, T, L} = SVector{L, T}
lowertriangletype(::Type{SHermitianCompact{N, T}}) where {N, T} = SVector{triangularnumber(N), T}
lowertriangletype(::Type{SHermitianCompact{N}}) where {N} = SVector{triangularnumber(N)}
@inline SHermitianCompact{N, T}(lowertriangle::SVector{L}) where {N, T, L} = SHermitianCompact{N, T, L}(lowertriangle)
@inline SHermitianCompact{N}(lowertriangle::SVector{L, T}) where {N, T, L} = SHermitianCompact{N, T, L}(lowertriangle)
@inline function SHermitianCompact(lowertriangle::SVector{L, T}) where {T, L}
N = triangularroot(Val(L))
SHermitianCompact{N, T, L}(lowertriangle)
end
@generated function SHermitianCompact{N, T, L}(a::Tuple) where {N, T, L}
_check_hermitian_parameters(Val(N), Val(L))
expr = Vector{Expr}(undef, L)
i = 0
for col = 1 : N, row = col : N
index = N * (col - 1) + row
expr[i += 1] = :(a[$index])
end
quote
@_inline_meta
@inbounds return SHermitianCompact{N, T, L}(SVector{L, T}(tuple($(expr...))))
end
end
@inline function SHermitianCompact{N, T}(a::Tuple) where {N, T}
L = triangularnumber(N)
SHermitianCompact{N, T, L}(a)
end
@inline (::Type{SSC})(a::SHermitianCompact) where {SSC <: SHermitianCompact} = SSC(a.lowertriangle)
@inline (::Type{SSC})(a::AbstractVector) where {SSC <: SHermitianCompact} = SSC(convert(lowertriangletype(SSC), a))
# disambiguation
@inline (::Type{SSC})(a::StaticArray{<:Tuple,<:Any,1}) where {SSC <: SHermitianCompact} = SSC(convert(SVector, a))
@generated function _hermitian_compact_indices(::Val{N}) where N
# Returns a Tuple{Pair{Int, Bool}} I such that for linear index i,
# * I[i][1] is the index into the lowertriangle field of an SHermitianCompact{N};
# * I[i][2] is true iff i is an index into the lower triangle of an N × N matrix.
indexmat = Matrix{Pair{Int, Bool}}(undef, N, N)
i = 0
for col = 1 : N, row = 1 : N
indexmat[row, col] = if row >= col
(i += 1) => true
else
indexmat[col, row][1] => false
end
end
quote
@_inline_meta
return $(tuple(indexmat...))
end
end
Base.@propagate_inbounds function Base.getindex(a::SHermitianCompact{N}, i::Int) where {N}
I = _hermitian_compact_indices(Val(N))
j, lower = I[i]
@inbounds value = a.lowertriangle[j]
return lower ? value : adjoint(value)
end
Base.@propagate_inbounds function Base.setindex(a::SHermitianCompact{N, T, L}, x, i::Int) where {N, T, L}
I = _hermitian_compact_indices(Val(N))
j, lower = I[i]
value = lower ? x : adjoint(x)
return SHermitianCompact{N}(setindex(a.lowertriangle, value, j))
end
# needed because it is used in convert.jl and the generic fallback is slow
@generated function Base.Tuple(a::SHermitianCompact{N}) where N
exprs = [:(a[$i]) for i = 1 : N^2]
quote
@_inline_meta
tuple($(exprs...))
end
end
LinearAlgebra.ishermitian(a::SHermitianCompact) = true
LinearAlgebra.issymmetric(a::SHermitianCompact) = true
# TODO: factorize?
@inline Base.:(==)(a::SHermitianCompact, b::SHermitianCompact) = a.lowertriangle == b.lowertriangle
@generated function _map(f, a::SHermitianCompact...)
S = Size(a[1])
N = S[1]
L = triangularnumber(N)
exprs = Vector{Expr}(undef, L)
for i ∈ 1:L
tmp = [:(a[$j].lowertriangle[$i]) for j ∈ 1:length(a)]
exprs[i] = :(f($(tmp...)))
end
return quote
@_inline_meta
same_size(a...)
@inbounds return SHermitianCompact(SVector(tuple($(exprs...))))
end
end
@inline Base.:*(a::Number, b::SHermitianCompact) = SHermitianCompact(a * b.lowertriangle)
@inline Base.:*(a::SHermitianCompact, b::Number) = SHermitianCompact(a.lowertriangle * b)
@inline Base.:/(a::SHermitianCompact, b::Number) = SHermitianCompact(a.lowertriangle / b)
@inline Base.:\(a::Number, b::SHermitianCompact) = SHermitianCompact(a \ b.lowertriangle)
@generated function _plus_uniform(::Size{S}, a::SHermitianCompact{N, T, L}, λ) where {S, N, T, L}
@assert S[1] == N
@assert S[2] == N
exprs = Vector{Expr}(undef, L)
i = 0
for col = 1 : N, row = col : N
i += 1
exprs[i] = row == col ? :(a.lowertriangle[$i] + λ) : :(a.lowertriangle[$i])
end
return quote
@_inline_meta
R = promote_type(eltype(a), typeof(λ))
SHermitianCompact{N, R, L}(SVector{L, R}(tuple($(exprs...))))
end
end
@generated function LinearAlgebra.transpose(a::SHermitianCompact{N, T, L}) where {N, T, L}
# To conform with LinearAlgebra, the transpose should be recursive.
# For this compact representation of a Hermitian matrix, that means that
# we should recursively transpose of the diagonal elements, but only
# conjugate the off-diagonal elements:
# [A Bᴴ]ᵀ = [Aᵀ Bᵀ] = [Aᵀ Bᵀ]
# [B C ] [Bᴴᵀ Cᵀ] [conj(B) Cᵀ]
exprs = Vector{Expr}(undef, L)
i = 0
for col = 1 : N, row = col : N
i += 1
exprs[i] = row == col ? :(transpose(a.lowertriangle[$i])) : :(conj(a.lowertriangle[$i]))
end
return quote
@_inline_meta
SHermitianCompact{N}(SVector{L}(tuple($(exprs...))))
end
end
@generated function LinearAlgebra.adjoint(a::SHermitianCompact{N, T, L}) where {N, T, L}
# To conform with LinearAlgebra, the adjoint should be recursive.
# Taking the adjoint of a Hermitian matrix is the identity, but
# with this compact representation of a Hermitian matrix, care
# should be taken that only the diagonal elements should be
# recursively conjugate-transposed; the off-diagonal elements should
# not:
# [A Bᴴ]ᴴ = [Aᴴ Bᴴ] = [Aᴴ Bᴴ]
# [B C ] [Bᴴᴴ Cᴴ] [B Cᴴ]
exprs = Vector{Expr}(undef, L)
i = 0
for col = 1 : N, row = col : N
i += 1
exprs[i] = row == col ? :(adjoint(a.lowertriangle[$i])) : :(a.lowertriangle[$i])
end
return quote
@_inline_meta
SHermitianCompact{N}(SVector{L}(tuple($(exprs...))))
end
end
@generated function _one(::Size{S}, ::Type{SSC}) where {S, SSC <: SHermitianCompact}
N = S[1]
L = triangularnumber(N)
T = eltype(SSC)
if T == Any
T = Float64
end
exprs = Vector{Expr}(undef, L)
i = 0
for col = 1 : N, row = col : N
exprs[i += 1] = row == col ? :(one($T)) : :(zero($T))
end
quote
@_inline_meta
return SHermitianCompact(SVector(tuple($(exprs...))))
end
end
@inline _scalar_matrix(s::Size{S}, t::Type{SSC}) where {S, SSC <: SHermitianCompact} = _one(s, t)
# _fill covers fill, zeros, and ones:
@generated function _fill(val, ::Size{s}, ::Type{SSC}) where {s, SSC <: SHermitianCompact}
N = s[1]
L = triangularnumber(N)
v = [:val for i = 1:L]
return quote
@_inline_meta
$SSC(SVector(tuple($(v...))))
end
end
@generated function _rand(randfun, rng::AbstractRNG, ::Type{SSC}) where {N, SSC <: SHermitianCompact{N}}
T = eltype(SSC)
if T == Any
T = Float64
end
L = triangularnumber(N)
v = [:(randfun(rng, $T)) for i = 1:L]
return quote
@_inline_meta
$SSC(SVector(tuple($(v...))))
end
end
@inline Random.rand(rng::AbstractRNG, ::Type{SSC}) where {SSC <: SHermitianCompact} = _rand(rand, rng, SSC)
@inline Random.randn(rng::AbstractRNG, ::Type{SSC}) where {SSC <: SHermitianCompact} = _rand(randn, rng, SSC)
@inline Random.randexp(rng::AbstractRNG, ::Type{SSC}) where {SSC <: SHermitianCompact} = _rand(randexp, rng, SSC)