Skip to content

Commit

Permalink
Added multichannel autoencoder with test cases (#382)
Browse files Browse the repository at this point in the history
Explain what this PR does.

---------

Signed-off-by: mboussarov <[email protected]>
  • Loading branch information
mboussarov committed May 31, 2024
1 parent 934d8cf commit 309f16a
Show file tree
Hide file tree
Showing 6 changed files with 1,171 additions and 3 deletions.
1,054 changes: 1,054 additions & 0 deletions examples/multichannel_ae.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions numalogic/config/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class ModelFactory(_ObjectFactory):
from numalogic.models.autoencoder.variants import (
VanillaAE,
SparseVanillaAE,
MultichannelAE,
Conv1dAE,
SparseConv1dAE,
LSTMAE,
Expand All @@ -135,6 +136,7 @@ class ModelFactory(_ObjectFactory):
_CLS_MAP: ClassVar[dict] = {
"VanillaAE": VanillaAE,
"SparseVanillaAE": SparseVanillaAE,
"MultichannelAE": MultichannelAE,
"Conv1dAE": Conv1dAE,
"SparseConv1dAE": SparseConv1dAE,
"LSTMAE": LSTMAE,
Expand Down
3 changes: 2 additions & 1 deletion numalogic/models/autoencoder/variants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from numalogic.models.autoencoder.variants.vanilla import VanillaAE, SparseVanillaAE
from numalogic.models.autoencoder.variants.vanilla import VanillaAE, SparseVanillaAE, MultichannelAE
from numalogic.models.autoencoder.variants.conv import Conv1dAE, SparseConv1dAE
from numalogic.models.autoencoder.variants.lstm import LSTMAE, SparseLSTMAE
from numalogic.models.autoencoder.variants.transformer import TransformerAE, SparseTransformerAE
Expand All @@ -7,6 +7,7 @@

__all__ = [
"VanillaAE",
"MultichannelAE",
"SparseVanillaAE",
"Conv1dAE",
"SparseConv1dAE",
Expand Down
99 changes: 99 additions & 0 deletions numalogic/models/autoencoder/variants/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from numalogic.models.autoencoder.base import BaseAE
from numalogic.tools.exceptions import LayerSizeMismatchError

EMPTY_TENSOR = torch.empty(0)


class _VanillaEncoder(nn.Module):
r"""Encoder module for the VanillaAE.
Expand Down Expand Up @@ -207,6 +209,103 @@ def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: int = 0):
return self.criterion(batch, recon, reduction="none")


class MultichannelAE(BaseAE):
r"""Multichannel Vanilla Autoencoder model based on the vanilla encoder and decoder.
Each channel is an isolated neural network.
Args:
----
seq_len: sequence length / window length
n_channels: num of channels, each channel is a separate neural network
encoder_layersizes: encoder layer size (default = Sequence[int] = (16, 8))
decoder_layersizes: decoder layer size (default = Sequence[int] = (8, 16))
dropout_p: the dropout value (default=0.25)
batchnorm: Flag to enable batch normalization (default=False)
encoderinfo: Flag to enable returning encoder information in the "forward" step
(default=False)
**kwargs: BaseAE kwargs
"""

def __init__(
self,
seq_len: int,
n_channels: int,
encoder_layersizes: Sequence[int] = (16, 8),
decoder_layersizes: Sequence[int] = (8, 16),
dropout_p: float = 0.25,
batchnorm: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.seq_len = seq_len
self.dropout_prob = dropout_p
self.n_channels = n_channels
# The number of features per channel default to 1 in this architecture
self.n_features = 1

if encoder_layersizes[-1] != decoder_layersizes[0]:
raise LayerSizeMismatchError(
f"Last layersize of encoder: {encoder_layersizes[-1]} "
f"does not match first layersize of decoder: {decoder_layersizes[0]}"
)

for i in range(self.n_channels):
encoder = _VanillaEncoder(
seq_len=seq_len,
n_features=self.n_features,
layersizes=encoder_layersizes,
dropout_p=dropout_p,
batchnorm=batchnorm,
)
decoder = _Decoder(
seq_len=seq_len,
n_features=self.n_features,
layersizes=decoder_layersizes,
dropout_p=dropout_p,
batchnorm=batchnorm,
)

encoder.apply(self.init_weights)
decoder.apply(self.init_weights)
setattr(self, f"channel_encoder{i}", encoder)
setattr(self, f"channel_decoder{i}", decoder)

@staticmethod
def init_weights(m: nn.Module) -> None:
"""Initialize the parameters in the model."""
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)

def forward(self, batch: Tensor) -> tuple[Tensor, Tensor]:
encoded_all, decoded_all = [], []
batch = torch.swapdims(batch, 1, 2)

for i in range(self.n_channels):
encoder = getattr(self, f"channel_encoder{i}")
decoder = getattr(self, f"channel_decoder{i}")

batch_channel = batch[:, [i]]

encoded = encoder(batch_channel)
decoded = decoder(encoded)

encoded_all.append(encoded)
decoded_all.append(decoded)

encoded_all = torch.stack(encoded_all, dim=-1)
encoded_all = torch.squeeze(encoded_all, 1)

decoded_all = torch.stack(decoded_all, dim=-1)
decoded_all = torch.squeeze(decoded_all, 1)

return encoded_all, decoded_all

def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: int = 0):
"""Returns reconstruction for streaming input."""
recon = self.reconstruction(batch)
return self.criterion(batch, recon, reduction="none")


class _SparseVanillaEncoder(_VanillaEncoder):
def _construct_layers(self, layersizes: Sequence[int]) -> nn.ModuleList:
r"""Utility function to generate a simple feedforward network layer.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.10.0a1"
version = "0.10.1"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
14 changes: 13 additions & 1 deletion tests/models/autoencoder/variants/test_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from numalogic._constants import TESTS_DIR
from numalogic.tools.data import StreamingDataset, TimeseriesDataModule
from numalogic.tools.trainer import TimeseriesTrainer
from numalogic.models.autoencoder.variants.vanilla import VanillaAE, SparseVanillaAE
from numalogic.models.autoencoder.variants.vanilla import VanillaAE, SparseVanillaAE, MultichannelAE
from numalogic.tools.exceptions import LayerSizeMismatchError

ROOT_DIR = os.path.join(TESTS_DIR, "resources", "data")
Expand Down Expand Up @@ -58,6 +58,18 @@ def test_sparse_vanilla(self):
test_reconerr = stream_trainer.predict(model, dataloaders=streamloader, unbatch=False)
self.assertTupleEqual((229, SEQ_LEN, self.X_train.shape[1]), test_reconerr.size())

def test_multichannel(self):
model = MultichannelAE(seq_len=SEQ_LEN, n_channels=2)

datamodule = TimeseriesDataModule(SEQ_LEN, self.X_train, batch_size=BATCH_SIZE)
trainer = TimeseriesTrainer(fast_dev_run=True, enable_progress_bar=True)
trainer.fit(model, datamodule=datamodule)

streamloader = DataLoader(StreamingDataset(self.X_val, SEQ_LEN), batch_size=BATCH_SIZE)
stream_trainer = TimeseriesTrainer()
test_reconerr = stream_trainer.predict(model, dataloaders=streamloader, unbatch=False)
self.assertTupleEqual((229, SEQ_LEN, self.X_train.shape[1]), test_reconerr.size())

def test_native_train(self):
model = VanillaAE(
SEQ_LEN,
Expand Down

0 comments on commit 309f16a

Please sign in to comment.