Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Is it possible to "warm-up" the transformer? #25

Closed
b2jia opened this issue Nov 11, 2022 · 9 comments
Closed

[Feature request] Is it possible to "warm-up" the transformer? #25

b2jia opened this issue Nov 11, 2022 · 9 comments
Labels
new feature Proposing to add a new feature

Comments

@b2jia
Copy link

b2jia commented Nov 11, 2022

Thank you for creating this wonderful resource! This is an amazing and useful tool!

Regarding SAITS, is it possible to pass a learning rate scheduler, rather than a fixed learning rate, for the transformer to pre-train?

I ask this because I compared the outputs of training 100 epochs vs 1000 epochs. The loss continues to decrease, but the error on holdout timepoints does not change between 100 vs 1000 epochs. Strangely, the prediction (after 100 & 1000 epochs) is less accurate than linear interpolation...! I wondered if it is because the transformers have too many parameters, and it needs some help learning initially.

@WenjieDu
Copy link
Owner

Hi there,

Thank you so much for your attention to PyPOTS! If you find PyPOTS helpful to your work, please star⭐️ this repository. Your star is your recognition, which can help more people notice PyPOTS and grow PyPOTS community. It matters and is definitely a kind of contribution.

I have received your message and will respond ASAP. Thank you for your patience! 😃

Best,
Wenjie

@MaciejSkrabski
Copy link
Contributor

Since we are waiting for Wenjie's feedback, please allow me to chime in.

I ask this because I compared the outputs of training 100 epochs vs 1000 epochs. The loss continues to decrease, but the error on holdout timepoints does not change between 100 vs 1000 epochs.

From my (very) limited understanding of transformers, they learn blazingly fast! In my case, where LSTM needed thousands of epochs to converge, around 40 epochs would be sufficient. Consider monitoring validation loss more often in the initial epochs. Also, keep calm and lower learning rate!

By the way, did you remember to simulate missing data in the training data? It is described in the SAITS paper that the model needs to see two kinds of data, from which two kinds of errors are calculated: one for missing data reconstruction, the other for visible data approximation. If the model does not train on the actual task it is supposed to solve, it cannot.

Best of luck!

@b2jia
Copy link
Author

b2jia commented Nov 12, 2022

This is great insight @MaciejSkrabski !

By the way, did you remember to simulate missing data in the training data?

I did, but I'm puzzled. I introduce artificially missing values (5%) into my already incomplete data (50% missing values). I assume during training, SAITS takes my artificial+incomplete input, and then does something similar to mcar - introduce its own artificially missing values (@WenjieDu what fraction is this, can you confirm?), impute, evaluate imputation loss (MIT). Otherwise, how does it train?

I currently use my artificial missing values as a "test" dataset, to evaluate final imputation loss (e.g.

imputation = saits.impute(X)  # impute the originally-missing values and artificially-missing values
mae = cal_mae(imputation, X_intact, indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)

Did I misunderstand the API - do I have to explicitly pass in the artificially missing value mask as input as well? I currently don't see an option to do so.

@b2jia
Copy link
Author

b2jia commented Nov 12, 2022

After revisiting the SAITS paper, I'm wondering: what is the intuition for MIT_weight and ORT_weight? I find SAITS is able to reconstruct data with near perfect accuracy, but the imputation (at least on my dataset) is only slightly better than linear interpolation. Is it a matter of weighing the imputation more ie. MIT_weight=10, ORT_weight=1 or a matter of model architecture (needs more heads? deeper network?).

@WenjieDu
Copy link
Owner

Thank you for creating this wonderful resource! This is an amazing and useful tool!

Regarding SAITS, is it possible to pass a learning rate scheduler, rather than a fixed learning rate, for the transformer to pre-train?

I ask this because I compared the outputs of training 100 epochs vs 1000 epochs. The loss continues to decrease, but the error on holdout timepoints does not change between 100 vs 1000 epochs. Strangely, the prediction (after 100 & 1000 epochs) is less accurate than linear interpolation...! I wondered if it is because the transformers have too many parameters, and it needs some help learning initially.

Hi Bojing, thank you for raising this issue, and for your patience. And many thanks for your timely help, Maciej @MaciejSkrabski! I'm sorry for my delayed response.

Actually, I tried warm-up for Transformer and SAITS but I didn't obtain any notable improvement. I thought this trick should work for Transformers on the very-large datasets, like NLP corpus, but I may be wrong. Of course, you can write a scheduler to give it a try in your experiment settings. For quick action, you can use the schedulers from the lib Transformers. Please let me know if you have new discoveries.

@WenjieDu
Copy link
Owner

WenjieDu commented Nov 16, 2022

This is great insight @MaciejSkrabski !

By the way, did you remember to simulate missing data in the training data?

I did, but I'm puzzled. I introduce artificially missing values (5%) into my already incomplete data (50% missing values). I assume during training, SAITS takes my artificial+incomplete input, and then does something similar to mcar - introduce its own artificially missing values (@WenjieDu what fraction is this, can you confirm?), impute, evaluate imputation loss (MIT). Otherwise, how does it train?

I currently use my artificial missing values as a "test" dataset, to evaluate final imputation loss (e.g.

imputation = saits.impute(X)  # impute the originally-missing values and artificially-missing values
mae = cal_mae(imputation, X_intact, indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)

Did I misunderstand the API - do I have to explicitly pass in the artificially missing value mask as input as well? I currently don't see an option to do so.

The default artificially missing rate applied in MIT is 20% (check out the code here). I think your understanding of the API is right. Considering PyPOTS is still lacking detailed documentation and is under development, I recommend you try my code in the repo SAITS to impute your data.


After revisiting the SAITS paper, I'm wondering: what is the intuition for MIT_weight and ORT_weight? I find SAITS is able to reconstruct data with near perfect accuracy, but the imputation (at least on my dataset) is only slightly better than linear interpolation. Is it a matter of weighing the imputation more ie. MIT_weight=10, ORT_weight=1 or a matter of model architecture (needs more heads? deeper network?).

They are loss weights of according tasks. Weighting different losses is a common method in multi-task learning. Giving higher weights makes the model pay more attention to the corresponding task because the model gets more punishment. Hope this helps.

@b2jia
Copy link
Author

b2jia commented Nov 16, 2022

Thanks for the response @WenjieDu , I will give the other repo a try. By the way, is it possible to inject the temporal domain into SAITS? For instance, the missing values at time points are unknown but at least the time points are known (and sometimes the time points are unevenly spaced).

@WenjieDu
Copy link
Owner

WenjieDu commented Nov 17, 2022

Like most imputation algorithms, the original SAITS assumes the input sampled with even time intervals. However, you can add the sampling timestamp as an additional feature of the input, or embed the timing into the positional encoding. But if data of features in your dataset is sampled irregularly (e.g. not all features are collected in one sampling operation), this may make your data more sparse.

It's an interesting question. Could you please give more details about your data and scenario? I'd love to know what kind of application needs such a function. And I can see what I can do to help further.

@b2jia
Copy link
Author

b2jia commented Nov 17, 2022

Thank you again for responding! Will close this issue, emailed you directly.

@b2jia b2jia closed this as completed Nov 17, 2022
@WenjieDu WenjieDu added the new feature Proposing to add a new feature label Apr 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new feature Proposing to add a new feature
Projects
None yet
Development

No branches or pull requests

3 participants