Skip to content

Commit

Permalink
figures+results
Browse files Browse the repository at this point in the history
  • Loading branch information
RaunakDey committed Jun 7, 2024
1 parent 4d8d5f8 commit 55e6a57
Show file tree
Hide file tree
Showing 77 changed files with 15,885 additions and 0 deletions.
Binary file added assets/llmtime_top_fig.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/__pycache__/metrics.cpython-311.pyc
Binary file not shown.
Binary file added data/__pycache__/metrics.cpython-39.pyc
Binary file not shown.
Binary file added data/__pycache__/serialize.cpython-311.pyc
Binary file not shown.
Binary file added data/__pycache__/serialize.cpython-39.pyc
Binary file not shown.
Binary file added data/__pycache__/small_context.cpython-311.pyc
Binary file not shown.
Binary file added data/__pycache__/small_context.cpython-39.pyc
Binary file not shown.
Binary file added data/__pycache__/synthetic.cpython-39.pyc
Binary file not shown.
588 changes: 588 additions & 0 deletions data/autoformer_dataset.py

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions data/last_val_mae.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
dataset,mae
bitcoin,7.777284173521224e+17
wind_4_seconds,0.0
tourism_quarterly,15845.100306204946
traffic_weekly,1.1855844384623289
us_births,1152.6666666666667
sunspot,3.933333396911621
covid_deaths,353.70939849624057
weather,2.362190193902301
nn5_daily,8.262752532958984
solar_10_minutes,2.7269221451758905
fred_md,2825.672461360778
tourism_yearly,99456.0540551959
oikolab_weather,120.03774162319314
cif_2016,386526.36704240675
solar_4_seconds,0.0
australian_electricity_demand,659.600688770839
traffic_hourly,0.026246391982575817
solar_weekly,1729.4092503457175
tourism_monthly,5636.83029361023
nn5_weekly,16.708553516113007
pedestrian_counts,170.8838383838384
kaggle_web_traffic_weekly,2081.781183003247
hospital,24.06573229030856
saugeenday,21.496667098999023
22 changes: 22 additions & 0 deletions data/last_value_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
dataset,rmse,mse,mae
weather,3.114059003829736,19.799946602304583,2.362190193902301
tourism_yearly,111078.12328993063,363513954611.7144,99456.0540551959
tourism_quarterly,19527.771486861628,6201928361.602383,15845.100306204946
tourism_monthly,7374.891631391369,619111237.8884984,5636.83029361023
cif_2016,513908.01521842513,4215886809397.5137,386526.3670424068
australian_electricity_demand,800.8367705189698,885757.1106004094,652.185211690267
pedestrian_counts,201.73714366476048,151208.0053661616,149.78377525252526
nn5_weekly,20.207279217372406,551.324271498317,16.708553516113007
kaggle_web_traffic_weekly,2751.056545255073,745610075.6193119,2081.781183003247
solar_10_minutes,2.916518075950651,15.696517675162108,2.72692214517589
solar_weekly,1918.441068452442,4732343.391918798,1729.4092503457173
fred_md,3128.761597032152,789147915.1258634,2825.672461360779
traffic_hourly,0.038692458579408555,0.0018708090134443467,0.029967592352576376
traffic_weekly,1.5367087177120735,4.087448585855902,1.1855844384623289
hospital,28.901001578961637,5851.362559756627,24.06573229030856
covid_deaths,403.4146883707207,4435194.603884712,353.7093984962405
saugeenday,39.79399009437644,1583.56164763133,21.496667098999023
us_births,1608.8383179590587,2588360.7333333334,1152.6666666666667
solar_4_seconds,0.0,0.0,0.0
wind_4_seconds,0.0,0.0,0.0
oikolab_weather,130.07108228180007,81269.08468305029,120.03774162319314
137 changes: 137 additions & 0 deletions data/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import numpy as np
from jax import vmap
import jax.numpy as jnp

def quantile_loss(target, pred, q):
q_pred = jnp.quantile(pred, q, axis=0)
return 2 * jnp.sum(
jnp.abs((q_pred - target) * ((target <= q_pred) * 1.0 - q))
)

def calculate_crps(target, pred, num_quantiles=20):
quantiles = jnp.linspace(0, 1.0, num_quantiles+1)[1:]
vec_quantile_loss = vmap(lambda q: quantile_loss(target, pred, q))
crps = jnp.sum(vec_quantile_loss(quantiles))
crps = crps / (jnp.sum(np.abs(target)) * len(quantiles))
return crps

import jax
from jax import grad,vmap
from .serialize import serialize_arr, SerializerSettings
import openai

def nll(input_arr, target_arr, model, settings:SerializerSettings, transform, count_seps=True, prompt=None, temp=1):
""" Returns the NLL/dimension (log base e) of the target array (continuous) according to the LM
conditioned on the input array. Applies relevant log determinant for transforms and
converts from discrete NLL of the LLM to continuous by assuming uniform within the bins.
inputs:
input_arr: (n,) context array
target_arr: (n,) ground truth array
Returns: NLL/D
"""
input_str = serialize_arr(vmap(transform)(input_arr), settings)
target_str = serialize_arr(vmap(transform)(target_arr), settings)
if prompt:
input_str = prompt + '\n' + input_str
if not input_str.endswith(settings.time_sep):
print('Appending time separator to input... Are you sure you want this?')
prompt = input_str + settings.time_sep + target_str
else:
prompt = input_str + target_str
response = openai.Completion.create(model=model, prompt=prompt, logprobs=5, max_tokens=0, echo=True, temperature=temp)
#print(response['choices'][0])
logprobs = np.array(response['choices'][0].logprobs.token_logprobs, dtype=np.float32)
tokens = np.array(response['choices'][0].logprobs.tokens)
top5logprobs = response['choices'][0].logprobs.top_logprobs
seps = tokens==settings.time_sep
target_start = np.argmax(np.cumsum(seps)==len(input_arr)) + 1
logprobs = logprobs[target_start:]
tokens = tokens[target_start:]
top5logprobs = top5logprobs[target_start:]
seps = tokens==settings.time_sep
assert len(logprobs[seps]) == len(target_arr), f'There should be one separator per target. Got {len(logprobs[seps])} separators and {len(target_arr)} targets.'
#adjust logprobs by removing extraneous and renormalizing (see appendix of paper)
# logp' = logp - log(1-pk*pextra)
allowed_tokens = [settings.bit_sep + str(i) for i in range(settings.base)]
allowed_tokens += [settings.time_sep, settings.plus_sign, settings.minus_sign, settings.bit_sep+settings.decimal_point]
allowed_tokens = {t for t in allowed_tokens if len(t) > 0}

p_extra = np.array([sum(np.exp(ll) for k,ll in top5logprobs[i].items() if not (k in allowed_tokens)) for i in range(len(top5logprobs))])
if settings.bit_sep == '':
p_extra = 0
adjusted_logprobs = logprobs - np.log(1-p_extra)
digits_bits = -adjusted_logprobs[~seps].sum()
seps_bits = -adjusted_logprobs[seps].sum()
BPD = digits_bits/len(target_arr)
if count_seps:
BPD += seps_bits/len(target_arr)
#print("BPD unadjusted:", -logprobs.sum()/len(target_arr), "BPD adjusted:", BPD)
# log p(x) = log p(token) - log bin_width = log p(token) + prec * log base
transformed_nll = BPD - settings.prec*np.log(settings.base)
avg_logdet_dydx = np.log(vmap(grad(transform))(target_arr)).mean()
return transformed_nll-avg_logdet_dydx

class Evaluator:

def __init__(self):
self.non_numerical_cols = [
"serialized_history",
"serialized_target",
"serialized_prediction",
"history_len",
"num_channels",
"example_num",
"sample_num",
]

def evaluate_df(self, gt_df, pred_df):
cols = [c for c in gt_df.columns if c not in self.non_numerical_cols]
num_channels = gt_df["num_channels"].iloc[0]
history_len = gt_df["history_len"].iloc[0]
gt_vals = gt_df[cols].to_numpy().reshape(len(gt_df), -1, num_channels) # (num_examples, history_len + target_len, num_channels)
gt_vals = gt_vals[:, history_len:, :] # (num_examples, target_len, num_channels)

cols = [c for c in pred_df.columns if c not in self.non_numerical_cols]
num_channels = pred_df["num_channels"].iloc[0]
pred_df = pred_df[cols + ["example_num"]]

all_pred_vals = []
for example_num in sorted(pred_df["example_num"].unique()):
pred_vals = pred_df[pred_df["example_num"] == example_num][cols].to_numpy() # (num_samples, target_len * num_channels)
pred_vals = pred_vals.reshape(pred_vals.shape[0], -1, num_channels) # (num_samples, target_len, num_channels)
all_pred_vals.append(pred_vals)

pred_vals = np.stack(all_pred_vals, axis=1) # (num_samples, num_examples, target_len, num_channels)
assert gt_vals.shape == pred_vals.shape[1:]

diff = (gt_vals[None] - pred_vals) # (num_samples, num_examples, target_len, num_channels)
mse = np.mean(diff**2)
mae = np.mean(np.abs(diff))
crps = calculate_crps(gt_vals, pred_vals)

return {
"mse": mse,
"mae": mae,
"crps": crps,
}

def evaluate(self, gt, pred):
'''
gt: (batch_size, steps)
pred: (batch_size, num_samples, steps)
'''
assert gt.shape == (pred.shape[0], pred.shape[2]), f"wrong shapes: gt.shape: {gt.shape}, pred.shape: {pred.shape}"
diff = (gt[:, None, :] - pred) # (batch_size, num_samples, steps)
mse = np.mean(diff**2)
mae = np.mean(np.abs(diff))
std = np.std(gt, axis=1) + 1e-8 # (batch_size,)
normlized_diff = diff / std[:, None, None] # (batch_size, num_samples, steps)
nmse = np.mean(normlized_diff**2)
nmae = np.mean(np.abs(normlized_diff))

return {
"nmse": nmse,
"nmae": nmae,
"mse": mse,
"mae": mae,
}
140 changes: 140 additions & 0 deletions data/monash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import pandas as pd
import numpy as np
from collections import defaultdict
from sklearn.preprocessing import StandardScaler
import datasets
from datasets import load_dataset
import os
import pickle

fix_pred_len = {
'australian_electricity_demand': 336,
'pedestrian_counts': 24,
'traffic_hourly': 168,
}

def get_benchmark_test_sets():
test_set_dir = "datasets/monash"
if not os.path.exists(test_set_dir):
os.makedirs(test_set_dir)

if len(os.listdir(test_set_dir)) > 0:
print(f'Loading test sets from {test_set_dir}')
test_sets = {}
for file in os.listdir(test_set_dir):
test_sets[file.split(".")[0]] = pickle.load(open(os.path.join(test_set_dir, file), 'rb'))
return test_sets
else:
print(f'No files found in {test_set_dir}. You are not using our preprocessed datasets!')

benchmarks = {
"monash_tsf": datasets.get_dataset_config_names("monash_tsf"),
}

test_sets = defaultdict(list)
for path in benchmarks:
pred_lens = [24, 48, 96, 192] if path == "ett" else [None]
for name in benchmarks[path]:
for pred_len in pred_lens:
if pred_len is None:
ds = load_dataset(path, name)
else:
ds = load_dataset(path, name, multivariate=False, prediction_length=pred_len)

train_example = ds['train'][0]['target']
val_example = ds['validation'][0]['target']

if len(np.array(train_example).shape) > 1:
print(f"Skipping {name} because it is multivariate")
continue

pred_len = len(val_example) - len(train_example)
if name in fix_pred_len:
print(f"Fixing pred len for {name}: {pred_len} -> {fix_pred_len[name]}")
pred_len = fix_pred_len[name]

tag = name
print("Processing", tag)

pairs = []
for x in ds['test']:
if np.isnan(x['target']).any():
print(f"Skipping {name} because it has NaNs")
break
history = np.array(x['target'][:-pred_len])
target = np.array(x['target'][-pred_len:])
pairs.append((history, target))
else:
scaler = None
if path == "ett":
trainset = np.array(ds['train'][0]['target'])
scaler = StandardScaler().fit(trainset[:,None])
test_sets[tag] = (pairs, scaler)

for name in test_sets:
try:
with open(os.path.join(test_set_dir,f"{name}.pkl"), 'wb') as f:
pickle.dump(test_sets[name], f)
print(f"Saved {name}")
except:
print(f"Failed to save {name}")

return test_sets

def get_datasets():
benchmarks = get_benchmark_test_sets()
# shuffle the benchmarks
for k, v in benchmarks.items():
x, _scaler = v # scaler is not used
train, test = zip(*x)
np.random.seed(0)
ind = np.arange(len(train))
ind = np.random.permutation(ind)
train = [train[i] for i in ind]
test = [test[i] for i in ind]
benchmarks[k] = [list(train), list(test)]

df = pd.read_csv('data/last_val_mae.csv')
df.sort_values(by='mae')

df_paper = pd.read_csv('data/paper_mae_raw.csv') # pdf text -> csv
datasets = df_paper['Dataset']
name_map = {
'Aus. Electricity Demand' :'australian_electricity_demand',
'Kaggle Weekly': 'kaggle_web_traffic_weekly',
'FRED-MD': 'fred_md',
'Saugeen River Flow': 'saugeenday',

}
datasets = [name_map.get(d, d) for d in datasets]
# lower case and repalce spaces with underscores
datasets = [d.lower().replace(' ', '_') for d in datasets]
df_paper['Dataset'] = datasets
df_paper = df_paper.reset_index(drop=True)
# for each dataset, add last value mae to df_paper
for dataset in df_paper['Dataset']:
if dataset in df['dataset'].values:
df_paper.loc[df_paper['Dataset'] == dataset, 'Last Value'] = df[df['dataset'] == dataset]['mae'].values[0]
# turn '-' into np.nan
df_paper = df_paper.replace('-', np.nan)
# convert all values to float
for method in df_paper.columns[1:]:
df_paper[method] = df_paper[method].astype(float)
df_paper.to_csv('data/paper_mae.csv', index=False)
# normalize each method by dividing by last value mae
for method in df_paper.columns[1:-1]: # skip dataset and last value
df_paper[method] = df_paper[method] / df_paper['Last Value']
# sort df by minimum mae across methods
df_paper['normalized_min'] = df_paper[df_paper.columns[1:-1]].min(axis=1)
df_paper['normalized_median'] = df_paper[df_paper.columns[1:-1]].median(axis=1)
df_paper = df_paper.sort_values(by='normalized_min')
df_paper = df_paper.reset_index(drop=True)
# save as csv
df_paper.to_csv('data/paper_mae_normalized.csv', index=False)
return benchmarks

def main():
get_datasets()

if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions data/paper_mae.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Dataset,SES,Theta,TBATS,ETS,(DHR-)ARIMA,PR,CatBoost,FFNN,DeepAR,N-BEATS,WaveNet,Transformer,Last Value
tourism_yearly,95579.23,90653.6,94121.08,94818.89,95033.24,82682.97,79567.22,79593.22,71471.29,70951.8,69905.47,74316.52,99456.0540551959
tourism_quarterly,15014.19,7656.49,9972.42,8925.52,10475.47,9092.58,10267.97,8981.04,9511.37,8640.56,9137.12,9521.67,15845.100306204946
tourism_monthly,5302.1,2069.96,2940.08,2004.51,2536.77,2187.28,2537.04,2022.21,1871.69,2003.02,2095.13,2146.98,5636.83029361023
cif_2016,581875.97,714818.58,855578.4,642421.42,469059.49,563205.57,603551.3,1495923.44,3200418.0,679034.8,5998224.62,4057973.04,386526.3670424068
australian_electricity_demand,659.6,665.04,370.74,1282.99,1045.92,247.18,241.77,258.76,302.41,213.83,227.5,231.45,659.600688770839
dominick,5.7,5.86,7.08,5.81,7.1,8.19,8.09,5.85,5.23,8.28,5.1,5.18,
bitcoin,5.33e+18,5.33e+18,9.9e+17,1.1e+18,3.62e+18,6.66e+17,1.93e+18,1.45e+18,1.95e+18,1.06e+18,2.46e+18,2.61e+18,7.777284173521224e+17
pedestrian_counts,170.87,170.94,222.38,216.5,635.16,44.18,43.41,46.41,44.78,66.84,46.46,47.29,170.8838383838384
vehicle_trips,29.98,30.76,21.21,30.95,30.07,27.24,22.61,22.93,22.0,28.16,24.15,28.01,
kdd_cup,42.04,42.06,39.2,44.88,52.2,36.85,34.82,37.16,48.98,49.1,37.08,44.46,
weather,2.24,2.51,2.3,2.35,2.45,8.17,2.51,2.09,2.02,2.34,2.29,2.03,2.362190193902301
nn5_daily,6.63,3.8,3.7,3.72,4.41,5.47,4.22,4.06,3.94,4.92,3.97,4.16,8.262752532958984
nn5_weekly,15.66,15.3,14.98,15.7,15.38,14.94,15.29,15.02,14.69,14.19,19.34,20.34,16.708553516113007
kaggle_daily,363.43,358.73,415.4,403.23,340.36,,,,,,,,
kaggle_web_traffic_weekly,2337.11,2373.98,2241.84,2668.28,3115.03,4051.75,10715.36,2025.23,2272.58,2051.3,2025.5,3100.32,2081.781183003247
solar_10_minutes,3.28,3.29,8.77,3.28,2.37,3.28,5.69,3.28,3.28,3.52,,3.28,2.7269221451758905
solar_weekly,1202.39,1210.83,908.65,1131.01,839.88,1044.98,1513.49,1050.84,721.59,1172.64,1996.89,576.35,1729.4092503457175
electricity_hourly,845.97,846.03,574.3,1344.61,868.2,537.38,407.14,354.39,329.75,350.37,286.56,398.8,
electricity_weekly,74149.18,74111.14,24347.24,67737.82,28457.18,44882.52,34518.43,27451.83,50312.05,32991.72,61429.32,76382.47,
carparts,0.55,0.53,0.58,0.56,0.56,0.41,0.53,0.39,0.39,0.98,0.4,0.39,
fred_md,2798.22,3492.84,1989.97,2041.42,2957.11,8921.94,2475.68,2339.57,4264.36,2557.8,2508.4,4666.04,2825.672461360778
traffic_hourly,0.03,0.03,0.04,0.03,0.04,0.02,0.02,0.01,0.01,0.02,0.02,0.01,0.0262463919825758
traffic_weekly,1.12,1.13,1.17,1.14,1.22,1.13,1.17,1.15,1.18,1.11,1.2,1.42,1.1855844384623289
rideshare,6.29,7.62,6.45,6.29,3.37,6.3,6.07,6.59,6.28,5.55,2.75,6.29,
hospital,21.76,18.54,17.43,17.97,19.6,19.24,19.17,22.86,18.25,20.18,19.35,36.19,24.06573229030856
covid_deaths,353.71,321.32,96.29,85.59,85.77,347.98,475.15,144.14,201.98,158.81,1049.48,408.66,353.70939849624057
temperature_rain,8.18,8.22,7.14,8.21,7.19,6.13,6.76,5.56,5.37,7.28,5.81,5.24,
sunspot,4.93,4.93,2.57,4.93,2.57,3.83,2.27,7.97,0.77,14.47,0.17,0.13,3.933333396911621
saugeenday,21.5,21.49,22.26,30.69,22.38,25.24,21.28,22.98,23.51,27.92,22.17,28.06,21.496667098999023
us_births,1192.2,586.93,399.0,419.73,526.33,574.93,441.7,557.87,424.93,422.0,504.4,452.87,1152.6666666666667
Loading

0 comments on commit 55e6a57

Please sign in to comment.