Skip to content

Commit

Permalink
Merge pull request #442 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Add `inverse_sliding_window()` and enable TimesNet to work with len>5000 samples
  • Loading branch information
WenjieDu committed Jun 20, 2024
2 parents 0fbdcd6 + 27c8305 commit d563164
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 247 deletions.
24 changes: 14 additions & 10 deletions pypots/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
gene_complete_random_walk_for_anomaly_detection,
gene_complete_random_walk_for_classification,
gene_random_walk,
gene_physionet2012,
)
from .load_specific_datasets import (
list_supported_datasets,
load_specific_dataset,
from .saving import (
save_dict_into_h5,
load_dict_from_h5,
pickle_dump,
pickle_load,
)
from .utils import (
parse_delta,
sliding_window,
inverse_sliding_window,
)
from .saving import save_dict_into_h5
from .utils import parse_delta, sliding_window

__all__ = [
# base dataset classes
Expand All @@ -29,13 +33,13 @@
"gene_complete_random_walk_for_anomaly_detection",
"gene_complete_random_walk_for_classification",
"gene_random_walk",
"gene_physionet2012",
# list and load datasets
"list_supported_datasets",
"load_specific_dataset",
# utils
"parse_delta",
"sliding_window",
"inverse_sliding_window",
# saving
"save_dict_into_h5",
"load_dict_from_h5",
"pickle_dump",
"pickle_load",
]
108 changes: 0 additions & 108 deletions pypots/data/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state

from .load_specific_datasets import load_specific_dataset


def gene_complete_random_walk(
n_samples: int = 1000,
Expand Down Expand Up @@ -320,109 +318,3 @@ def gene_random_walk(
data["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(test_X)

return data


def gene_physionet2012(artificially_missing_rate: float = 0.1):
"""Generate a fully-prepared PhysioNet-2012 dataset for model testing.
Parameters
----------
artificially_missing_rate : float, default=0.1
The rate of artificially missing values to generate for model evaluation.
This ratio is calculated based on the number of observed values, i.e. if artificially_missing_rate = 0.1,
then 10% of the observed values will be randomly masked as missing data and hold out for model evaluation.
Returns
-------
data: dict,
A dictionary containing the generated PhysioNet-2012 dataset.
"""
assert (
0 <= artificially_missing_rate < 1
), "artificially_missing_rate must be in [0,1)"

# generate samples
dataset = load_specific_dataset("physionet_2012")
X = dataset["X"]
y = dataset["y"]
ICUType = dataset["ICUType"]

all_recordID = X["RecordID"].unique()
train_set_ids, test_set_ids = train_test_split(all_recordID, test_size=0.2)
train_set_ids, val_set_ids = train_test_split(train_set_ids, test_size=0.2)
train_set_ids.sort()
val_set_ids.sort()
test_set_ids.sort()
train_set = X[X["RecordID"].isin(train_set_ids)].sort_values(["RecordID", "Time"])
val_set = X[X["RecordID"].isin(val_set_ids)].sort_values(["RecordID", "Time"])
test_set = X[X["RecordID"].isin(test_set_ids)].sort_values(["RecordID", "Time"])

train_set = train_set.drop(["RecordID", "Time"], axis=1)
val_set = val_set.drop(["RecordID", "Time"], axis=1)
test_set = test_set.drop(["RecordID", "Time"], axis=1)
train_X, val_X, test_X = (
train_set.to_numpy(),
val_set.to_numpy(),
test_set.to_numpy(),
)

# normalization
scaler = StandardScaler()
train_X = scaler.fit_transform(train_X)
val_X = scaler.transform(val_X)
test_X = scaler.transform(test_X)

# reshape into time series samples
train_X = train_X.reshape(len(train_set_ids), 48, -1)
val_X = val_X.reshape(len(val_set_ids), 48, -1)
test_X = test_X.reshape(len(test_set_ids), 48, -1)

train_y = y[y.index.isin(train_set_ids)].sort_index()
val_y = y[y.index.isin(val_set_ids)].sort_index()
test_y = y[y.index.isin(test_set_ids)].sort_index()
train_y, val_y, test_y = train_y.to_numpy(), val_y.to_numpy(), test_y.to_numpy()

train_ICUType = ICUType[ICUType.index.isin(train_set_ids)].sort_index()
val_ICUType = ICUType[ICUType.index.isin(val_set_ids)].sort_index()
test_ICUType = ICUType[ICUType.index.isin(test_set_ids)].sort_index()
train_ICUType, val_ICUType, test_ICUType = (
train_ICUType.to_numpy(),
val_ICUType.to_numpy(),
test_ICUType.to_numpy(),
)

data = {
"n_classes": 2,
"n_steps": 48,
"n_features": train_X.shape[-1],
"train_X": train_X,
"train_y": train_y.flatten(),
"train_ICUType": train_ICUType.flatten(),
"val_X": val_X,
"val_y": val_y.flatten(),
"val_ICUType": val_ICUType.flatten(),
"test_X": test_X,
"test_y": test_y.flatten(),
"test_ICUType": test_ICUType.flatten(),
"scaler": scaler,
}

if artificially_missing_rate > 0:
# mask values in the validation set as ground truth
val_X_ori = val_X
val_X = mcar(val_X, artificially_missing_rate)
# mask values in the test set as ground truth
test_X_ori = test_X
test_X = mcar(test_X, artificially_missing_rate)

data["val_X"] = val_X
data["val_X_ori"] = val_X_ori

# test_X is for model input
data["test_X"] = test_X
# test_X_ori is for error calc, not for model input, hence mustn't have NaNs
data["test_X_ori"] = np.nan_to_num(test_X_ori) # fill NaNs for later error calc
data["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(test_X)

return data
54 changes: 0 additions & 54 deletions pypots/data/load_preprocessing.py

This file was deleted.

71 changes: 0 additions & 71 deletions pypots/data/load_specific_datasets.py

This file was deleted.

60 changes: 58 additions & 2 deletions pypots/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import math
from typing import Union

import numpy as np
import torch

from ..utils.logging import logger


def turn_data_into_specified_dtype(
data: Union[np.ndarray, torch.Tensor, list],
Expand Down Expand Up @@ -194,8 +197,13 @@ def sliding_window(time_series, window_len, sliding_len=None):
start_indices = np.asarray(range(total_len // sliding_len)) * sliding_len

# remove the last one if left length is not enough
if total_len - start_indices[-1] * sliding_len < window_len:
start_indices = start_indices[:-1]
if total_len - start_indices[-1] < window_len:
to_drop = math.ceil(window_len / sliding_len)
left_len = total_len - start_indices[-1]
start_indices = start_indices[:-to_drop]
logger.warning(
f"The last {to_drop} samples are dropped due to the left length {left_len} is not enough."
)

sample_collector = []
for idx in start_indices:
Expand All @@ -204,3 +212,51 @@ def sliding_window(time_series, window_len, sliding_len=None):
samples = np.asarray(sample_collector).astype("float32")

return samples


def inverse_sliding_window(X, sliding_len):
"""Restore the original time-series data from the generated sliding window samples.
Note that this is the inverse operation of the `sliding_window` function, but there is no guarantee that
the restored data is the same as the original data considering that
1. the sliding length may be larger than the window size and there will be gaps between restored data;
2. if values in the samples get changed, the overlap part may not be the same as the original data after averaging;
3. some incomplete samples at the tail may be dropped during the sliding window operation, hence the restored data
may be shorter than the original data.
Parameters
----------
X :
The generated time-series samples with sliding window method, shape of [n_samples, n_steps, n_features],
where n_steps is the window size of the used sliding window method.
sliding_len :
The sliding length of the window for each moving step in the sliding window method used to generate X.
Returns
-------
restored_data :
The restored time-series data with shape of [total_length, n_features].
"""
assert len(X.shape) == 3, f"X should be a 3D array, but got {X.shape}"
n_samples, window_size, n_features = X.shape

if sliding_len >= window_size:
if sliding_len > window_size:
logger.warning(
f"sliding_len {sliding_len} is larger than the window size {window_size}, "
f"hence there will be gaps between restored data."
)
restored_data = X.reshape(n_samples * window_size, n_features)
else:
collector = [X[0][:sliding_len]]
overlap = X[0][sliding_len:]
for x in X[1:]:
overlap_avg = (overlap + x[:-sliding_len]) / 2
collector.append(overlap_avg[:sliding_len])
overlap = np.concatenate(
[overlap_avg[sliding_len:], x[-sliding_len:]], axis=0
)
collector.append(overlap)
restored_data = np.concatenate(collector, axis=0)
return restored_data
1 change: 1 addition & 0 deletions pypots/imputation/timesnet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
n_features,
d_model,
dropout=dropout,
n_max_steps=n_steps,
)
self.model = BackboneTimesNet(
n_layers,
Expand Down
5 changes: 4 additions & 1 deletion pypots/nn/modules/transformer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,17 @@ def __init__(
freq="h",
dropout=0.1,
with_pos=True,
n_max_steps=1000,
):
super().__init__()

self.with_pos = with_pos

self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
if with_pos:
self.position_embedding = PositionalEncoding(d_hid=d_model)
self.position_embedding = PositionalEncoding(
d_hid=d_model, n_positions=n_max_steps
)
self.temporal_embedding = (
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
if embed_type != "timeF"
Expand Down
Loading

0 comments on commit d563164

Please sign in to comment.