-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 7953ecb
Showing
67 changed files
with
4,290 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class Inference(nn.Module): | ||
|
||
def __init__(self, train_data, model): | ||
super(Inference, self).__init__() | ||
self.model = model | ||
self.train_x = train_data[0] | ||
self.train_y = train_data[1] | ||
self.output_min = torch.min(self.train_y) | ||
self.output_max = torch.max(self.train_y) | ||
self.mean_vec = None | ||
self.gram_mat = None | ||
# cholesky is lower triangular matrix | ||
self.cholesky = None | ||
self.jitter = 0 | ||
|
||
def gram_mat_update(self, hyper=None): | ||
if hyper is not None: | ||
self.model.vec_to_param(hyper) | ||
|
||
self.mean_vec = self.train_y - self.model.mean(self.train_x.float()) | ||
self.gram_mat = self.model.kernel(self.train_x) + torch.diag(self.model.likelihood(self.train_x.float())) | ||
|
||
def cholesky_update(self, hyper): | ||
self.gram_mat_update(hyper) | ||
eye_mat = torch.diag(self.gram_mat.new_ones(self.gram_mat.size(0))) | ||
for jitter_const in [0, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3]: | ||
chol_jitter = torch.trace(self.gram_mat).item() * jitter_const | ||
try: | ||
# cholesky is lower triangular matrix | ||
self.cholesky = torch.cholesky(self.gram_mat + eye_mat * chol_jitter, upper=False) | ||
self.jitter = chol_jitter | ||
return | ||
except RuntimeError: | ||
pass | ||
raise RuntimeError('Absolute entry values of Gram matrix are between %.4E~%.4E with trace %.4E' % | ||
(torch.min(torch.abs(self.gram_mat)).item(), torch.max(torch.abs(self.gram_mat)).item(), | ||
torch.trace(self.gram_mat).item())) | ||
|
||
def predict(self, pred_x, hyper=None, verbose=False, compute_grad = False): | ||
if hyper is not None: | ||
param_original = self.model.param_to_vec() | ||
self.cholesky_update(hyper) | ||
|
||
k_pred_train = self.model.kernel(pred_x, self.train_x) | ||
k_pred = self.model.kernel(pred_x, diagonal=True) | ||
|
||
# cholesky is lower triangular matrix | ||
chol_solver = torch.triangular_solve(torch.cat([k_pred_train.t(), self.mean_vec], 1), self.cholesky, upper=False)[0] | ||
chol_solve_k = chol_solver[:, :-1] | ||
chol_solve_y = chol_solver[:, -1:] | ||
|
||
pred_mean = torch.mm(chol_solve_k.t(), chol_solve_y) + self.model.mean(pred_x) | ||
pred_quad = (chol_solve_k ** 2).sum(0).view(-1, 1) | ||
pred_var = k_pred - pred_quad | ||
|
||
if verbose: | ||
numerically_stable = (pred_var >= 0).all() | ||
zero_pred_var = (pred_var <= 0).all() | ||
|
||
if hyper is not None: | ||
self.cholesky_update(param_original) | ||
|
||
if compute_grad: | ||
alpha = torch.cholesky_solve(self.mean_vec, self.cholesky, upper=False) | ||
grad_cross = self.model.kernel.grad(self.train_x, pred_x) | ||
grad_xp_m = torch.mm(grad_cross, k_pred_train.t()*alpha) | ||
gamma = torch.triangular_solve(chol_solve_k, self.cholesky.t(), upper=True)[0] | ||
grad_xp_v = -2 * torch.mm(gamma.t(), (grad_cross * k_pred_train).t()).t() | ||
return pred_mean, pred_var.clamp(min=1e-8), grad_xp_m, grad_xp_v | ||
else: | ||
if verbose: | ||
return pred_mean, pred_var.clamp(min=1e-8), numerically_stable, zero_pred_var | ||
else: | ||
return pred_mean, pred_var.clamp(min=1e-8) | ||
|
||
|
||
|
||
def negative_log_likelihood(self, hyper=None): | ||
if hyper is not None: | ||
param_original = self.model.param_to_vec() | ||
self.cholesky_update(hyper) | ||
|
||
# cholesky is lower triangular matrix | ||
mean_vec_sol = torch.triangular_solve(self.mean_vec, self.cholesky, upper=False)[0] | ||
nll = 0.5 * torch.sum(mean_vec_sol ** 2) + torch.sum(torch.log(torch.diag(self.cholesky))) + 0.5 * self.train_y.size(0) * np.log(2 * np.pi) | ||
if hyper is not None: | ||
self.cholesky_update(param_original) | ||
return nll | ||
|
||
|
||
if __name__ == '__main__': | ||
n_size_ = 50 | ||
jitter_const_ = 0 | ||
for _ in range(10): | ||
A_ = torch.randn(n_size_, n_size_ - 2) | ||
A_ = A_.matmul(A_.t()) * 0 + 1e-6 | ||
A_ = A_ + torch.diag(torch.ones(n_size_)) * jitter_const_ * torch.trace(A_).item() | ||
b_ = torch.randn(n_size_, 3) | ||
L_ = torch.cholesky(A_, upper=False) | ||
assert (torch.diag(L_) > 0).all() | ||
abs_min = torch.min(torch.abs(A_)).item() | ||
abs_max = torch.max(torch.abs(A_)).item() | ||
trace = torch.trace(A_).item() | ||
print(' %.4E~%.4E %.4E' % (abs_min, abs_max, trace)) | ||
print(' jitter:%.4E' % (trace * jitter_const_)) | ||
print('The smallest eigen value : %.4E\n' % torch.min(torch.diag(L_)).item()) | ||
torch.triangular_solve(b_, L_, upper=False) | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import math | ||
|
||
import torch | ||
from GPmodel.kernels.mixedkernel import MixedKernel | ||
import numpy as np | ||
|
||
class MixedDiffusionKernel(MixedKernel): | ||
def __init__(self, log_order_variances, grouped_log_beta, fourier_freq_list, fourier_basis_list, lengthscales, num_discrete, num_continuous): | ||
super(MixedDiffusionKernel, self).__init__(log_order_variances, grouped_log_beta, fourier_freq_list, fourier_basis_list, lengthscales, num_discrete, num_continuous) | ||
|
||
def forward(self, x1, x2=None, diagonal=False): | ||
""" | ||
:param x1, x2: each row is a vector with vertex numbers starting from 0 for each | ||
:return: | ||
""" | ||
if diagonal: | ||
assert x2 is None | ||
stabilizer = 0 | ||
if x2 is None: | ||
x2 = x1 | ||
if diagonal: | ||
stabilizer = 1e-6 * x1.new_ones(x1.size(0), 1, dtype=torch.float32) | ||
else: | ||
stabilizer = torch.diag(1e-6 * x1.new_ones(x1.size(0), dtype=torch.float32)) | ||
|
||
base_kernels = [] | ||
for i in range(len(self.fourier_freq_list)): | ||
beta = torch.exp(self.grouped_log_beta[i]) | ||
fourier_freq = self.fourier_freq_list[i] | ||
fourier_basis = self.fourier_basis_list[i] | ||
cat_i = fourier_freq.size(0) | ||
discrete_kernel = ((1-torch.exp(-beta*cat_i))/(1+(cat_i-1)*torch.exp(-beta*cat_i)))**((x1[:, i].unsqueeze(1)[:, np.newaxis] != x2[:, i].unsqueeze(1)).sum(axis=-1)) | ||
if diagonal: | ||
base_kernels.append(torch.diagonal(discrete_kernel).unsqueeze(1)) | ||
else: | ||
base_kernels.append(discrete_kernel) | ||
|
||
lengthscales = torch.exp(self.lengthscales)**2 | ||
temp_x_1 = x1[:, self.num_discrete:]/lengthscales | ||
temp_x_2 = x2[:, self.num_discrete:]/lengthscales | ||
|
||
for i in range(self.num_continuous): | ||
normalized_dists = torch.cdist(temp_x_1[:, i].unsqueeze(1), temp_x_2[:, i].unsqueeze(1)) | ||
gaussian_kernel = torch.exp(-0.5 * (normalized_dists) ** 2) | ||
if not diagonal: | ||
base_kernels.append(gaussian_kernel) | ||
else: | ||
base_kernels.append(torch.diagonal(gaussian_kernel).unsqueeze(1)) | ||
base_kernels = torch.stack(base_kernels) | ||
if diagonal: | ||
base_kernels = base_kernels.squeeze(-1) | ||
|
||
num_dimensions = self.num_discrete + self.num_continuous | ||
if (not diagonal): | ||
e_n = torch.empty([num_dimensions + 1, \ | ||
base_kernels.size(1), base_kernels.size(2)]) | ||
e_n[0, :, :] = 1.0 | ||
interaction_orders = torch.arange(1, num_dimensions+1).reshape([-1, 1, 1, 1]).float() | ||
kernel_dim = -3 | ||
shape = [1 for _ in range(3)] | ||
else: | ||
e_n = torch.empty([num_dimensions + 1, \ | ||
base_kernels.size(1)]) | ||
e_n[0, :] = 1.0 | ||
interaction_orders = torch.arange(1, num_dimensions+1).reshape([-1, 1, 1]).float() | ||
kernel_dim = -2 | ||
shape = [1 for _ in range(2)] | ||
|
||
s_k = base_kernels.unsqueeze(kernel_dim - 1).pow(interaction_orders).sum(dim=kernel_dim) | ||
m1 = torch.tensor([-1.0]) | ||
shape[kernel_dim] = -1 | ||
|
||
|
||
for deg in range(1, num_dimensions + 1): # deg goes from 1 to R (it's 1-indexed!) | ||
ks = torch.arange(1, deg + 1, dtype=torch.float).reshape(*shape) # use for pow | ||
kslong = torch.arange(1, deg + 1, dtype=torch.long) # use for indexing | ||
# note that s_k is 0-indexed, so we must subtract 1 from kslong | ||
sum_ = ( | ||
m1.pow(ks - 1) * e_n.index_select(kernel_dim, deg - kslong) * s_k.index_select(kernel_dim, kslong - 1) | ||
).sum(dim=kernel_dim) / deg | ||
if kernel_dim == -3: | ||
e_n[deg, :, :] = sum_ | ||
else: | ||
e_n[deg, :] = sum_ | ||
|
||
order_variances = torch.exp(self.log_order_variances) | ||
if kernel_dim == -3: | ||
kernel_mat = torch.exp(self.log_amp) * ((order_variances.unsqueeze(-1).unsqueeze(-1) * e_n.narrow(kernel_dim, 1, num_dimensions)).sum(dim=kernel_dim)) + stabilizer | ||
return torch.exp(self.log_amp) * ((order_variances.unsqueeze(-1).unsqueeze(-1) * e_n.narrow(kernel_dim, 1, num_dimensions)).sum( | ||
dim=kernel_dim) + stabilizer) | ||
else: | ||
return torch.exp(self.log_amp) * ((order_variances.unsqueeze(-1) * e_n.narrow(kernel_dim, 1, num_dimensions)).sum(dim=kernel_dim) + stabilizer) | ||
|
||
|
||
# def grad(self, x1, x2=None): | ||
# if x2 is None: | ||
# x2 = x1 | ||
# diffs = (x1[:, self.num_discrete:] - x2[:, self.num_discrete:])/self.lengthscales | ||
# return diffs.t() | ||
|
||
if __name__ == '__main__': | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
|
||
from GPmodel.modules.gp_modules import GPModule | ||
|
||
|
||
class MixedKernel(GPModule): | ||
|
||
def __init__(self, log_order_variances, grouped_log_beta, fourier_freq_list, fourier_basis_list, lengthscales, num_discrete, num_continuous): | ||
super(MixedKernel, self).__init__() | ||
self.log_amp = torch.FloatTensor(1) | ||
self.log_order_variances = log_order_variances # torch.ones(size=(num_discrete + num_continuous, )) # one for each combination of interaction | ||
self.grouped_log_beta = grouped_log_beta | ||
self.fourier_freq_list = fourier_freq_list | ||
self.fourier_basis_list = fourier_basis_list | ||
self.lengthscales = lengthscales | ||
self.num_discrete = num_discrete | ||
self.num_continuous = num_continuous | ||
assert self.log_order_variances.size(0) == self.num_continuous + self.num_discrete, "order variances are not properly initialized" | ||
assert self.lengthscales.size(0) == self.num_continuous, "lengthscales is not properly initialized" | ||
assert self.grouped_log_beta.size(0) == self.num_discrete, "beta is not properly initialized" | ||
|
||
def n_params(self): | ||
return 1 | ||
|
||
def param_to_vec(self): | ||
return self.log_amp.clone() | ||
|
||
def vec_to_param(self, vec): | ||
assert vec.numel() == 1 # self.num_discrete + self.num_continuous | ||
self.log_amp = vec[:1].clone() | ||
|
||
def forward(self, input1, input2=None): | ||
raise NotImplementedError |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
|
||
from GPmodel.likelihoods.likelihood import Likelihood | ||
|
||
|
||
class GaussianLikelihood(Likelihood): | ||
|
||
def __init__(self): | ||
super(GaussianLikelihood, self).__init__() | ||
self.log_noise_var = torch.FloatTensor(1) | ||
self.noise_scale = 0.1 | ||
|
||
def n_params(self): | ||
return 1 | ||
|
||
def param_to_vec(self): | ||
return self.log_noise_var.clone() | ||
|
||
def vec_to_param(self, vec): | ||
self.log_noise_var = vec.clone() | ||
|
||
def forward(self, input): | ||
return torch.exp(self.log_noise_var).repeat(input.size(0)) | ||
|
||
def __repr__(self): | ||
return self.__class__.__name__ | ||
|
||
|
||
if __name__ == '__main__': | ||
likelihood = GaussianLikelihood() | ||
print(list(likelihood.parameters())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from GPmodel.modules.gp_modules import GPModule | ||
|
||
|
||
class Likelihood(GPModule): | ||
|
||
def __init__(self): | ||
super(Likelihood, self).__init__() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
|
||
from GPmodel.means.mean import Mean | ||
|
||
|
||
class ConstantMean(Mean): | ||
|
||
def __init__(self): | ||
super(ConstantMean, self).__init__() | ||
self.const_mean = torch.FloatTensor(1) | ||
|
||
def n_params(self): | ||
return 1 | ||
|
||
def param_to_vec(self): | ||
return self.const_mean.clone() | ||
|
||
def vec_to_param(self, vec): | ||
self.const_mean = vec.clone() | ||
|
||
def forward(self, input): | ||
# print("input ", input) | ||
return self.const_mean * input.new_ones(input.size(0), 1).float() | ||
|
||
def __repr__(self): | ||
return self.__class__.__name__ | ||
|
||
|
||
if __name__ == '__main__': | ||
likelihood = ConstantMean() | ||
print(list(likelihood.parameters())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from GPmodel.modules.gp_modules import GPModule | ||
|
||
|
||
class Mean(GPModule): | ||
|
||
def __init__(self): | ||
super(Mean, self).__init__() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
from GPmodel.modules.gp_modules import GPModule | ||
|
||
|
||
class GP(GPModule): | ||
def __init__(self, **kwargs): | ||
super(GP, self).__init__() | ||
|
||
def init_param(self, output_data): | ||
raise NotImplementedError | ||
|
||
def n_params(self): | ||
cnt = 0 | ||
print("heree") | ||
for param in self.parameters(): | ||
cnt += param.numel() | ||
return cnt | ||
|
||
def param_to_vec(self): | ||
flat_param_list = [] | ||
for m in self.children(): | ||
# print("************* [children] ************") | ||
# print(m, m.param_to_vec()) | ||
flat_param_list.append(m.param_to_vec()) | ||
return torch.cat(flat_param_list) | ||
|
||
def vec_to_param(self, vec): | ||
# print("vec", vec) | ||
ind = 0 | ||
for m in self.children(): | ||
jump = m.n_params() | ||
# print("jump: ", jump) | ||
m.vec_to_param(vec[ind:ind+jump]) | ||
ind += jump |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
|
||
from GPmodel.likelihoods.gaussian import GaussianLikelihood | ||
from GPmodel.means.constant import ConstantMean | ||
from GPmodel.models.gp import GP | ||
|
||
|
||
class GPRegression(GP): | ||
|
||
def __init__(self, kernel, mean=ConstantMean()): | ||
super(GPRegression, self).__init__() | ||
self.kernel = kernel | ||
self.mean = mean | ||
self.likelihood = GaussianLikelihood() | ||
|
||
def init_param(self, output_data): | ||
output_mean = torch.mean(output_data).item() | ||
output_log_var = (0.5 * torch.var(output_data)).log().item() | ||
self.kernel.log_amp.fill_(output_log_var) | ||
self.mean.const_mean.fill_(output_mean) | ||
self.likelihood.log_noise_var.fill_(output_mean / 1000.0) |
Empty file.
Oops, something went wrong.