Skip to content

Commit

Permalink
feat!: convert AE variants to lightning modules (#110)
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Dec 20, 2022
1 parent ed94615 commit 88d26ec
Show file tree
Hide file tree
Showing 30 changed files with 3,118 additions and 1,759 deletions.
5 changes: 2 additions & 3 deletions numalogic/models/autoencoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from numalogic.models.autoencoder.factory import ModelPlFactory
from numalogic.models.autoencoder.pipeline import AutoencoderPipeline, SparseAEPipeline
from numalogic.models.autoencoder.trainer import AutoencoderTrainer

__all__ = ["AutoencoderPipeline", "SparseAEPipeline", "ModelPlFactory"]
__all__ = ["AutoencoderTrainer"]
68 changes: 51 additions & 17 deletions numalogic/models/autoencoder/base.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,60 @@
from abc import ABCMeta, abstractmethod
from typing import Tuple
from abc import ABCMeta

from torch import nn, Tensor
from torch.utils.data import Dataset
from torchinfo import summary
import pytorch_lightning as pl
import torch.nn.functional as F
from torch import Tensor, optim


class TorchAE(nn.Module, metaclass=ABCMeta):
class BaseAE(pl.LightningModule, metaclass=ABCMeta):
"""
Abstract Base class for all Pytorch based autoencoder models for time-series data.
"""

def __repr__(self) -> str:
return str(summary(self))
def __init__(self, loss_fn: str = "huber", optim_algo: str = "adam", lr: float = 1e-3):
super().__init__()
self.lr = lr
self.optim_algo = optim_algo
self.criterion = self.init_criterion(loss_fn)

def summary(self, input_shape: Tuple[int, ...]) -> None:
print(summary(self, input_size=input_shape))
@staticmethod
def init_criterion(loss_fn: str):
if loss_fn == "huber":
return F.huber_loss
if loss_fn == "l1":
return F.l1_loss
if loss_fn == "mse":
return F.mse_loss
raise NotImplementedError(f"Unsupported loss function provided: {loss_fn}")

@abstractmethod
def construct_dataset(self, x: Tensor, seq_len: int = None) -> Dataset:
"""
Returns a dataset instance to be used for training.
Needs to be overridden.
"""
pass
def init_optimizer(self, optim_algo: str):
if optim_algo == "adam":
return optim.Adam(self.parameters(), lr=self.lr)
if optim_algo == "adagrad":
return optim.Adagrad(self.parameters(), lr=self.lr)
if optim_algo == "rmsprop":
return optim.RMSprop(self.parameters(), lr=self.lr)
raise NotImplementedError(f"Unsupported optimizer value provided: {optim_algo}")

def _get_reconstruction_loss(self, batch):
_, recon = self.forward(batch)
return self.criterion(batch, recon)

def reconstruction(self, batch: Tensor) -> Tensor:
_, recon = self.forward(batch)
return recon

def configure_optimizers(self):
optimizer = self.init_optimizer(self.optim_algo)
return {"optimizer": optimizer}

def training_step(self, batch, batch_idx):
loss = self._get_reconstruction_loss(batch)
return loss

def validation_step(self, batch, batch_idx):
loss = self._get_reconstruction_loss(batch)
return loss

def test_step(self, batch, batch_idx):
loss = self._get_reconstruction_loss(batch)
return loss
19 changes: 0 additions & 19 deletions numalogic/models/autoencoder/factory.py

This file was deleted.

Loading

0 comments on commit 88d26ec

Please sign in to comment.