import os import pickle import time from datetime import datetime import torch from config import experiment_directory def model_data_filenames(exp_dir, objective_name): folder_name = objective_name + '_' + datetime.now().strftime('%Y%m%d-%H:%M:%S:%f') os.makedirs(os.path.join(exp_dir, folder_name)) logfile_dir = os.path.join(exp_dir, folder_name, 'log') os.makedirs(logfile_dir) model_filename = os.path.join(exp_dir, folder_name, 'model.pt') cfg_data_filename = os.path.join(exp_dir, folder_name, 'data_config.pkl') return model_filename, cfg_data_filename, logfile_dir def load_model_data(path, exp_dir=experiment_directory()): if not os.path.exists(path): path = os.path.join(exp_dir, path) logfile_dir = os.path.join(path, 'log') model_filename = os.path.join(path, 'model.pt') cfg_data_filename = os.path.join(path, 'data_config.pkl') model = torch.load(model_filename) cfg_data_file = open(cfg_data_filename, 'r') cfg_data = pickle.load(cfg_data_file) for key, value in pickle.load(cfg_data_file).iteritems(): if key != 'logfile_dir': exec(key + '=value') cfg_data_file.close() return model, cfg_data, logfile_dir def save_model_data(model, model_filename, cfg_data, cfg_data_filename): torch.save(model, model_filename) f = open(cfg_data_filename, 'w') pickle.dump(cfg_data, f) f.close() def displaying_and_logging(logfile_dir, eval_inputs, eval_outputs, pred_mean_list, pred_std_list, pred_var_list, time_list, elapse_list, hyper_samples, log_beta_samples, lengthscale_samples, log_order_var_samples, store_data=True): logfile = open(os.path.join(logfile_dir, str(eval_inputs.size(0)).zfill(4) + '.out'), 'w') for i in range(eval_inputs.size(0)): min_val, min_ind = torch.min(eval_outputs[:i + 1], 0) time_str = time.strftime('%H:%M:%S', time.gmtime(time_list[i])) \ + '(' + time.strftime('%H:%M:%S', time.gmtime(elapse_list[i])) + ') ' data_str = ('%3d-th : %+12.4f, mean : %+.4E, std : %.4E, var : %.4E, min : %+8.4f(%3d)' % (i + 1, eval_outputs.squeeze()[i], pred_mean_list[i], pred_std_list[i], pred_var_list[i], min_val.item(), min_ind.item() + 1)) min_str = ' <==== IMPROVED' if i == min_ind.item() else '' print(time_str + data_str + min_str) logfile.writelines(time_str + data_str + min_str + '\n') logfile.close() if store_data: pklfilename = os.path.join(logfile_dir, str(eval_inputs.size(0)).zfill(4) + '.pkl') torch.save({'inputs': eval_inputs, 'outputs': eval_outputs, 'hyper_samples':hyper_samples, 'log_beta_samples': log_beta_samples, 'lengthscale_samples': lengthscale_samples, 'lov':log_order_var_samples}, pklfilename) #torch.save({'inputs': eval_inputs, 'outputs': eval_outputs}, pklfilename)