Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stop #2

Closed
Rdfing opened this issue Apr 13, 2022 · 2 comments
Closed

Early stop #2

Rdfing opened this issue Apr 13, 2022 · 2 comments

Comments

@Rdfing
Copy link

Rdfing commented Apr 13, 2022

Wenjie,

I tried the PyPOTS with the Beijing Air quality database. For the dataset preparation, I follow the gene_UCI_BeijingAirQuality_dataset. The following is the PyPOTS setup.

saits_base = SAITS(seq_len=seq_len, n_features=132, 
                   n_layers=2,  # num of group-inner layers
                   d_model=256, # model hidden dim
                   d_inner=128, # hidden size of feed forward layer
                   n_head=4, # head num of self-attention
                   d_k=64, d_v=64, # key dim, value dim
                   dropout=0, 
                   epochs=200,
                   patience=30,
                   batch_size=32,
                   weight_decay=1e-5,
                   ORT_weight=1,
                   MIT_weight=1,
                  )

saits_base.fit(train_set_X)

PyPOTS stops earlier than the epochs specified (stops around epoch 80), without triggering either print('Exceeded the training patience. Terminating the training procedure...') or print('Finished all training epochs.').

epoch 0: training loss 0.9637 
epoch 1: training loss 0.6161 
epoch 2: training loss 0.5177 
epoch 3: training loss 0.4783 
epoch 4: training loss 0.4489 
...
epoch 73: training loss 0.2462 
epoch 74: training loss 0.2460 
epoch 75: training loss 0.2480 
epoch 76: training loss 0.2452 
epoch 77: training loss 0.2452 
epoch 78: training loss 0.2458 
epoch 79: training loss 0.2449 
epoch 80: training loss 0.2423 
epoch 81: training loss 0.2425 
epoch 82: training loss 0.2443 
epoch 83: training loss 0.2403 
epoch 84: training loss 0.2406

Then I evaluate the model performance (not knowing why the model stops early) on test_set as

test_set_mae = cal_mae(test_set_imputation, test_set_X_intact, test_set_indicating_mask)
0.21866121846582318

I have a few questions:

  1. What could be the cause for the early stop?
  2. In addition, is there any object in saits_base that stores the loss history?
  3. Does the function cal_mae calculate the same MAE in your paper? For this Beijing air quality case, I should be able to tune the hyperparameter to get the test_set_mae down to around 0.146?

Thank you,
Haochen

@WenjieDu
Copy link
Owner

WenjieDu commented Apr 14, 2022

Hi Haochen,

Thank you for opening the first issue of PyPOTS, and your feedback is much appreciated!

A1: This is weird. Currently the only early-stopping strategy is training patience. Is this problem reproducible? I also had given it a try but didn't encounter such a problem.

A2: No, so far PyPOTS does not store training logs. I will add such a function in the near future.

A3: Yes, cal_mae is in the same way as the calculation in the paper 1. You got a higher error not because of cal_mae nor the model, but because of over-fitting. PyPOTS keeps the sklearn-like training procedure, namely, fitting the model on given X, and selecting the best model according to model performance on X (this is probably not ideal to NN models because they are easy to overfit). So, your model over-fits the training set and obtains bigger MAE on the test set. I will consider changing the parameters of function fit from (X) to (train_X, val_X). So far, if you want to reproduce the results in the paper SAITS, please use code in WenjieDu/SAITS.

Footnotes

  1. Wenjie Du et al. SAITS: Self-Attention-based Imputation for Time Series. https://arxiv.org/abs/2202.08516

@Rdfing
Copy link
Author

Rdfing commented Apr 14, 2022

Wenjie,

I see. I will work with your original code then. Here is the script that was used for the Beijing air quality case for your convenience.

Thanks,
Haochen

import pandas as pd
import numpy as np
import datetime 

import math
from sklearn.preprocessing import StandardScaler
from pypots.data import load_specific_dataset, mcar, fill_nan_with_mask
from pypots.imputation import SAITS
from pypots.utils.metrics import cal_mae

import os
import sys

# load data
file_path = 'AirQuality/PRSA_Data_20130301-20170228'

df_collector = []
station_name_collector = []
file_list = os.listdir(file_path)
for filename in file_list:
    current_df = pd.read_csv(file_path+'/'+filename)
    current_df['date_time'] = pd.to_datetime(current_df[['year', 'month', 'day', 'hour']])
    station_name_collector.append(current_df.loc[0, 'station'])
    # remove duplicated date info and wind direction, which is a categorical col
    current_df = current_df.drop(['year', 'month', 'day', 'hour', 'wd', 'No', 'station'], axis=1)
    df_collector.append(current_df)

# 
date_time = df_collector[0]['date_time']
df_collector = [i.drop('date_time', axis=1) for i in df_collector]
df = pd.concat(df_collector, axis=1)
feature_names = [station + '_' + feature
                    for station in station_name_collector
                    for feature in df_collector[0].columns]
feature_num = len(feature_names)
df.columns = feature_names
print(f'Original df missing rate: '
            f'{(df[feature_names].isna().sum().sum() / (df.shape[0] * feature_num)):.3f}')

df['date_time'] = date_time
unique_months = df['date_time'].dt.to_period('M').unique()
selected_as_test = unique_months[:10]  # select first 3 months as test set
print(f'months selected as test set are {selected_as_test}')
selected_as_val = unique_months[10:20]  # select the 4th - the 6th months as val set
print(f'months selected as val set are {selected_as_val}')
selected_as_train = unique_months[20:]  # use left months as train set
print(f'months selected as train set are {selected_as_train}')

test_set = df[df['date_time'].dt.to_period('M').isin(selected_as_test)]
val_set = df[df['date_time'].dt.to_period('M').isin(selected_as_val)]
train_set = df[df['date_time'].dt.to_period('M').isin(selected_as_train)]

scaler = StandardScaler()
train_set_X = scaler.fit_transform(train_set.loc[:, feature_names])
val_set_X = scaler.transform(val_set.loc[:, feature_names])
test_set_X = scaler.transform(test_set.loc[:, feature_names])

def window_truncate(feature_vectors, seq_len):
    """ Generate time series samples, truncating windows from time-series data with a given sequence length.
    Parameters
    ----------
    feature_vectors : array, shape of [total_length, feature_num]
        Time-series data.
    seq_len : int,
        Sequence length.
    Returns
    -------
    array,
        Truncated time series with given sequence length.
    """
    start_indices = np.asarray(range(feature_vectors.shape[0] // seq_len)) * seq_len
    sample_collector = []
    for idx in start_indices:
        sample_collector.append(feature_vectors[idx: idx + seq_len])

    return np.asarray(sample_collector).astype('float32')

seq_len = 24

train_set_X = window_truncate(train_set_X, seq_len)
val_set_X = window_truncate(val_set_X, seq_len)
test_set_X = window_truncate(test_set_X, seq_len)


# hold out 10% observed values as ground truth
train_set_X_intact, train_set_X, train_set_missing_mask, train_set_indicating_mask = mcar(train_set_X, 0.1)
train_set_X = fill_nan_with_mask(train_set_X, train_set_missing_mask)

# hold out 10% observed values as ground truth
test_set_X_intact, test_set_X, test_set_missing_mask, test_set_indicating_mask = mcar(test_set_X, 0.1)
test_set_X = fill_nan_with_mask(test_set_X, test_set_missing_mask)

# Model training. This is PyPOTS showtime.
saits_base = SAITS(seq_len=seq_len, n_features=132, 
                   n_layers=2,  # num of group-inner layers
                   d_model=256, # model hidden dim
                   d_inner=128, # hidden size of feed forward layer
                   n_head=4, # head num of self-attention
                   d_k=64, d_v=64, # key dim, value dim
                   dropout=0, 
                   epochs=200,
                   patience=30,
                   batch_size=32,
                   weight_decay=0.0,
                   ORT_weight=1,
                   MIT_weight=1,
                  )

saits_base.fit(train_set_X)

# impute the originally-missing values and artificially-missing values
train_set_imputation = saits_base.impute(train_set_X) 

# calculate mean absolute error on the ground truth (artificially-missing values)
train_set__mae = cal_mae(train_set_imputation, train_set_X_intact, train_set_indicating_mask)

train_set__mae

# impute the originally-missing values and artificially-missing values
test_set_imputation = saits_base.impute(test_set_X) 

# calculate mean absolute error on the ground truth (artificially-missing values)
test_set_mae = cal_mae(test_set_imputation, test_set_X_intact, test_set_indicating_mask)

test_set_mae

@WenjieDu WenjieDu added enhancement New feature or request and removed enhancement New feature or request labels Apr 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants