From 2dfd84c2d020dfc1c07b7dea07987eda270384ef Mon Sep 17 00:00:00 2001 From: Avik Basu <3485425+ab93@users.noreply.github.com> Date: Tue, 8 Aug 2023 13:59:22 -0700 Subject: [PATCH] feat: convolutional vae for multivariate time series (#237) - Introduce variational autoencoder for multivariate time series data - Customized trainer for vae - Causal convolutional modules - Improved training console logging Exploration in https://github.com/numaproj/numalogic-benchmarks/pull/3 --------- Signed-off-by: Avik Basu --- Makefile | 2 +- numalogic/models/vae/__init__.py | 3 + numalogic/models/vae/base.py | 63 ++++++ numalogic/models/vae/layer.py | 62 ++++++ numalogic/models/vae/trainer.py | 69 ++++++ numalogic/models/vae/variants/__init__.py | 3 + numalogic/models/vae/variants/conv.py | 248 ++++++++++++++++++++++ numalogic/tools/callbacks.py | 57 +++++ numalogic/tools/data.py | 4 + numalogic/tools/exceptions.py | 2 +- tests/models/vae/__init__.py | 0 tests/models/vae/test_conv.py | 122 +++++++++++ 12 files changed, 633 insertions(+), 2 deletions(-) create mode 100644 numalogic/models/vae/__init__.py create mode 100644 numalogic/models/vae/base.py create mode 100644 numalogic/models/vae/layer.py create mode 100644 numalogic/models/vae/trainer.py create mode 100644 numalogic/models/vae/variants/__init__.py create mode 100644 numalogic/models/vae/variants/conv.py create mode 100644 tests/models/vae/__init__.py create mode 100644 tests/models/vae/test_conv.py diff --git a/Makefile b/Makefile index 6224b908..ad3fc65f 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ clean: @find . -type f -name "*.py[co]" -exec rm -rf {} + format: clean - poetry run black numalogic/ examples/ tests/ benchmarks/ + poetry run black numalogic/ examples/ tests/ lint: format poetry run ruff check --fix . diff --git a/numalogic/models/vae/__init__.py b/numalogic/models/vae/__init__.py new file mode 100644 index 00000000..c6f64d4e --- /dev/null +++ b/numalogic/models/vae/__init__.py @@ -0,0 +1,3 @@ +from numalogic.models.vae.trainer import VAETrainer + +__all__ = ["VAETrainer"] diff --git a/numalogic/models/vae/base.py b/numalogic/models/vae/base.py new file mode 100644 index 00000000..3103b4d6 --- /dev/null +++ b/numalogic/models/vae/base.py @@ -0,0 +1,63 @@ +from typing import Callable + +import torch.nn.functional as F +from torch import Tensor, optim + +from numalogic.base import TorchModel + + +def _init_criterion(loss_fn: str) -> Callable: + if loss_fn == "huber": + return F.huber_loss + if loss_fn == "l1": + return F.l1_loss + if loss_fn == "mse": + return F.mse_loss + raise ValueError(f"Unsupported loss function provided: {loss_fn}") + + +class BaseVAE(TorchModel): + """ + Abstract Base class for all Pytorch based variational autoencoder models. + + Args: + ---- + lr: learning rate (default: 3e-4) + weight_decay: weight decay factor weight for regularization (default: 0.0) + loss_fn: loss function used to train the model + supported values include: {mse, l1, huber} + """ + + def __init__( + self, + lr: float = 3e-4, + weight_decay: float = 0.0, + loss_fn: str = "mse", + ): + super().__init__() + self._lr = lr + self.weight_decay = weight_decay + self.criterion = _init_criterion(loss_fn) + + def configure_shape(self, x: Tensor) -> Tensor: + """Method to configure the batch shape for each type of model architecture.""" + return x + + def configure_optimizers(self) -> dict: + optimizer = optim.Adam(self.parameters(), lr=self._lr, weight_decay=self.weight_decay) + return {"optimizer": optimizer} + + def recon_loss(self, batch: Tensor, recon: Tensor, reduction: str = "sum"): + return self.criterion(batch, recon, reduction=reduction) + + def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor: + """Validation step for the model.""" + p, recon = self.forward(batch) + loss = self.recon_loss(batch, recon) + self.log("val_loss", loss) + return loss + + def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: int = 0) -> Tensor: + """Prediction step for the model.""" + p, recon = self.forward(batch) + return self.recon_loss(batch, recon, reduction="none") diff --git a/numalogic/models/vae/layer.py b/numalogic/models/vae/layer.py new file mode 100644 index 00000000..54364200 --- /dev/null +++ b/numalogic/models/vae/layer.py @@ -0,0 +1,62 @@ +from torch import nn, Tensor +import torch.nn.functional as F + + +class CausalConv1d(nn.Conv1d): + """Temporal convolutional layer with causal padding.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + ): + super().__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + self.__padding = (kernel_size - 1) * dilation + + def forward(self, x: Tensor) -> Tensor: + return super().forward(F.pad(x, (self.__padding, 0))) + + +class CausalConvBlock(nn.Module): + """Basic convolutional block consisting of: + - causal 1D convolutional layer + - batch norm + - relu activation. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + ): + super().__init__() + self.conv = CausalConv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + ) + self.bnorm = nn.BatchNorm1d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, input_: Tensor) -> Tensor: + return self.relu(self.bnorm(self.conv(input_))) diff --git a/numalogic/models/vae/trainer.py b/numalogic/models/vae/trainer.py new file mode 100644 index 00000000..0800bd36 --- /dev/null +++ b/numalogic/models/vae/trainer.py @@ -0,0 +1,69 @@ +import sys +import warnings +from typing import Optional + +import torch +from torch import Tensor +from pytorch_lightning import Trainer, LightningModule + +from numalogic.tools.callbacks import ConsoleLogger +from numalogic.tools.data import inverse_window + + +class VAETrainer(Trainer): + """A PyTorch Lightning Trainer for VAE models. + + Args: + ---- + max_epochs: The maximum number of epochs to train for. (default: 100) + logger: Whether to use a console logger to log metrics. (default: True) + log_freq: The number of epochs between logging. (default: 5) + check_val_every_n_epoch: The number of epochs between validation checks. (default: 5) + enable_checkpointing: Whether to enable checkpointing. (default: False) + enable_progress_bar: Whether to enable the progress bar. (default: False) + enable_model_summary: Whether to enable the model summary. (default: False) + **trainer_kw: Additional keyword arguments to pass to the Lightning Trainer. + """ + + def __init__( + self, + max_epochs: int = 100, + logger: bool = True, + log_freq: int = 5, + check_val_every_n_epoch: int = 5, + enable_checkpointing: bool = False, + enable_progress_bar: bool = False, + enable_model_summary: bool = False, + **trainer_kw + ): + if not sys.warnoptions: + warnings.simplefilter("ignore", category=UserWarning) + + if logger: + logger = ConsoleLogger(log_freq=log_freq) + + super().__init__( + logger=logger, + max_epochs=max_epochs, + check_val_every_n_epoch=check_val_every_n_epoch, + enable_checkpointing=enable_checkpointing, + enable_progress_bar=enable_progress_bar, + enable_model_summary=enable_model_summary, + **trainer_kw + ) + + def predict(self, model: Optional[LightningModule] = None, unbatch=True, **kwargs) -> Tensor: + r"""Predicts the output of the model. + + Args: + ---- + model: The model to predict with. (default: None) + unbatch: Whether to inverse window the output. (default: True) + **kwargs: Additional keyword arguments to pass to the Lightning + trainers predict method. + """ + recon_err = super().predict(model, **kwargs) + recon_err = torch.vstack(recon_err) + if unbatch: + return inverse_window(recon_err, method="keep_last") + return recon_err diff --git a/numalogic/models/vae/variants/__init__.py b/numalogic/models/vae/variants/__init__.py new file mode 100644 index 00000000..419e5838 --- /dev/null +++ b/numalogic/models/vae/variants/__init__.py @@ -0,0 +1,3 @@ +from numalogic.models.vae.variants.conv import Conv1dVAE + +__all__ = ["Conv1dVAE"] diff --git a/numalogic/models/vae/variants/conv.py b/numalogic/models/vae/variants/conv.py new file mode 100644 index 00000000..303315c0 --- /dev/null +++ b/numalogic/models/vae/variants/conv.py @@ -0,0 +1,248 @@ +from collections.abc import Sequence +from typing import Final + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch.distributions import MultivariateNormal, kl_divergence + +from numalogic.models.vae.base import BaseVAE +from numalogic.models.vae.layer import CausalConvBlock +from numalogic.tools.exceptions import ModelInitializationError + +_DEFAULT_KERNEL_SIZE: Final[int] = 3 +_DEFAULT_STRIDE: Final[int] = 2 + + +class Encoder(nn.Module): + """ + Encoder module for Convolutional Variational Autoencoder. + + Args: + ---- + seq_len: sequence length / window length + n_features: num of features + latent_dim: latent dimension + conv_channels: number of convolutional channels + num_samples: number of samples to draw from the latent distribution + """ + + def __init__( + self, + seq_len: int, + n_features: int, + latent_dim: int, + conv_channels: Sequence[int] = (16,), + num_samples: int = 10, + ): + super().__init__() + + self.seq_len = seq_len + self.nsamples = num_samples + + conv_layer = CausalConvBlock( + in_channels=n_features, + out_channels=conv_channels[0], + kernel_size=_DEFAULT_KERNEL_SIZE, + stride=_DEFAULT_STRIDE, + dilation=1, + ) + layers = self._construct_conv_layers(conv_channels) + if layers: + self.conv_layers = nn.Sequential(conv_layer, *layers) + else: + self.conv_layers = conv_layer + + self.flatten = nn.Flatten(start_dim=1) + self.fc = nn.LazyLinear(latent_dim) + self.mu = nn.Linear(latent_dim, latent_dim) + self.logvar = nn.Linear(latent_dim, latent_dim) + + @staticmethod + def _construct_conv_layers(conv_channels) -> nn.ModuleList: + """Construct dilated causal convolutional layers.""" + layers = nn.ModuleList() + layer_idx = 1 + while layer_idx < len(conv_channels): + layers.append( + CausalConvBlock( + conv_channels[layer_idx - 1], + conv_channels[layer_idx], + kernel_size=_DEFAULT_KERNEL_SIZE, + stride=_DEFAULT_STRIDE, + dilation=2**layer_idx, + ) + ) + layer_idx += 1 + return layers + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """ + Forward pass returning the mean and + log variance of the latent distribution. + + Args: + ---- + x: input tensor of shape (batch_size, n_features, seq_len) + + Returns + ------- + A tuple of: + mu: mean of the latent distribution + logvar: log variance of the latent distribution + """ + out = self.conv_layers(x) + out = self.flatten(out) + out = torch.relu(self.fc(out)) + mu = self.mu(out) + logvar = F.softplus(self.logvar(out)) + return mu, logvar + + +class Decoder(nn.Module): + """ + Decoder (non-probabilistic) module for Convolutional Variational Autoencoder. + + Args: + ---- + seq_len: sequence length / window length + n_features: num of features + num_conv_filters: number of convolutional filters + latent_dim: latent dimension + """ + + def __init__(self, seq_len: int, n_features: int, num_conv_filters: int, latent_dim: int): + super().__init__() + self.seq_len = seq_len + self.n_features = n_features + self.fc = nn.Linear(latent_dim, num_conv_filters * 6) + self.unflatten = nn.Unflatten(dim=1, unflattened_size=(num_conv_filters, 6)) + self.conv_tr = nn.ConvTranspose1d( + in_channels=num_conv_filters, + out_channels=n_features, + kernel_size=_DEFAULT_KERNEL_SIZE, + stride=_DEFAULT_STRIDE, + padding=1, + output_padding=1, + ) + self.bnorm = nn.BatchNorm1d(n_features) + self.fc_out = nn.LazyLinear(seq_len) + self.td_linear = nn.Linear(n_features, n_features) + + def forward(self, z: Tensor) -> Tensor: + out = torch.relu(self.fc(z)) + out = self.unflatten(out) + out = torch.relu(self.bnorm(self.conv_tr(out))) + out = torch.relu(self.fc_out(out)) + out = out.view(-1, self.seq_len, self.n_features) + return self.td_linear(out) + + +class Conv1dVAE(BaseVAE): + """ + Convolutional Variational Autoencoder for time series data. + + Uses causal convolutions to preserve temporal information in + the encoded latent space. The decoder non probabilsitc, and + conists of transposed convolutions and linear layers. + + Note: The model assumes that the input data is of shape + (batch_size, n_features, seq_len). + + Args: + ---- + seq_len: sequence length / window length + n_features: num of features + conv_channels: number of convolutional channels + latent_dim: latent dimension + num_samples: number of samples to draw from the latent distribution + + Raises + ------ + ValueError: if an unsupported loss function is provided + ModuleInitializationError: if initialization of the model fails due to invalid input, + invalid hyperparameters, or invalid convolutional kernel size / stride + """ + + def __init__( + self, + seq_len: int, + n_features: int, + latent_dim: int, + conv_channels: Sequence[int] = (16,), + num_samples: int = 10, + **kwargs, + ): + super().__init__(**kwargs) + self.seq_len = seq_len + self.z_dim = latent_dim + self.n_features = n_features + self.nsamples = num_samples + + self.encoder = Encoder( + seq_len=seq_len, + n_features=n_features, + conv_channels=conv_channels, + latent_dim=latent_dim, + num_samples=num_samples, + ) + self.decoder = Decoder( + seq_len=seq_len, + n_features=n_features, + num_conv_filters=conv_channels[0], + latent_dim=latent_dim, + ) + + # Do a dry run to initialize lazy modules + try: + self.forward(torch.rand(1, seq_len, n_features)) + except (ValueError, RuntimeError) as err: + raise ModelInitializationError( + "Model forward pass failed. " + "Please validate input arguments and the expected input shape " + ) from err + + def forward(self, x: Tensor) -> tuple[MultivariateNormal, Tensor]: + x = self.configure_shape(x) + z_mu, z_logvar = self.encoder(x) + p = MultivariateNormal(loc=z_mu, covariance_matrix=torch.diag_embed(z_logvar.exp())) + samples = p.rsample(sample_shape=torch.Size([self.nsamples])) + z = torch.mean(samples, dim=0) + x_recon = self.decoder(z) + return p, x_recon + + def configure_shape(self, x: Tensor) -> Tensor: + """Method to configure the batch shape for each type of model architecture.""" + return x.view(-1, self.n_features, self.seq_len) + + def kld_loss(self, p: MultivariateNormal) -> Tensor: + """ + Computes the reverse KL divergence between latent distribution and + the known Multivariate Gaussian prior. + + Args: + ---- + p: MultivariateNormal distribution + + Returns + ------- + kld: Reverse KL divergence between p and q + """ + q = MultivariateNormal(torch.zeros(self.z_dim), torch.eye(self.z_dim)) + kld = kl_divergence(q, p) + return kld.sum() + + def training_step(self, batch: Tensor, batch_idx: int) -> Tensor: + """Training step for the model.""" + p, recon = self.forward(batch) + kld_loss = self.kld_loss(p) + recon_loss = self.recon_loss(batch, recon) + self.log_dict( + { + "train_kld_loss": kld_loss, + "train_recon_loss": recon_loss, + }, + on_epoch=True, + on_step=False, + ) + return kld_loss + recon_loss diff --git a/numalogic/tools/callbacks.py b/numalogic/tools/callbacks.py index 55b63f7b..ac1a2af3 100644 --- a/numalogic/tools/callbacks.py +++ b/numalogic/tools/callbacks.py @@ -11,9 +11,12 @@ import logging +from typing import Optional, Union import pytorch_lightning as pl +from lightning_utilities.core.rank_zero import rank_zero_only from pytorch_lightning.callbacks import ProgressBar +from pytorch_lightning.loggers import Logger _LOGGER = logging.getLogger(__name__) @@ -50,3 +53,57 @@ def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningMo loss = pl_module.total_val_loss / trainer.num_val_batches[0] _LOGGER.info("validation_loss=%.5f", loss) pl_module.reset_val_loss() + + +class ConsoleLogger(Logger): + """ + A lightweight console logger for training metrics. + + Args: + ---- + log_freq: Interval of epochs to log + experiment_id: Optional experiment id, e.g. a uuid + experiment_name: Optional experiment name, e.g. a model name + """ + + def __init__( + self, + log_freq: int = 5, + experiment_id: Optional[str] = None, + experiment_name: Optional[str] = None, + ): + self.log_freq = log_freq + self._id = experiment_id + self._name = experiment_name + + @property + def version(self) -> Optional[Union[int, str]]: + return self._id + + @property + def name(self) -> Optional[str]: + return self._name + + @property + def experiment(self) -> Optional[str]: + return self._name + + def log_hyperparams(self, params, *args, **kwargs): + raise NotImplementedError("ConsoleLogger does not log hyperparameters") + + @rank_zero_only + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + epoch = metrics.pop("epoch", None) + if epoch is None: + _LOGGER.info("metrics=%s", metrics) + return + epoch += 1 + if epoch == 1: + _LOGGER.info(self._format_metrics(epoch, metrics)) + elif not (epoch % self.log_freq): + _LOGGER.info(self._format_metrics(epoch, metrics)) + + @staticmethod + def _format_metrics(epoch, metrics: dict[str, float]) -> str: + log_msg = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items()) + return f"epoch={epoch}, {log_msg}" diff --git a/numalogic/tools/data.py b/numalogic/tools/data.py index 7a221d3b..08b7e555 100644 --- a/numalogic/tools/data.py +++ b/numalogic/tools/data.py @@ -125,6 +125,10 @@ def as_array(self) -> npt.NDArray[float]: """Returns the full data in a sequence of shape (batch, seq_len, num_features).""" return self[:] + def as_tensor(self) -> Tensor: + """Returns the full data in a sequence of shape (batch, seq_len, num_features).""" + return torch.from_numpy(self[:]).contiguous() + def create_seq(self, input_: npt.NDArray[float]) -> Generator[npt.NDArray[float], None, None]: r"""Yields sequences of specified length from the input data. diff --git a/numalogic/tools/exceptions.py b/numalogic/tools/exceptions.py index 1dde7c08..82f2f196 100644 --- a/numalogic/tools/exceptions.py +++ b/numalogic/tools/exceptions.py @@ -10,7 +10,7 @@ # limitations under the License. -class ModelInitializationError(Exception): +class ModelInitializationError(RuntimeError): """Raised when a model is not initialized properly.""" pass diff --git a/tests/models/vae/__init__.py b/tests/models/vae/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/vae/test_conv.py b/tests/models/vae/test_conv.py new file mode 100644 index 00000000..f78648dc --- /dev/null +++ b/tests/models/vae/test_conv.py @@ -0,0 +1,122 @@ +import logging +import os +import unittest + +import pandas as pd +import torch +from sklearn.preprocessing import StandardScaler +from torch import nn, Tensor +from torch.utils.data import DataLoader + +from numalogic._constants import TESTS_DIR +from numalogic.models.vae import VAETrainer +from numalogic.models.vae.variants import Conv1dVAE +from numalogic.tools.data import TimeseriesDataModule, StreamingDataset +from numalogic.tools.exceptions import ModelInitializationError + +ROOT_DIR = os.path.join(TESTS_DIR, "resources", "data") +DATA_FILE = os.path.join(ROOT_DIR, "interactionstatus.csv") +EPOCHS = 2 +BATCH_SIZE = 32 +SEQ_LEN = 12 +LR = 0.001 +ACCELERATOR = "cuda" if torch.cuda.is_available() else "cpu" +torch.manual_seed(42) + + +logging.basicConfig(level=logging.INFO) + + +class TestConv1dVAE(unittest.TestCase): + x_train = None + x_val = None + + @classmethod + def setUpClass(cls) -> None: + df = pd.read_csv(DATA_FILE) + df = df[["success", "failure"]] + scaler = StandardScaler() + cls.x_train = scaler.fit_transform(df[:-240]) + cls.x_val = scaler.transform(df[-240:]) + + def test_model_01(self): + model = Conv1dVAE(seq_len=SEQ_LEN, n_features=2, latent_dim=1, loss_fn="l1") + datamodule = TimeseriesDataModule(SEQ_LEN, self.x_train, batch_size=BATCH_SIZE) + trainer = VAETrainer(accelerator=ACCELERATOR, max_epochs=EPOCHS, fast_dev_run=True) + trainer.fit(model, datamodule=datamodule) + + streamloader = DataLoader(StreamingDataset(self.x_val, SEQ_LEN), batch_size=BATCH_SIZE) + stream_trainer = VAETrainer(accelerator=ACCELERATOR) + test_reconerr = stream_trainer.predict(model, dataloaders=streamloader) + test_reconerr_w_seq = stream_trainer.predict(model, dataloaders=streamloader, unbatch=False) + + self.assertTupleEqual(self.x_val.shape, test_reconerr.shape) + self.assertTupleEqual(streamloader.dataset.as_tensor().shape, test_reconerr_w_seq.shape) + + def test_model_02(self): + model = Conv1dVAE(seq_len=SEQ_LEN, n_features=2, latent_dim=1, conv_channels=(8, 4)) + trainer = VAETrainer(accelerator=ACCELERATOR, max_epochs=EPOCHS, log_freq=1) + trainer.fit( + model, + train_dataloaders=DataLoader( + StreamingDataset(self.x_train, SEQ_LEN), batch_size=BATCH_SIZE + ), + ) + + test_ds = StreamingDataset(self.x_val, SEQ_LEN) + + model.eval() + with torch.no_grad(): + _, recon = model(test_ds.as_tensor()) + + self.assertTupleEqual(test_ds.as_tensor().size(), recon.shape) + self.assertEqual(recon.dim(), 3) + + def test_native_train(self): + model = Conv1dVAE( + seq_len=SEQ_LEN, + n_features=2, + latent_dim=1, + loss_fn="huber", + ) + optimizer = torch.optim.Adam(model.parameters(), lr=LR) + criterion = nn.HuberLoss(delta=0.5) + + train_loader = DataLoader( + StreamingDataset(self.x_train, seq_len=SEQ_LEN), batch_size=BATCH_SIZE + ) + + model.train() + loss = Tensor([0.0]) + for epoch in range(1, EPOCHS + 1): + for _X_batch in train_loader: + optimizer.zero_grad() + encoded, decoded = model(_X_batch) + decoded = decoded.view(-1, SEQ_LEN, self.x_train.shape[1]) + + loss = criterion(decoded, _X_batch) + loss.backward() + optimizer.step() + + if epoch % 5 == 0: + print(f"epoch : {epoch}, loss_mean : {loss.item():.7f}") + + def test_err(self): + with self.assertRaises(ValueError): + Conv1dVAE( + seq_len=SEQ_LEN, + n_features=2, + latent_dim=1, + loss_fn="random", + ) + with self.assertRaises(ModelInitializationError): + Conv1dVAE( + seq_len=SEQ_LEN, + n_features=2, + latent_dim=1, + conv_channels=(8, 4, 2, 1), + ) + + +if __name__ == "__main__": + unittest.main()