Skip to content

Commit

Permalink
feat: add beta parameter for disentanglement
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Jan 16, 2024
1 parent 5ec5883 commit 16c02f1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
34 changes: 34 additions & 0 deletions numalogic/models/vae/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,37 @@ def __init__(

def forward(self, input_: Tensor) -> Tensor:
return self.relu(self.bnorm(self.conv(input_)))


class ConvTransposeBlock(nn.Module):
"""Basic transpose convolutional block consisting of:
- transpose convolutional layer
- batch norm
- relu activation.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 1,
dilation: int = 1,
output_padding: int = 0,
):
super().__init__()
self.convtranspose = nn.ConvTranspose1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=output_padding,
)
self.bnorm = nn.BatchNorm1d(out_channels)
self.relu = nn.ReLU(inplace=True)

def forward(self, input_: Tensor) -> Tensor:
return self.relu(self.bnorm(self.convtranspose(input_)))
12 changes: 4 additions & 8 deletions numalogic/models/vae/variants/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ def __init__(
n_features: int,
latent_dim: int,
conv_channels: Sequence[int] = (16,),
num_samples: int = 10,
):
super().__init__()

self.seq_len = seq_len
self.nsamples = num_samples

conv_layer = CausalConvBlock(
in_channels=n_features,
Expand Down Expand Up @@ -101,7 +99,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:

class Decoder(nn.Module):
"""
Decoder (non-probabilistic) module for Convolutional Variational Autoencoder.
Decoder module for Convolutional Variational Autoencoder.
Args:
----
Expand Down Expand Up @@ -155,7 +153,6 @@ class Conv1dVAE(BaseVAE):
n_features: num of features
conv_channels: number of convolutional channels
latent_dim: latent dimension
num_samples: number of samples to draw from the latent distribution
Raises
------
Expand All @@ -170,21 +167,20 @@ def __init__(
n_features: int,
latent_dim: int,
conv_channels: Sequence[int] = (16,),
num_samples: int = 10,
beta: float = 1.0,
**kwargs,
):
super().__init__(**kwargs)
self.seq_len = seq_len
self.z_dim = latent_dim
self.n_features = n_features
self.nsamples = num_samples
self.beta = beta

self.encoder = Encoder(
seq_len=seq_len,
n_features=n_features,
conv_channels=conv_channels,
latent_dim=latent_dim,
num_samples=num_samples,
)
self.decoder = Decoder(
seq_len=seq_len,
Expand Down Expand Up @@ -245,4 +241,4 @@ def training_step(self, batch: Tensor, batch_idx: int) -> Tensor:
on_epoch=True,
on_step=False,
)
return kld_loss + recon_loss
return recon_loss + (self.beta * kld_loss)

0 comments on commit 16c02f1

Please sign in to comment.