Skip to content

Commit

Permalink
change forward's output from tuple to dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
moskomule committed Feb 4, 2023
1 parent 6e9e8fd commit a6cdeac
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 50 deletions.
38 changes: 19 additions & 19 deletions examples/dataset_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from hypergrad.approx_hypergrad import conjugate_gradient, neumann, nystrom
from hypergrad.optimizers import diff_sgd
from hypergrad.solver import BaseImplicitSolver
from hypergrad.solver import BaseImplicitSolver, ForwardOutput
from hypergrad.utils import Params


Expand Down Expand Up @@ -55,16 +55,16 @@ def __init__(self, *args, **kwargs):
def target(self):
return torch.arange(10, device=self.device).repeat(self.outer.num_per_class)

def inner_obj(self,
in_params: Params,
out_params: Params,
input: torch.Tensor,
target: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
def inner_forward(self,
in_params: Params,
out_params: Params,
input: torch.Tensor,
target: torch.Tensor
) -> ForwardOutput:
input, = out_params
output = self.inner_func(in_params, input)
loss = F.cross_entropy(output, target)
return loss, output
return ForwardOutput(loss, output)

def outer_update(self) -> None:
in_input, = tuple(self.outer.parameters())
Expand All @@ -73,29 +73,29 @@ def outer_update(self) -> None:
in_params = tuple(self.inner_params)
_, out_params = functorch.make_functional(self.outer, disable_autograd_tracking=True)

implicit_grads = self.approx_ihvp(lambda i, o: self.inner_obj(i, o, in_input, in_target)[0],
lambda i, o: self.outer_obj(i, o, out_input, out_target)[0],
implicit_grads = self.approx_ihvp(lambda i, o: self.inner_forward(i, o, in_input, in_target).loss,
lambda i, o: self.outer_forward(i, o, out_input, out_target).loss,
in_params, out_params)
self.outer_grad = implicit_grads
self.outer_optimizer.step()
self.outer.zero_grad(set_to_none=True)

loss, output = self.outer_obj(in_params, out_params, out_input, out_target)
loss, output = self.outer_forward(in_params, out_params, out_input, out_target)
self.recorder.add('outer_loss', loss.detach())
self.recorder.add('outer_acc', accuracy(output, out_target))

def outer_obj(self,
in_params: Params,
out_params: Params,
input: torch.Tensor,
target: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
def outer_forward(self,
in_params: Params,
out_params: Params,
input: torch.Tensor,
target: torch.Tensor
) -> ForwardOutput:
output = self.inner_func(in_params, input)
return F.cross_entropy(output, target), output
return ForwardOutput(F.cross_entropy(output, target), output)

def inner_update(self) -> None:
target = self.target()
grads, (loss, output) = functorch.grad_and_value(self.inner_obj,
grads, (loss, output) = functorch.grad_and_value(self.inner_forward,
has_aux=True)(self.inner_params,
tuple(self.outer.parameters()), None,
target)
Expand Down
38 changes: 19 additions & 19 deletions examples/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from hypergrad.approx_hypergrad import conjugate_gradient, neumann, nystrom
from hypergrad.optimizers import diff_sgd
from hypergrad.solver import BaseImplicitSolver
from hypergrad.solver import BaseImplicitSolver, ForwardOutput
from hypergrad.utils import Params


Expand All @@ -26,47 +26,47 @@ def generate_data(num_data: int,
class Solver(BaseImplicitSolver):
def inner_update(self) -> None:
input, target = next(self.inner_loader)
grads, (loss, output) = functorch.grad_and_value(self.inner_obj,
grads, (loss, output) = functorch.grad_and_value(self.inner_forward,
has_aux=True)(self.inner_params,
tuple(self.outer.parameters()), input, target)
self.inner_params, self.inner_optim_state = self.inner_optimizer(list(self.inner_params), list(grads),
self.inner_optim_state)
self.recorder.add('inner_loss', loss.detach())

def inner_obj(self,
in_params: Params,
out_params: Params,
input: Tensor,
target: Tensor
) -> tuple[Tensor, Tensor]:
def inner_forward(self,
in_params: Params,
out_params: Params,
input: Tensor,
target: Tensor
) -> ForwardOutput:
output = self.inner_func(in_params, input)
loss = F.binary_cross_entropy_with_logits(output, target)
wd = sum([(_in.pow(2) * _out).sum() for _in, _out in zip(in_params, out_params)])
return loss + wd / 2, output
return ForwardOutput(loss + wd / 2, output)

def outer_update(self) -> None:
in_input, in_target = next(self.inner_loader)
out_input, out_target = next(self.outer_loader)
in_params = tuple(self.inner_params)
_, out_params = functorch.make_functional(self.outer, disable_autograd_tracking=True)

self.recorder.add('outer_loss', self.outer_obj(in_params, out_params, out_input, out_target)[0].detach())
implicit_grads = self.approx_ihvp(lambda i, o: self.inner_obj(i, o, in_input, in_target)[0],
lambda i, o: self.outer_obj(i, o, out_input, out_target)[0],
self.recorder.add('outer_loss', self.outer_forward(in_params, out_params, out_input, out_target).loss.detach())
implicit_grads = self.approx_ihvp(lambda i, o: self.inner_forward(i, o, in_input, in_target).loss,
lambda i, o: self.outer_forward(i, o, out_input, out_target).loss,
in_params, out_params)

self.outer_grad = implicit_grads
self.outer_optimizer.step()
self.outer.zero_grad(set_to_none=True)

def outer_obj(self,
in_params: Params,
out_params: Params,
input: Tensor,
target: Tensor
) -> tuple[Tensor, Tensor]:
def outer_forward(self,
in_params: Params,
out_params: Params,
input: Tensor,
target: Tensor
) -> ForwardOutput:
output = self.inner_func(in_params, input)
return F.binary_cross_entropy_with_logits(output, target), output
return ForwardOutput(F.binary_cross_entropy_with_logits(output, target), output)

def post_outer_update(self) -> None:
self.inner_params[0].zero_()
Expand Down
36 changes: 24 additions & 12 deletions hypergrad/solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Wrapper for bi-level optimization solver
import abc
import collections
import dataclasses
import logging
import math
import statistics
Expand Down Expand Up @@ -43,6 +44,17 @@ def archive(self):
return self._archive.copy()


@dataclasses.dataclass
class ForwardOutput:
loss: Tensor
output: Tensor

def __iter__(self):
# intended to use as `dataclasses.astuple`, but
# it is 50 times faster than using dataclasses.astuple
return iter(self.__dict__.values())


class BaseSolver(abc.ABC):

def __init__(self,
Expand Down Expand Up @@ -151,12 +163,12 @@ def reset_inner_model(self):
self.inner_func, self.inner_params = functorch.make_functional(self.inner, not self._inner_requires_grad)

@abc.abstractmethod
def inner_obj(self,
in_params: Params,
out_params: Params,
input: Tensor,
target: Tensor
) -> tuple[Tensor, Tensor]:
def inner_forward(self,
in_params: Params,
out_params: Params,
input: Tensor,
target: Tensor
) -> ForwardOutput:
# returns loss, output
...

Expand All @@ -165,12 +177,12 @@ def inner_update(self) -> None:
...

@abc.abstractmethod
def outer_obj(self,
in_params: Params,
out_params: Params,
input: Tensor,
target: Tensor
) -> tuple[Tensor, Tensor]:
def outer_forward(self,
in_params: Params,
out_params: Params,
input: Tensor,
target: Tensor
) -> ForwardOutput:
# returns loss, output
...

Expand Down

0 comments on commit a6cdeac

Please sign in to comment.