Skip to content

Commit

Permalink
fix: vae nsamples
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Jan 17, 2024
1 parent 16c02f1 commit 03faf8f
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions numalogic/models/vae/variants/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,7 @@ def forward(self, x: Tensor) -> tuple[MultivariateNormal, Tensor]:
x = self.configure_shape(x)
z_mu, z_logvar = self.encoder(x)
p = MultivariateNormal(loc=z_mu, covariance_matrix=torch.diag_embed(z_logvar.exp()))
samples = p.rsample(sample_shape=torch.Size([self.nsamples]))
z = torch.mean(samples, dim=0)
z = p.rsample()
x_recon = self.decoder(z)
return p, x_recon

Expand Down

0 comments on commit 03faf8f

Please sign in to comment.