Skip to content

Commit

Permalink
RotoGrad hotfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianjav committed Feb 13, 2022
1 parent 32838ae commit 4365828
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions rotograd/rotograd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Any, Optional
from typing import Sequence, List, Any, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -136,14 +136,14 @@ def weight(self):
return self.p.weight[self.item] if hasattr(self.p, 'weight') else 1.

def rotate(self, z):
return rotate(z, self.R, self.p.input_size)
return rotate(z, self.R, self.p.latent_size)

def rotate_back(self, z):
return rotate_back(z, self.R, self.p.input_size)
return rotate_back(z, self.R, self.p.latent_size)

def forward(self, z):
R = self.R.clone().detach()
new_z = rotate(z, R, self.p.input_size)
new_z = rotate(z, R, self.p.latent_size)
if self.p.training:
new_z.register_hook(self.hook)

Expand Down Expand Up @@ -189,7 +189,7 @@ class RotateOnly(nn.Module):
heads: Sequence[nn.Module]
rep: Optional[torch.Tensor]

def __init__(self, backbone: nn.Module, heads: Sequence[nn.Module], latent_size: int, *args,
def __init__(self, backbone: nn.Module, heads: List[nn.Module], latent_size: int, *args,
normalize_losses: bool = False):
super(RotateOnly, self).__init__()
num_tasks = len(heads)
Expand Down Expand Up @@ -337,9 +337,8 @@ def backward(self, losses: Sequence[torch.Tensor], backbone_loss=None, **kwargs)
else:
backbone_loss.backward(retain_graph=True)

self.rep.backward(self._rep_grad)
self.rep.backward(self._rep_grad())

@property
def _rep_grad(self):
old_grads = self.original_grads # these grads are already rotated, we have to recover the originals
# with torch.no_grad():
Expand All @@ -355,7 +354,7 @@ def _rep_grad(self):

for i, grad in enumerate(grads):
R = self.rotation[i]
loss_rotograd = rotate(mean_grad, R, self.input_size) - grad
loss_rotograd = rotate(mean_grad, R, self.latent_size) - grad
loss_rotograd = torch.einsum('bi,bi->b', loss_rotograd, loss_rotograd)
loss_rotograd.mean().backward()

Expand Down Expand Up @@ -425,13 +424,12 @@ def __init__(self, backbone: nn.Module, heads: Sequence[nn.Module], latent_size:
self.initial_grads = None
self.counter = 0

@property
def _rep_grad(self):
super()._rep_grad
super()._rep_grad()

grad_norms = [torch.norm(g, keepdim=True).clamp_min(1e-15) for g in self.original_grads]

if self.initial_grads is None or self.counter == self.update_at:
if self.initial_grads is None or self.counter == self.burn_in_period:
self.initial_grads = grad_norms
self.counter += 1

Expand Down Expand Up @@ -502,9 +500,8 @@ def weight(self) -> Sequence[torch.Tensor]:
norm_coef = self.num_tasks / sum(ws)
return [w * norm_coef for w in ws]

@property
def _rep_grad(self):
super()._rep_grad
super()._rep_grad()

grads_norm = [g.norm(p=2) for g in self.original_grads]

Expand Down

0 comments on commit 4365828

Please sign in to comment.