diff --git a/transformer/Modules.py b/transformer/Modules.py index c1d1d4c..fdd2a59 100644 --- a/transformer/Modules.py +++ b/transformer/Modules.py @@ -47,8 +47,8 @@ def forward(self, z): if z.size(1) == 1: return z - mu = torch.mean(z, dim=1) - sigma = torch.std(z, dim=1) + mu = torch.mean(z, dim=-1) + sigma = torch.std(z, dim=-1) ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)