diff --git a/rotograd/rotograd.py b/rotograd/rotograd.py index 80b82c9..331062f 100644 --- a/rotograd/rotograd.py +++ b/rotograd/rotograd.py @@ -1,4 +1,4 @@ -from typing import Sequence, Any, Optional +from typing import Sequence, List, Any, Optional import torch import torch.nn as nn @@ -136,14 +136,14 @@ def weight(self): return self.p.weight[self.item] if hasattr(self.p, 'weight') else 1. def rotate(self, z): - return rotate(z, self.R, self.p.input_size) + return rotate(z, self.R, self.p.latent_size) def rotate_back(self, z): - return rotate_back(z, self.R, self.p.input_size) + return rotate_back(z, self.R, self.p.latent_size) def forward(self, z): R = self.R.clone().detach() - new_z = rotate(z, R, self.p.input_size) + new_z = rotate(z, R, self.p.latent_size) if self.p.training: new_z.register_hook(self.hook) @@ -189,7 +189,7 @@ class RotateOnly(nn.Module): heads: Sequence[nn.Module] rep: Optional[torch.Tensor] - def __init__(self, backbone: nn.Module, heads: Sequence[nn.Module], latent_size: int, *args, + def __init__(self, backbone: nn.Module, heads: List[nn.Module], latent_size: int, *args, normalize_losses: bool = False): super(RotateOnly, self).__init__() num_tasks = len(heads) @@ -337,9 +337,8 @@ def backward(self, losses: Sequence[torch.Tensor], backbone_loss=None, **kwargs) else: backbone_loss.backward(retain_graph=True) - self.rep.backward(self._rep_grad) + self.rep.backward(self._rep_grad()) - @property def _rep_grad(self): old_grads = self.original_grads # these grads are already rotated, we have to recover the originals # with torch.no_grad(): @@ -355,7 +354,7 @@ def _rep_grad(self): for i, grad in enumerate(grads): R = self.rotation[i] - loss_rotograd = rotate(mean_grad, R, self.input_size) - grad + loss_rotograd = rotate(mean_grad, R, self.latent_size) - grad loss_rotograd = torch.einsum('bi,bi->b', loss_rotograd, loss_rotograd) loss_rotograd.mean().backward() @@ -425,13 +424,12 @@ def __init__(self, backbone: nn.Module, heads: Sequence[nn.Module], latent_size: self.initial_grads = None self.counter = 0 - @property def _rep_grad(self): - super()._rep_grad + super()._rep_grad() grad_norms = [torch.norm(g, keepdim=True).clamp_min(1e-15) for g in self.original_grads] - if self.initial_grads is None or self.counter == self.update_at: + if self.initial_grads is None or self.counter == self.burn_in_period: self.initial_grads = grad_norms self.counter += 1 @@ -502,9 +500,8 @@ def weight(self) -> Sequence[torch.Tensor]: norm_coef = self.num_tasks / sum(ws) return [w * norm_coef for w in ws] - @property def _rep_grad(self): - super()._rep_grad + super()._rep_grad() grads_norm = [g.norm(p=2) for g in self.original_grads]