Skip to content

Commit

Permalink
feat: vanilla ae with independent channels (#392)
Browse files Browse the repository at this point in the history
* VanillaICAE model
* VanillaAE supports flattening internally for multiple features
---------

Signed-off-by: Avik Basu <[email protected]>
Co-authored-by: mboussarov <[email protected]>
Co-authored-by: Avik Basu <[email protected]>
  • Loading branch information
3 people committed Jun 13, 2024
1 parent 0e1a928 commit 8f6c8ef
Show file tree
Hide file tree
Showing 10 changed files with 1,042 additions and 85 deletions.
691 changes: 691 additions & 0 deletions examples/vanilla_ic.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions numalogic/config/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class ModelFactory(_ObjectFactory):
SparseLSTMAE,
TransformerAE,
SparseTransformerAE,
VanillaICAE,
)
from numalogic.models.vae.variants import Conv1dVAE

Expand All @@ -146,6 +147,7 @@ class ModelFactory(_ObjectFactory):
"TransformerAE": TransformerAE,
"SparseTransformerAE": SparseTransformerAE,
"Conv1dVAE": Conv1dVAE,
"VanillaICAE": VanillaICAE,
}


Expand Down
8 changes: 7 additions & 1 deletion numalogic/models/autoencoder/variants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from numalogic.models.autoencoder.variants.vanilla import VanillaAE, SparseVanillaAE, MultichannelAE
from numalogic.models.autoencoder.variants.vanilla import (
VanillaAE,
SparseVanillaAE,
MultichannelAE,
)
from numalogic.models.autoencoder.variants.icvanilla import VanillaICAE
from numalogic.models.autoencoder.variants.conv import Conv1dAE, SparseConv1dAE
from numalogic.models.autoencoder.variants.lstm import LSTMAE, SparseLSTMAE
from numalogic.models.autoencoder.variants.transformer import TransformerAE, SparseTransformerAE
Expand All @@ -16,4 +21,5 @@
"TransformerAE",
"SparseTransformerAE",
"BaseAE",
"VanillaICAE",
]
187 changes: 187 additions & 0 deletions numalogic/models/autoencoder/variants/icvanilla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from collections.abc import Sequence

import torch
from torch import nn, Tensor

from numalogic.models.autoencoder.base import BaseAE
from numalogic.tools.exceptions import LayerSizeMismatchError
from numalogic.tools.layer import IndependentChannelLinear


class _VanillaEncoder(nn.Module):
r"""Encoder module for the VanillaAE.
Args:
----
seq_len: sequence length / window length
n_features: num of features
layersizes: encoder layer size
dropout_p: the dropout value
"""

def __init__(
self,
seq_len: int,
n_features: int,
layersizes: Sequence[int],
dropout_p: float,
batchnorm: bool,
):
super().__init__()
self.seq_len = seq_len
self.n_features = n_features
self.dropout_p = dropout_p
self.bnorm = batchnorm

layers = self._construct_layers(layersizes)
self.encoder = nn.Sequential(*layers)

def _construct_layers(self, layersizes: Sequence[int]) -> nn.ModuleList:
r"""Utility function to generate a simple feedforward network layer.
Args:
----
layersizes: layer size
Returns
-------
A simple feedforward network layer of type nn.ModuleList
"""
layers = nn.ModuleList()
start_layersize = self.seq_len

for lsize in layersizes[:-1]:
_l = [IndependentChannelLinear(start_layersize, lsize, self.n_features)]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))
layers.extend([*_l, nn.Tanh(), nn.Dropout(p=self.dropout_p)])
start_layersize = lsize

_l = [IndependentChannelLinear(start_layersize, layersizes[-1], self.n_features)]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))
layers.extend([*_l, nn.Tanh(), nn.Dropout(p=self.dropout_p)])

return layers

def forward(self, x: Tensor) -> Tensor:
return self.encoder(x)


class _Decoder(nn.Module):
r"""Decoder module for the autoencoder module.
Args:
----
seq_len: sequence length / window length
n_features: num of features
layersizes: decoder layer size
dropout_p: the dropout value
"""

def __init__(
self,
seq_len: int,
n_features: int,
layersizes: Sequence[int],
dropout_p: float,
batchnorm: bool,
):
super().__init__()
self.seq_len = seq_len
self.n_features = n_features
self.dropout_p = dropout_p
self.bnorm = batchnorm

layers = self._construct_layers(layersizes)
self.decoder = nn.Sequential(*layers)

def forward(self, x: Tensor) -> Tensor:
return self.decoder(x)

def _construct_layers(self, layersizes: Sequence[int]) -> nn.ModuleList:
r"""Utility function to generate a simple feedforward network layer.
Args:
----
layersizes: layer size
Returns
-------
A simple feedforward network layer
"""
layers = nn.ModuleList()

for idx, _ in enumerate(layersizes[:-1]):
_l = [IndependentChannelLinear(layersizes[idx], layersizes[idx + 1], self.n_features)]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))
layers.extend([*_l, nn.Tanh(), nn.Dropout(p=self.dropout_p)])

layers.append(IndependentChannelLinear(layersizes[-1], self.seq_len, self.n_features))
return layers


class VanillaICAE(BaseAE):
r"""Vanilla Autoencoder model with Independent isolated Channels based
on the vanilla encoder and decoder. Each channel is an isolated neural network.
Args:
----
seq_len: sequence length / window length
n_channels: num of features/channel, each channel is a separate neural network
encoder_layersizes: encoder layer size (default = Sequence[int] = (16, 8))
decoder_layersizes: decoder layer size (default = Sequence[int] = (8, 16))
dropout_p: the dropout value (default=0.25)
batchnorm: Flag to enable batch normalization (default=False)
**kwargs: BaseAE kwargs
"""

def __init__(
self,
seq_len: int,
n_channels: int = 1,
encoder_layersizes: Sequence[int] = (16, 8),
decoder_layersizes: Sequence[int] = (8, 16),
dropout_p: float = 0.25,
batchnorm: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.seq_len = seq_len
self.dropout_prob = dropout_p
self.n_channels = n_channels

if encoder_layersizes[-1] != decoder_layersizes[0]:
raise LayerSizeMismatchError(
f"Last layersize of encoder: {encoder_layersizes[-1]} "
f"does not match first layersize of decoder: {decoder_layersizes[0]}"
)

self.encoder = _VanillaEncoder(
seq_len=seq_len,
n_features=n_channels,
layersizes=encoder_layersizes,
dropout_p=dropout_p,
batchnorm=batchnorm,
)
self.decoder = _Decoder(
seq_len=seq_len,
n_features=n_channels,
layersizes=decoder_layersizes,
dropout_p=dropout_p,
batchnorm=batchnorm,
)

def forward(self, batch: Tensor) -> tuple[Tensor, Tensor]:
batch = torch.swapdims(batch, 1, 2)
encoded = self.encoder(batch)
decoded = self.decoder(encoded)
return encoded, torch.swapdims(decoded, 1, 2)

def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: int = 0):
"""Returns reconstruction for streaming input."""
recon = self.reconstruction(batch)
return self.criterion(batch, recon, reduction="none")
32 changes: 14 additions & 18 deletions numalogic/models/autoencoder/variants/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from numalogic.models.autoencoder.base import BaseAE
from numalogic.tools.exceptions import LayerSizeMismatchError

EMPTY_TENSOR = torch.empty(0)


class _VanillaEncoder(nn.Module):
r"""Encoder module for the VanillaAE.
Expand All @@ -31,7 +29,7 @@ class _VanillaEncoder(nn.Module):
n_features: num of features
layersizes: encoder layer size
dropout_p: the dropout value
batchnorm: Flag to enable/diasable batch normalization
"""

def __init__(
Expand All @@ -43,13 +41,13 @@ def __init__(
batchnorm: bool,
):
super().__init__()
self.seq_len = seq_len
self.input_size = seq_len * n_features
self.n_features = n_features
self.dropout_p = dropout_p
self.bnorm = batchnorm

layers = self._construct_layers(layersizes)
self.encoder = nn.Sequential(*layers)
self.encoder = nn.Sequential(nn.Flatten(), *layers)

def _construct_layers(self, layersizes: Sequence[int]) -> nn.ModuleList:
r"""Utility function to generate a simple feedforward network layer.
Expand All @@ -63,18 +61,18 @@ def _construct_layers(self, layersizes: Sequence[int]) -> nn.ModuleList:
A simple feedforward network layer of type nn.ModuleList
"""
layers = nn.ModuleList()
start_layersize = self.seq_len
start_layersize = self.input_size

for lsize in layersizes[:-1]:
_l = [nn.Linear(start_layersize, lsize)]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))
_l.append(nn.BatchNorm1d(lsize))
layers.extend([*_l, nn.Tanh(), nn.Dropout(p=self.dropout_p)])
start_layersize = lsize

_l = [nn.Linear(start_layersize, layersizes[-1])]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))
_l.append(nn.BatchNorm1d(layersizes[-1]))
layers.extend([*_l, nn.Tanh(), nn.Dropout(p=self.dropout_p)])

return layers
Expand All @@ -92,7 +90,7 @@ class _Decoder(nn.Module):
n_features: num of features
layersizes: decoder layer size
dropout_p: the dropout value
batchnorm: flag to enable/disable batch normalization
"""

def __init__(
Expand All @@ -104,13 +102,13 @@ def __init__(
batchnorm: bool,
):
super().__init__()
self.seq_len = seq_len
self.out_size = seq_len * n_features
self.n_features = n_features
self.dropout_p = dropout_p
self.bnorm = batchnorm

layers = self._construct_layers(layersizes)
self.decoder = nn.Sequential(*layers)
self.decoder = nn.Sequential(*layers, nn.Unflatten(-1, (n_features, seq_len)))

def forward(self, x: Tensor) -> Tensor:
return self.decoder(x)
Expand All @@ -131,10 +129,10 @@ def _construct_layers(self, layersizes: Sequence[int]) -> nn.ModuleList:
for idx, _ in enumerate(layersizes[:-1]):
_l = [nn.Linear(layersizes[idx], layersizes[idx + 1])]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))
_l.append(nn.BatchNorm1d(layersizes[idx + 1]))
layers.extend([*_l, nn.Tanh(), nn.Dropout(p=self.dropout_p)])

layers.append(nn.Linear(layersizes[-1], self.seq_len))
layers.append(nn.Linear(layersizes[-1], self.out_size))
return layers


Expand Down Expand Up @@ -221,8 +219,6 @@ class MultichannelAE(BaseAE):
decoder_layersizes: decoder layer size (default = Sequence[int] = (8, 16))
dropout_p: the dropout value (default=0.25)
batchnorm: Flag to enable batch normalization (default=False)
encoderinfo: Flag to enable returning encoder information in the "forward" step
(default=False)
**kwargs: BaseAE kwargs
"""

Expand Down Expand Up @@ -319,18 +315,18 @@ def _construct_layers(self, layersizes: Sequence[int]) -> nn.ModuleList:
A simple feedforward network layer of type nn.ModuleList
"""
layers = nn.ModuleList()
start_layersize = self.seq_len
start_layersize = self.input_size

for lsize in layersizes[:-1]:
_l = [nn.Linear(start_layersize, lsize)]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))
_l.append(nn.BatchNorm1d(lsize))
layers.extend([*_l, nn.Tanh(), nn.Dropout(p=self.dropout_p)])
start_layersize = lsize

_l = [nn.Linear(start_layersize, layersizes[-1])]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))
_l.append(nn.BatchNorm1d(layersizes[-1]))
layers.extend([*_l, nn.ReLU()])

return layers
Expand Down
Loading

0 comments on commit 8f6c8ef

Please sign in to comment.