-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
77 changed files
with
15,885 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
dataset,mae | ||
bitcoin,7.777284173521224e+17 | ||
wind_4_seconds,0.0 | ||
tourism_quarterly,15845.100306204946 | ||
traffic_weekly,1.1855844384623289 | ||
us_births,1152.6666666666667 | ||
sunspot,3.933333396911621 | ||
covid_deaths,353.70939849624057 | ||
weather,2.362190193902301 | ||
nn5_daily,8.262752532958984 | ||
solar_10_minutes,2.7269221451758905 | ||
fred_md,2825.672461360778 | ||
tourism_yearly,99456.0540551959 | ||
oikolab_weather,120.03774162319314 | ||
cif_2016,386526.36704240675 | ||
solar_4_seconds,0.0 | ||
australian_electricity_demand,659.600688770839 | ||
traffic_hourly,0.026246391982575817 | ||
solar_weekly,1729.4092503457175 | ||
tourism_monthly,5636.83029361023 | ||
nn5_weekly,16.708553516113007 | ||
pedestrian_counts,170.8838383838384 | ||
kaggle_web_traffic_weekly,2081.781183003247 | ||
hospital,24.06573229030856 | ||
saugeenday,21.496667098999023 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
dataset,rmse,mse,mae | ||
weather,3.114059003829736,19.799946602304583,2.362190193902301 | ||
tourism_yearly,111078.12328993063,363513954611.7144,99456.0540551959 | ||
tourism_quarterly,19527.771486861628,6201928361.602383,15845.100306204946 | ||
tourism_monthly,7374.891631391369,619111237.8884984,5636.83029361023 | ||
cif_2016,513908.01521842513,4215886809397.5137,386526.3670424068 | ||
australian_electricity_demand,800.8367705189698,885757.1106004094,652.185211690267 | ||
pedestrian_counts,201.73714366476048,151208.0053661616,149.78377525252526 | ||
nn5_weekly,20.207279217372406,551.324271498317,16.708553516113007 | ||
kaggle_web_traffic_weekly,2751.056545255073,745610075.6193119,2081.781183003247 | ||
solar_10_minutes,2.916518075950651,15.696517675162108,2.72692214517589 | ||
solar_weekly,1918.441068452442,4732343.391918798,1729.4092503457173 | ||
fred_md,3128.761597032152,789147915.1258634,2825.672461360779 | ||
traffic_hourly,0.038692458579408555,0.0018708090134443467,0.029967592352576376 | ||
traffic_weekly,1.5367087177120735,4.087448585855902,1.1855844384623289 | ||
hospital,28.901001578961637,5851.362559756627,24.06573229030856 | ||
covid_deaths,403.4146883707207,4435194.603884712,353.7093984962405 | ||
saugeenday,39.79399009437644,1583.56164763133,21.496667098999023 | ||
us_births,1608.8383179590587,2588360.7333333334,1152.6666666666667 | ||
solar_4_seconds,0.0,0.0,0.0 | ||
wind_4_seconds,0.0,0.0,0.0 | ||
oikolab_weather,130.07108228180007,81269.08468305029,120.03774162319314 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import numpy as np | ||
from jax import vmap | ||
import jax.numpy as jnp | ||
|
||
def quantile_loss(target, pred, q): | ||
q_pred = jnp.quantile(pred, q, axis=0) | ||
return 2 * jnp.sum( | ||
jnp.abs((q_pred - target) * ((target <= q_pred) * 1.0 - q)) | ||
) | ||
|
||
def calculate_crps(target, pred, num_quantiles=20): | ||
quantiles = jnp.linspace(0, 1.0, num_quantiles+1)[1:] | ||
vec_quantile_loss = vmap(lambda q: quantile_loss(target, pred, q)) | ||
crps = jnp.sum(vec_quantile_loss(quantiles)) | ||
crps = crps / (jnp.sum(np.abs(target)) * len(quantiles)) | ||
return crps | ||
|
||
import jax | ||
from jax import grad,vmap | ||
from .serialize import serialize_arr, SerializerSettings | ||
import openai | ||
|
||
def nll(input_arr, target_arr, model, settings:SerializerSettings, transform, count_seps=True, prompt=None, temp=1): | ||
""" Returns the NLL/dimension (log base e) of the target array (continuous) according to the LM | ||
conditioned on the input array. Applies relevant log determinant for transforms and | ||
converts from discrete NLL of the LLM to continuous by assuming uniform within the bins. | ||
inputs: | ||
input_arr: (n,) context array | ||
target_arr: (n,) ground truth array | ||
Returns: NLL/D | ||
""" | ||
input_str = serialize_arr(vmap(transform)(input_arr), settings) | ||
target_str = serialize_arr(vmap(transform)(target_arr), settings) | ||
if prompt: | ||
input_str = prompt + '\n' + input_str | ||
if not input_str.endswith(settings.time_sep): | ||
print('Appending time separator to input... Are you sure you want this?') | ||
prompt = input_str + settings.time_sep + target_str | ||
else: | ||
prompt = input_str + target_str | ||
response = openai.Completion.create(model=model, prompt=prompt, logprobs=5, max_tokens=0, echo=True, temperature=temp) | ||
#print(response['choices'][0]) | ||
logprobs = np.array(response['choices'][0].logprobs.token_logprobs, dtype=np.float32) | ||
tokens = np.array(response['choices'][0].logprobs.tokens) | ||
top5logprobs = response['choices'][0].logprobs.top_logprobs | ||
seps = tokens==settings.time_sep | ||
target_start = np.argmax(np.cumsum(seps)==len(input_arr)) + 1 | ||
logprobs = logprobs[target_start:] | ||
tokens = tokens[target_start:] | ||
top5logprobs = top5logprobs[target_start:] | ||
seps = tokens==settings.time_sep | ||
assert len(logprobs[seps]) == len(target_arr), f'There should be one separator per target. Got {len(logprobs[seps])} separators and {len(target_arr)} targets.' | ||
#adjust logprobs by removing extraneous and renormalizing (see appendix of paper) | ||
# logp' = logp - log(1-pk*pextra) | ||
allowed_tokens = [settings.bit_sep + str(i) for i in range(settings.base)] | ||
allowed_tokens += [settings.time_sep, settings.plus_sign, settings.minus_sign, settings.bit_sep+settings.decimal_point] | ||
allowed_tokens = {t for t in allowed_tokens if len(t) > 0} | ||
|
||
p_extra = np.array([sum(np.exp(ll) for k,ll in top5logprobs[i].items() if not (k in allowed_tokens)) for i in range(len(top5logprobs))]) | ||
if settings.bit_sep == '': | ||
p_extra = 0 | ||
adjusted_logprobs = logprobs - np.log(1-p_extra) | ||
digits_bits = -adjusted_logprobs[~seps].sum() | ||
seps_bits = -adjusted_logprobs[seps].sum() | ||
BPD = digits_bits/len(target_arr) | ||
if count_seps: | ||
BPD += seps_bits/len(target_arr) | ||
#print("BPD unadjusted:", -logprobs.sum()/len(target_arr), "BPD adjusted:", BPD) | ||
# log p(x) = log p(token) - log bin_width = log p(token) + prec * log base | ||
transformed_nll = BPD - settings.prec*np.log(settings.base) | ||
avg_logdet_dydx = np.log(vmap(grad(transform))(target_arr)).mean() | ||
return transformed_nll-avg_logdet_dydx | ||
|
||
class Evaluator: | ||
|
||
def __init__(self): | ||
self.non_numerical_cols = [ | ||
"serialized_history", | ||
"serialized_target", | ||
"serialized_prediction", | ||
"history_len", | ||
"num_channels", | ||
"example_num", | ||
"sample_num", | ||
] | ||
|
||
def evaluate_df(self, gt_df, pred_df): | ||
cols = [c for c in gt_df.columns if c not in self.non_numerical_cols] | ||
num_channels = gt_df["num_channels"].iloc[0] | ||
history_len = gt_df["history_len"].iloc[0] | ||
gt_vals = gt_df[cols].to_numpy().reshape(len(gt_df), -1, num_channels) # (num_examples, history_len + target_len, num_channels) | ||
gt_vals = gt_vals[:, history_len:, :] # (num_examples, target_len, num_channels) | ||
|
||
cols = [c for c in pred_df.columns if c not in self.non_numerical_cols] | ||
num_channels = pred_df["num_channels"].iloc[0] | ||
pred_df = pred_df[cols + ["example_num"]] | ||
|
||
all_pred_vals = [] | ||
for example_num in sorted(pred_df["example_num"].unique()): | ||
pred_vals = pred_df[pred_df["example_num"] == example_num][cols].to_numpy() # (num_samples, target_len * num_channels) | ||
pred_vals = pred_vals.reshape(pred_vals.shape[0], -1, num_channels) # (num_samples, target_len, num_channels) | ||
all_pred_vals.append(pred_vals) | ||
|
||
pred_vals = np.stack(all_pred_vals, axis=1) # (num_samples, num_examples, target_len, num_channels) | ||
assert gt_vals.shape == pred_vals.shape[1:] | ||
|
||
diff = (gt_vals[None] - pred_vals) # (num_samples, num_examples, target_len, num_channels) | ||
mse = np.mean(diff**2) | ||
mae = np.mean(np.abs(diff)) | ||
crps = calculate_crps(gt_vals, pred_vals) | ||
|
||
return { | ||
"mse": mse, | ||
"mae": mae, | ||
"crps": crps, | ||
} | ||
|
||
def evaluate(self, gt, pred): | ||
''' | ||
gt: (batch_size, steps) | ||
pred: (batch_size, num_samples, steps) | ||
''' | ||
assert gt.shape == (pred.shape[0], pred.shape[2]), f"wrong shapes: gt.shape: {gt.shape}, pred.shape: {pred.shape}" | ||
diff = (gt[:, None, :] - pred) # (batch_size, num_samples, steps) | ||
mse = np.mean(diff**2) | ||
mae = np.mean(np.abs(diff)) | ||
std = np.std(gt, axis=1) + 1e-8 # (batch_size,) | ||
normlized_diff = diff / std[:, None, None] # (batch_size, num_samples, steps) | ||
nmse = np.mean(normlized_diff**2) | ||
nmae = np.mean(np.abs(normlized_diff)) | ||
|
||
return { | ||
"nmse": nmse, | ||
"nmae": nmae, | ||
"mse": mse, | ||
"mae": mae, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import pandas as pd | ||
import numpy as np | ||
from collections import defaultdict | ||
from sklearn.preprocessing import StandardScaler | ||
import datasets | ||
from datasets import load_dataset | ||
import os | ||
import pickle | ||
|
||
fix_pred_len = { | ||
'australian_electricity_demand': 336, | ||
'pedestrian_counts': 24, | ||
'traffic_hourly': 168, | ||
} | ||
|
||
def get_benchmark_test_sets(): | ||
test_set_dir = "datasets/monash" | ||
if not os.path.exists(test_set_dir): | ||
os.makedirs(test_set_dir) | ||
|
||
if len(os.listdir(test_set_dir)) > 0: | ||
print(f'Loading test sets from {test_set_dir}') | ||
test_sets = {} | ||
for file in os.listdir(test_set_dir): | ||
test_sets[file.split(".")[0]] = pickle.load(open(os.path.join(test_set_dir, file), 'rb')) | ||
return test_sets | ||
else: | ||
print(f'No files found in {test_set_dir}. You are not using our preprocessed datasets!') | ||
|
||
benchmarks = { | ||
"monash_tsf": datasets.get_dataset_config_names("monash_tsf"), | ||
} | ||
|
||
test_sets = defaultdict(list) | ||
for path in benchmarks: | ||
pred_lens = [24, 48, 96, 192] if path == "ett" else [None] | ||
for name in benchmarks[path]: | ||
for pred_len in pred_lens: | ||
if pred_len is None: | ||
ds = load_dataset(path, name) | ||
else: | ||
ds = load_dataset(path, name, multivariate=False, prediction_length=pred_len) | ||
|
||
train_example = ds['train'][0]['target'] | ||
val_example = ds['validation'][0]['target'] | ||
|
||
if len(np.array(train_example).shape) > 1: | ||
print(f"Skipping {name} because it is multivariate") | ||
continue | ||
|
||
pred_len = len(val_example) - len(train_example) | ||
if name in fix_pred_len: | ||
print(f"Fixing pred len for {name}: {pred_len} -> {fix_pred_len[name]}") | ||
pred_len = fix_pred_len[name] | ||
|
||
tag = name | ||
print("Processing", tag) | ||
|
||
pairs = [] | ||
for x in ds['test']: | ||
if np.isnan(x['target']).any(): | ||
print(f"Skipping {name} because it has NaNs") | ||
break | ||
history = np.array(x['target'][:-pred_len]) | ||
target = np.array(x['target'][-pred_len:]) | ||
pairs.append((history, target)) | ||
else: | ||
scaler = None | ||
if path == "ett": | ||
trainset = np.array(ds['train'][0]['target']) | ||
scaler = StandardScaler().fit(trainset[:,None]) | ||
test_sets[tag] = (pairs, scaler) | ||
|
||
for name in test_sets: | ||
try: | ||
with open(os.path.join(test_set_dir,f"{name}.pkl"), 'wb') as f: | ||
pickle.dump(test_sets[name], f) | ||
print(f"Saved {name}") | ||
except: | ||
print(f"Failed to save {name}") | ||
|
||
return test_sets | ||
|
||
def get_datasets(): | ||
benchmarks = get_benchmark_test_sets() | ||
# shuffle the benchmarks | ||
for k, v in benchmarks.items(): | ||
x, _scaler = v # scaler is not used | ||
train, test = zip(*x) | ||
np.random.seed(0) | ||
ind = np.arange(len(train)) | ||
ind = np.random.permutation(ind) | ||
train = [train[i] for i in ind] | ||
test = [test[i] for i in ind] | ||
benchmarks[k] = [list(train), list(test)] | ||
|
||
df = pd.read_csv('data/last_val_mae.csv') | ||
df.sort_values(by='mae') | ||
|
||
df_paper = pd.read_csv('data/paper_mae_raw.csv') # pdf text -> csv | ||
datasets = df_paper['Dataset'] | ||
name_map = { | ||
'Aus. Electricity Demand' :'australian_electricity_demand', | ||
'Kaggle Weekly': 'kaggle_web_traffic_weekly', | ||
'FRED-MD': 'fred_md', | ||
'Saugeen River Flow': 'saugeenday', | ||
|
||
} | ||
datasets = [name_map.get(d, d) for d in datasets] | ||
# lower case and repalce spaces with underscores | ||
datasets = [d.lower().replace(' ', '_') for d in datasets] | ||
df_paper['Dataset'] = datasets | ||
df_paper = df_paper.reset_index(drop=True) | ||
# for each dataset, add last value mae to df_paper | ||
for dataset in df_paper['Dataset']: | ||
if dataset in df['dataset'].values: | ||
df_paper.loc[df_paper['Dataset'] == dataset, 'Last Value'] = df[df['dataset'] == dataset]['mae'].values[0] | ||
# turn '-' into np.nan | ||
df_paper = df_paper.replace('-', np.nan) | ||
# convert all values to float | ||
for method in df_paper.columns[1:]: | ||
df_paper[method] = df_paper[method].astype(float) | ||
df_paper.to_csv('data/paper_mae.csv', index=False) | ||
# normalize each method by dividing by last value mae | ||
for method in df_paper.columns[1:-1]: # skip dataset and last value | ||
df_paper[method] = df_paper[method] / df_paper['Last Value'] | ||
# sort df by minimum mae across methods | ||
df_paper['normalized_min'] = df_paper[df_paper.columns[1:-1]].min(axis=1) | ||
df_paper['normalized_median'] = df_paper[df_paper.columns[1:-1]].median(axis=1) | ||
df_paper = df_paper.sort_values(by='normalized_min') | ||
df_paper = df_paper.reset_index(drop=True) | ||
# save as csv | ||
df_paper.to_csv('data/paper_mae_normalized.csv', index=False) | ||
return benchmarks | ||
|
||
def main(): | ||
get_datasets() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
Dataset,SES,Theta,TBATS,ETS,(DHR-)ARIMA,PR,CatBoost,FFNN,DeepAR,N-BEATS,WaveNet,Transformer,Last Value | ||
tourism_yearly,95579.23,90653.6,94121.08,94818.89,95033.24,82682.97,79567.22,79593.22,71471.29,70951.8,69905.47,74316.52,99456.0540551959 | ||
tourism_quarterly,15014.19,7656.49,9972.42,8925.52,10475.47,9092.58,10267.97,8981.04,9511.37,8640.56,9137.12,9521.67,15845.100306204946 | ||
tourism_monthly,5302.1,2069.96,2940.08,2004.51,2536.77,2187.28,2537.04,2022.21,1871.69,2003.02,2095.13,2146.98,5636.83029361023 | ||
cif_2016,581875.97,714818.58,855578.4,642421.42,469059.49,563205.57,603551.3,1495923.44,3200418.0,679034.8,5998224.62,4057973.04,386526.3670424068 | ||
australian_electricity_demand,659.6,665.04,370.74,1282.99,1045.92,247.18,241.77,258.76,302.41,213.83,227.5,231.45,659.600688770839 | ||
dominick,5.7,5.86,7.08,5.81,7.1,8.19,8.09,5.85,5.23,8.28,5.1,5.18, | ||
bitcoin,5.33e+18,5.33e+18,9.9e+17,1.1e+18,3.62e+18,6.66e+17,1.93e+18,1.45e+18,1.95e+18,1.06e+18,2.46e+18,2.61e+18,7.777284173521224e+17 | ||
pedestrian_counts,170.87,170.94,222.38,216.5,635.16,44.18,43.41,46.41,44.78,66.84,46.46,47.29,170.8838383838384 | ||
vehicle_trips,29.98,30.76,21.21,30.95,30.07,27.24,22.61,22.93,22.0,28.16,24.15,28.01, | ||
kdd_cup,42.04,42.06,39.2,44.88,52.2,36.85,34.82,37.16,48.98,49.1,37.08,44.46, | ||
weather,2.24,2.51,2.3,2.35,2.45,8.17,2.51,2.09,2.02,2.34,2.29,2.03,2.362190193902301 | ||
nn5_daily,6.63,3.8,3.7,3.72,4.41,5.47,4.22,4.06,3.94,4.92,3.97,4.16,8.262752532958984 | ||
nn5_weekly,15.66,15.3,14.98,15.7,15.38,14.94,15.29,15.02,14.69,14.19,19.34,20.34,16.708553516113007 | ||
kaggle_daily,363.43,358.73,415.4,403.23,340.36,,,,,,,, | ||
kaggle_web_traffic_weekly,2337.11,2373.98,2241.84,2668.28,3115.03,4051.75,10715.36,2025.23,2272.58,2051.3,2025.5,3100.32,2081.781183003247 | ||
solar_10_minutes,3.28,3.29,8.77,3.28,2.37,3.28,5.69,3.28,3.28,3.52,,3.28,2.7269221451758905 | ||
solar_weekly,1202.39,1210.83,908.65,1131.01,839.88,1044.98,1513.49,1050.84,721.59,1172.64,1996.89,576.35,1729.4092503457175 | ||
electricity_hourly,845.97,846.03,574.3,1344.61,868.2,537.38,407.14,354.39,329.75,350.37,286.56,398.8, | ||
electricity_weekly,74149.18,74111.14,24347.24,67737.82,28457.18,44882.52,34518.43,27451.83,50312.05,32991.72,61429.32,76382.47, | ||
carparts,0.55,0.53,0.58,0.56,0.56,0.41,0.53,0.39,0.39,0.98,0.4,0.39, | ||
fred_md,2798.22,3492.84,1989.97,2041.42,2957.11,8921.94,2475.68,2339.57,4264.36,2557.8,2508.4,4666.04,2825.672461360778 | ||
traffic_hourly,0.03,0.03,0.04,0.03,0.04,0.02,0.02,0.01,0.01,0.02,0.02,0.01,0.0262463919825758 | ||
traffic_weekly,1.12,1.13,1.17,1.14,1.22,1.13,1.17,1.15,1.18,1.11,1.2,1.42,1.1855844384623289 | ||
rideshare,6.29,7.62,6.45,6.29,3.37,6.3,6.07,6.59,6.28,5.55,2.75,6.29, | ||
hospital,21.76,18.54,17.43,17.97,19.6,19.24,19.17,22.86,18.25,20.18,19.35,36.19,24.06573229030856 | ||
covid_deaths,353.71,321.32,96.29,85.59,85.77,347.98,475.15,144.14,201.98,158.81,1049.48,408.66,353.70939849624057 | ||
temperature_rain,8.18,8.22,7.14,8.21,7.19,6.13,6.76,5.56,5.37,7.28,5.81,5.24, | ||
sunspot,4.93,4.93,2.57,4.93,2.57,3.83,2.27,7.97,0.77,14.47,0.17,0.13,3.933333396911621 | ||
saugeenday,21.5,21.49,22.26,30.69,22.38,25.24,21.28,22.98,23.51,27.92,22.17,28.06,21.496667098999023 | ||
us_births,1192.2,586.93,399.0,419.73,526.33,574.93,441.7,557.87,424.93,422.0,504.4,452.87,1152.6666666666667 |
Oops, something went wrong.