Skip to content

Commit

Permalink
Merge pull request #344 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Release v0.4, apply SAITS embedding strategy to the newly added models, and update README
  • Loading branch information
WenjieDu authored Apr 9, 2024
2 parents fb7ec06 + 0df20d0 commit eb03a15
Show file tree
Hide file tree
Showing 14 changed files with 227 additions and 106 deletions.
69 changes: 36 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,39 +192,42 @@ The paper references are all listed at the bottom of this readme file. Please re
🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support.
This functionality is implemented with the [Microsoft NNI](https://github.com/microsoft/nni) framework.

| ***`Imputation`*** | 🚥 | 🚥 | 🚥 |
|:----------------------:|:-----------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|
| **Type** | **Abbr.** | **Full name of the algorithm/model** | **Year** |
| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 |
| Neural Net | Transformer | Attention is All you Need [^2];<br>Self-Attention-based Imputation for Time Series [^1];<br><sub>Note: proposed in [^2], and re-implemented as an imputation model in [^1].</sub> | 2017 |
| Neural Net | Crossformer | Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting [^16] | 2023 |
| Neural Net | TimesNet | Temporal 2D-Variation Modeling for General Time Series Analysis [^14] | 2023 |
| Neural Net | PatchTST | A Time Series is Worth 64 Words: Long-Term Forecasting with Transformers [^18] | 2023 |
| Neural Net | DLinear | Are Transformers Effective for Time Series Forecasting? [^17] | 2023 |
| Neural Net | ETSformer | Exponential Smoothing Transformers for Time-series Forecasting [^19] | 2023 |
| Neural Net | FEDformer | Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting [^20] | 2022 |
| Neural Net | Informer | Beyond Efficient Transformer for Long Sequence Time-Series Forecasting [^21] | 2021 |
| Neural Net | Autoformer | Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting [^15] | 2021 |
| Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 |
| Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 |
| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 |
| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
| Neural Net | M-RNN | Multi-directional Recurrent Neural Network [^9] | 2019 |
| Naive | LOCF/NOCB | Last Observation Carried Forward / Next Observation Carried Backward | - |
| Naive | Median | Median Value Imputation | - |
| Naive | Mean | Mean Value Imputation | - |
| ***`Classification`*** | 🚥 | 🚥 | 🚥 |
| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
| Neural Net | GRU-D | Recurrent Neural Networks for Multivariate Time Series with Missing Values [^4] | 2018 |
| Neural Net | Raindrop | Graph-Guided Network for Irregularly Sampled Multivariate Time Series [^5] | 2022 |
| ***`Clustering`*** | 🚥 | 🚥 | 🚥 |
| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
| Neural Net | CRLI | Clustering Representation Learning on Incomplete time-series data [^6] | 2021 |
| Neural Net | VaDER | Variational Deep Embedding with Recurrence [^7] | 2019 |
| ***`Forecasting`*** | 🚥 | 🚥 | 🚥 |
| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
| Probabilistic | BTTF | Bayesian Temporal Tensor Factorization [^8] | 2021 |
🔥 Note that Transformer, Crossformer, PatchTST, DLinear, ETSformer, FEDformer, Informer, Autoformer are not proposed as imputation methods in their original papers,
and they cannot accept POTS as input. **To make them applicable on POTS data, we apply the embedding strategy the same as we did in [SAITS paper](https://arxiv.org/pdf/2202.08516).**

| ***`Imputation`*** | 🚥 | 🚥 | 🚥 |
|:----------------------:|:-----------:|:-----------------------------------------------------------------------------------------------:|:--------:|
| **Type** | **Abbr.** | **Full name of the algorithm/model** | **Year** |
| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 |
| Neural Net | Transformer | Attention is All you Need [^2] | 2017 |
| Neural Net | Crossformer | Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting [^16] | 2023 |
| Neural Net | TimesNet | Temporal 2D-Variation Modeling for General Time Series Analysis [^14] | 2023 |
| Neural Net | PatchTST | A Time Series is Worth 64 Words: Long-Term Forecasting with Transformers [^18] | 2023 |
| Neural Net | DLinear | Are Transformers Effective for Time Series Forecasting? [^17] | 2023 |
| Neural Net | ETSformer | Exponential Smoothing Transformers for Time-series Forecasting [^19] | 2023 |
| Neural Net | FEDformer | Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting [^20] | 2022 |
| Neural Net | Informer | Beyond Efficient Transformer for Long Sequence Time-Series Forecasting [^21] | 2021 |
| Neural Net | Autoformer | Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting [^15] | 2021 |
| Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 |
| Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 |
| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 |
| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
| Neural Net | M-RNN | Multi-directional Recurrent Neural Network [^9] | 2019 |
| Naive | LOCF/NOCB | Last Observation Carried Forward / Next Observation Carried Backward | - |
| Naive | Median | Median Value Imputation | - |
| Naive | Mean | Mean Value Imputation | - |
| ***`Classification`*** | 🚥 | 🚥 | 🚥 |
| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
| Neural Net | GRU-D | Recurrent Neural Networks for Multivariate Time Series with Missing Values [^4] | 2018 |
| Neural Net | Raindrop | Graph-Guided Network for Irregularly Sampled Multivariate Time Series [^5] | 2022 |
| ***`Clustering`*** | 🚥 | 🚥 | 🚥 |
| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
| Neural Net | CRLI | Clustering Representation Learning on Incomplete time-series data [^6] | 2021 |
| Neural Net | VaDER | Variational Deep Embedding with Recurrence [^7] | 2019 |
| ***`Forecasting`*** | 🚥 | 🚥 | 🚥 |
| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
| Probabilistic | BTTF | Bayesian Temporal Tensor Factorization [^8] | 2021 |


## ❖ Citing PyPOTS
Expand Down
2 changes: 1 addition & 1 deletion pypots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
__version__ = "0.3.2"
__version__ = "0.4"


from . import imputation, classification, clustering, forecasting, optim, data, utils
Expand Down
22 changes: 15 additions & 7 deletions pypots/imputation/autoformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from .submodules import (
Expand Down Expand Up @@ -38,7 +39,7 @@ def __init__(
self.seq_len = n_steps
self.n_layers = n_layers
self.enc_embedding = DataEmbedding(
n_features,
n_features * 2,
d_model,
dropout=dropout,
with_pos=False,
Expand All @@ -63,28 +64,35 @@ def __init__(
)

# for the imputation task, the output dim is the same as input dim
self.projection = nn.Linear(d_model, n_features)
self.output_projection = nn.Linear(d_model, n_features)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, masks = inputs["X"], inputs["missing_mask"]

# embedding
enc_out = self.enc_embedding(X) # [B,T,C]
# WDU: the original Autoformer paper isn't proposed for imputation task. Hence the model doesn't take
# the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the
# embedding layers to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.

# the same as SAITS, concatenate the time series data and the missing mask for embedding
input_X = torch.cat([X, masks], dim=2)
enc_out = self.enc_embedding(input_X)

# Autoformer encoder processing
enc_out, attns = self.encoder(enc_out)

# project back the original data space
dec_out = self.projection(enc_out)
output = self.output_projection(enc_out)

imputed_data = masks * X + (1 - masks) * dec_out
imputed_data = masks * X + (1 - masks) * output
results = {
"imputed_data": imputed_data,
}

if training:
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"])
loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
results["loss"] = loss

return results
20 changes: 15 additions & 5 deletions pypots/imputation/crossformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
super().__init__()

self.n_features = n_features
self.d_model = d_model

# The padding operation to handle invisible sgemnet length
pad_in_len = ceil(1.0 * n_steps / seg_len) * seg_len
Expand All @@ -49,7 +50,7 @@ def __init__(
0,
)
self.enc_pos_embedding = nn.Parameter(
torch.randn(1, n_features, in_seg_num, d_model)
torch.randn(1, d_model, in_seg_num, d_model)
)
self.pre_norm = nn.LayerNorm(d_model)

Expand All @@ -71,31 +72,40 @@ def __init__(
)

self.head = FlattenHead(head_nf, n_steps, dropout)
self.embedding = nn.Linear(n_features * 2, d_model)
self.output_projection = nn.Linear(d_model, n_features)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, masks = inputs["X"], inputs["missing_mask"]

# WDU: the original Crossformer paper isn't proposed for imputation task. Hence the model doesn't take
# the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the
# embedding layers to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.
# embedding
x_enc = self.enc_value_embedding(X.permute(0, 2, 1))
input_X = self.embedding(torch.cat([X, masks], dim=2))
x_enc = self.enc_value_embedding(input_X.permute(0, 2, 1))

# Crossformer processing
x_enc = rearrange(
x_enc, "(b d) seg_num d_model -> b d seg_num d_model", d=self.n_features
x_enc, "(b d) seg_num d_model -> b d seg_num d_model", d=self.d_model
)
x_enc += self.enc_pos_embedding
x_enc = self.pre_norm(x_enc)
enc_out, attns = self.encoder(x_enc)
# project back the original data space
dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1)
output = self.output_projection(dec_out)

imputed_data = masks * X + (1 - masks) * dec_out
imputed_data = masks * X + (1 - masks) * output
results = {
"imputed_data": imputed_data,
}

if training:
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"])
loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
results["loss"] = loss

return results
9 changes: 6 additions & 3 deletions pypots/imputation/crossformer/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,12 @@ def __init__(
d_ff,
depth,
dropout,
seg_num=10,
factor=10,
seg_num,
factor,
):
super().__init__()

d_k = d_model // n_heads
if win_size > 1:
self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm)
else:
Expand All @@ -158,7 +159,9 @@ def __init__(

for i in range(depth):
self.encode_layers.append(
TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, d_ff, dropout)
TwoStageAttentionLayer(
seg_num, factor, d_model, n_heads, d_k, d_k, d_ff, dropout
)
)

def forward(self, x, attn_mask=None, tau=None, delta=None):
Expand Down
Loading

0 comments on commit eb03a15

Please sign in to comment.