-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat!: convert AE variants to lightning modules (#110)
Signed-off-by: Avik Basu <[email protected]>
- Loading branch information
Showing
30 changed files
with
3,118 additions
and
1,759 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
from numalogic.models.autoencoder.factory import ModelPlFactory | ||
from numalogic.models.autoencoder.pipeline import AutoencoderPipeline, SparseAEPipeline | ||
from numalogic.models.autoencoder.trainer import AutoencoderTrainer | ||
|
||
__all__ = ["AutoencoderPipeline", "SparseAEPipeline", "ModelPlFactory"] | ||
__all__ = ["AutoencoderTrainer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,60 @@ | ||
from abc import ABCMeta, abstractmethod | ||
from typing import Tuple | ||
from abc import ABCMeta | ||
|
||
from torch import nn, Tensor | ||
from torch.utils.data import Dataset | ||
from torchinfo import summary | ||
import pytorch_lightning as pl | ||
import torch.nn.functional as F | ||
from torch import Tensor, optim | ||
|
||
|
||
class TorchAE(nn.Module, metaclass=ABCMeta): | ||
class BaseAE(pl.LightningModule, metaclass=ABCMeta): | ||
""" | ||
Abstract Base class for all Pytorch based autoencoder models for time-series data. | ||
""" | ||
|
||
def __repr__(self) -> str: | ||
return str(summary(self)) | ||
def __init__(self, loss_fn: str = "huber", optim_algo: str = "adam", lr: float = 1e-3): | ||
super().__init__() | ||
self.lr = lr | ||
self.optim_algo = optim_algo | ||
self.criterion = self.init_criterion(loss_fn) | ||
|
||
def summary(self, input_shape: Tuple[int, ...]) -> None: | ||
print(summary(self, input_size=input_shape)) | ||
@staticmethod | ||
def init_criterion(loss_fn: str): | ||
if loss_fn == "huber": | ||
return F.huber_loss | ||
if loss_fn == "l1": | ||
return F.l1_loss | ||
if loss_fn == "mse": | ||
return F.mse_loss | ||
raise NotImplementedError(f"Unsupported loss function provided: {loss_fn}") | ||
|
||
@abstractmethod | ||
def construct_dataset(self, x: Tensor, seq_len: int = None) -> Dataset: | ||
""" | ||
Returns a dataset instance to be used for training. | ||
Needs to be overridden. | ||
""" | ||
pass | ||
def init_optimizer(self, optim_algo: str): | ||
if optim_algo == "adam": | ||
return optim.Adam(self.parameters(), lr=self.lr) | ||
if optim_algo == "adagrad": | ||
return optim.Adagrad(self.parameters(), lr=self.lr) | ||
if optim_algo == "rmsprop": | ||
return optim.RMSprop(self.parameters(), lr=self.lr) | ||
raise NotImplementedError(f"Unsupported optimizer value provided: {optim_algo}") | ||
|
||
def _get_reconstruction_loss(self, batch): | ||
_, recon = self.forward(batch) | ||
return self.criterion(batch, recon) | ||
|
||
def reconstruction(self, batch: Tensor) -> Tensor: | ||
_, recon = self.forward(batch) | ||
return recon | ||
|
||
def configure_optimizers(self): | ||
optimizer = self.init_optimizer(self.optim_algo) | ||
return {"optimizer": optimizer} | ||
|
||
def training_step(self, batch, batch_idx): | ||
loss = self._get_reconstruction_loss(batch) | ||
return loss | ||
|
||
def validation_step(self, batch, batch_idx): | ||
loss = self._get_reconstruction_loss(batch) | ||
return loss | ||
|
||
def test_step(self, batch, batch_idx): | ||
loss = self._get_reconstruction_loss(batch) | ||
return loss |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.