-
Notifications
You must be signed in to change notification settings - Fork 249
/
16_gp_fast_f.stan
87 lines (81 loc) · 1.98 KB
/
16_gp_fast_f.stan
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
functions {
vector gp_pred_rng(array[] real x2,
vector y1,
array[] real x1,
real alpha,
real rho,
real sigma,
real delta) {
int N1 = rows(y1);
int N2 = size(x2);
vector[N2] f2;
{
matrix[N1, N1] L_K;
vector[N1] K_div_y1;
matrix[N1, N2] k_x1_x2;
matrix[N1, N2] v_pred;
vector[N2] f2_mu;
matrix[N2, N2] cov_f2;
matrix[N2, N2] diag_delta;
matrix[N1, N1] K;
K = cov_exp_quad(x1, alpha, rho);
for (n in 1:N1) {
K[n, n] = K[n, n] + square(sigma);
}
L_K = cholesky_decompose(K);
K_div_y1 = mdivide_left_tri_low(L_K, y1);
K_div_y1 = mdivide_right_tri_low(K_div_y1', L_K)';
k_x1_x2 = cov_exp_quad(x1, x2, alpha, rho);
f2_mu = (k_x1_x2' * K_div_y1);
v_pred = mdivide_left_tri_low(L_K, k_x1_x2);
cov_f2 = cov_exp_quad(x2, alpha, rho) - v_pred' * v_pred;
diag_delta = diag_matrix(rep_vector(delta, N2));
f2 = multi_normal_rng(f2_mu, cov_f2 + diag_delta);
}
return f2;
}
}
data {
int<lower=1> N1;
array[N1] real x1;
vector[N1] y1;
int<lower=1> N2;
array[N2] real x2;
// fake pars
real alpha;
real rho;
real sigma;
}
transformed data {
vector[N1] mu = rep_vector(0, N1);
real delta = 1e-9;
}
parameters {
//real<lower=0> rho;
//real<lower=0> alpha;
real<lower=0> sigmax;
}
model {
matrix[N1, N1] L_K;
{
matrix[N1, N1] K = cov_exp_quad(x1, alpha, rho);
real sq_sigma = square(sigma);
// diagonal elements
for (n1 in 1:N1) {
K[n1, n1] = K[n1, n1] + sq_sigma;
}
L_K = cholesky_decompose(K);
}
//rho ~ inv_gamma(5, 5);
//alpha ~ std_normal();
sigmax ~ std_normal();
y1 ~ multi_normal_cholesky(mu, L_K);
}
generated quantities {
vector[N2] f2;
vector[N2] y2;
f2 = gp_pred_rng(x2, y1, x1, alpha, rho, sigma, delta);
for (n2 in 1:N2) {
y2[n2] = normal_rng(f2[n2], sigma);
}
}