Skip to content

Commit

Permalink
chore!: use torch and lightning 2.0
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Apr 8, 2023
1 parent c053fb9 commit 85fb527
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 333 deletions.
19 changes: 19 additions & 0 deletions numalogic/models/autoencoder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@ def __init__(self, loss_fn: str = "huber", optim_algo: str = "adam", lr: float =
self.optim_algo = optim_algo
self.criterion = self.init_criterion(loss_fn)

self._total_train_loss = 0.0
self._total_val_loss = 0.0

@property
def total_train_loss(self):
return self._total_train_loss

@property
def total_val_loss(self):
return self._total_val_loss

def reset_train_loss(self):
self._total_train_loss = 0.0

def reset_val_loss(self):
self._total_val_loss = 0.0

@staticmethod
def init_criterion(loss_fn: str):
if loss_fn == "huber":
Expand Down Expand Up @@ -68,8 +85,10 @@ def configure_optimizers(self) -> Dict[str, Any]:

def training_step(self, batch: Tensor, batch_idx: int) -> Tensor:
loss = self._get_reconstruction_loss(batch)
self._total_train_loss += loss.detach().item()
return loss

def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor:
loss = self._get_reconstruction_loss(batch)
self._total_val_loss += loss.detach().item()
return loss
1 change: 1 addition & 0 deletions numalogic/models/autoencoder/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
warnings.simplefilter("ignore", category=UserWarning)

super().__init__(
accelerator="cpu",
logger=logger,
max_epochs=max_epochs,
check_val_every_n_epoch=check_val_every_n_epoch,
Expand Down
15 changes: 8 additions & 7 deletions numalogic/tools/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
import logging

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ProgressBarBase

from pytorch_lightning.callbacks import ProgressBar

_LOGGER = logging.getLogger(__name__)


class ProgressDetails(ProgressBarBase):
class ProgressDetails(ProgressBar):
r"""
A lightweight training progress detail producer.
Expand All @@ -40,12 +39,14 @@ def disable(self):

def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
super().on_train_epoch_end(trainer, pl_module)
metrics = self.get_metrics(trainer, pl_module)
loss = pl_module.total_train_loss / trainer.num_training_batches
curr_epoch = trainer.current_epoch
if curr_epoch % self._log_freq == 0:
_LOGGER.info("epoch=%s, training_loss=%s", curr_epoch, metrics["loss"])
_LOGGER.info("epoch=%s, training_loss=%.5f", curr_epoch, loss)
pl_module.reset_train_loss()

def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
super().on_validation_epoch_end(trainer, pl_module)
metrics = self.get_metrics(trainer, pl_module)
_LOGGER.info("validation_loss=%s", metrics["loss"])
loss = pl_module.total_val_loss / trainer.num_val_batches[0]
_LOGGER.info("validation_loss=%.5f", loss)
pl_module.reset_val_loss()
26 changes: 13 additions & 13 deletions numalogic/tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,27 @@ class TimeseriesDataModule(pl.LightningDataModule):
Args:
seq_len: The length of the sequences to be generated from the input data.
train_data: A numpy array containing the training data in the shape of (batch, num_features)
val_data: A numpy array containing the validation data in the shape of (batch, num_features)
data: A numpy array containing the training data in the shape of (batch, num_features)
val_split_ratio: ratio of data to be used for validation set
batch_size: The size of each batch of data. Defaults to 64.
"""

def __init__(
self,
seq_len: int,
train_data: npt.NDArray[float],
val_data: npt.NDArray[float] = None,
data: npt.NDArray[float],
val_split_ratio: float = 0.1,
batch_size: int = 64,
):
super().__init__()
self.batch_size = batch_size
self.seq_len = seq_len
self.train_data = train_data
self.val_data = val_data
self.data = data

if 0.0 >= val_split_ratio >= 1.0:
raise ValueError("val_split_ratio can only accept values between 0.0 and 1.0")

self.val_split_ratio = val_split_ratio

self.train_dataset = None
self.val_dataset = None
Expand All @@ -127,10 +131,9 @@ def setup(self, stage: str) -> None:
Sets up the data module by initializing the train and validation datasets.
"""
if stage == "fit":
self.train_dataset = StreamingDataset(self.train_data, self.seq_len)
if self.val_data is None:
return
self.val_dataset = StreamingDataset(self.val_data, self.seq_len)
val_size = np.floor(self.val_split_ratio * len(self.data)).astype(int)
self.train_dataset = StreamingDataset(self.data[:-val_size, :], self.seq_len)
self.val_dataset = StreamingDataset(self.data[-val_size:, :], self.seq_len)

def train_dataloader(self) -> TRAIN_DATALOADERS:
r"""
Expand All @@ -142,9 +145,6 @@ def val_dataloader(self) -> Optional[EVAL_DATALOADERS]:
r"""
Creates and returns a DataLoader for the validation dataset if validation data is provided.
"""
if self.val_data is None:
_LOGGER.warning("Validation data not provided in TimeseriesDataModule.")
return None
return DataLoader(self.val_dataset, batch_size=self.batch_size)

@staticmethod
Expand Down
Loading

0 comments on commit 85fb527

Please sign in to comment.