diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 7205d3a08a0c..6a5c416373f0 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -34,7 +34,7 @@ def __call__(self, x): else: batch_mean, batch_var = self.running_mean, self.running_var # NOTE: this can be precomputed for static inference. if you manually update running_var, you have to reset this - if not hasattr(self, "batch_invstd"): + if not hasattr(self, "batch_invstd") or not self.batch_invstd: self.batch_invstd = batch_var.add(self.eps)**-0.5 batch_invstd = self.batch_invstd