Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
moskomule committed Feb 4, 2023
1 parent 95196d0 commit 9f6b83c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/dataset_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def inner_update(self) -> None:
)(self.inner_params,
tuple(self.outer.parameters()), None,
target)
self.inner_params, self.inner_optim_state = self.inner_optimizer(list(self.inner_params), list(grads),
self.inner_params, self.inner_optim_state = self.inner_optimizer(self.inner_params, grads,
self.inner_optim_state)
self.recorder.add('inner_loss', loss.detach())
self.recorder.add('inner_acc', accuracy(output, target))
Expand Down
22 changes: 11 additions & 11 deletions examples/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from hypergrad.approx_hypergrad import conjugate_gradient, neumann, nystrom
from hypergrad.optimizers import diff_sgd
from hypergrad.solver import BaseImplicitSolver, ForwardOutput
from hypergrad.solver import BaseImplicitSolver
from hypergrad.utils import Params


Expand All @@ -26,9 +26,9 @@ def generate_data(num_data: int,
class Solver(BaseImplicitSolver):
def inner_update(self) -> None:
input, target = next(self.inner_loader)
grads, loss = functorch.grad_and_value(lambda *args: self.inner_forward(*args).loss
)(self.inner_params, tuple(self.outer.parameters()), input, target)
self.inner_params, self.inner_optim_state = self.inner_optimizer(list(self.inner_params), list(grads),
grads, loss = functorch.grad_and_value(self.inner_forward)(self.inner_params, tuple(self.outer.parameters()),
input, target)
self.inner_params, self.inner_optim_state = self.inner_optimizer(self.inner_params, grads,
self.inner_optim_state)
self.recorder.add('inner_loss', loss.detach())

Expand All @@ -37,21 +37,21 @@ def inner_forward(self,
out_params: Params,
input: Tensor,
target: Tensor
) -> ForwardOutput:
) -> Tensor:
output = self.inner_func(in_params, input)
loss = F.binary_cross_entropy_with_logits(output, target)
wd = sum([(_in.pow(2) * _out).sum() for _in, _out in zip(in_params, out_params)])
return ForwardOutput(loss + wd / 2, output)
return loss + wd / 2

def outer_update(self) -> None:
in_input, in_target = next(self.inner_loader)
out_input, out_target = next(self.outer_loader)
in_params = tuple(self.inner_params)
_, out_params = functorch.make_functional(self.outer, disable_autograd_tracking=True)

self.recorder.add('outer_loss', self.outer_forward(in_params, out_params, out_input, out_target).loss.detach())
implicit_grads = self.approx_ihvp(lambda i, o: self.inner_forward(i, o, in_input, in_target).loss,
lambda i, o: self.outer_forward(i, o, out_input, out_target).loss,
self.recorder.add('outer_loss', self.outer_forward(in_params, out_params, out_input, out_target).detach())
implicit_grads = self.approx_ihvp(lambda i, o: self.inner_forward(i, o, in_input, in_target),
lambda i, o: self.outer_forward(i, o, out_input, out_target),
in_params, out_params)

self.outer_grad = implicit_grads
Expand All @@ -63,9 +63,9 @@ def outer_forward(self,
out_params: Params,
input: Tensor,
target: Tensor
) -> ForwardOutput:
) -> Tensor:
output = self.inner_func(in_params, input)
return ForwardOutput(F.binary_cross_entropy_with_logits(output, target), output)
return F.binary_cross_entropy_with_logits(output, target)

def post_outer_update(self) -> None:
self.inner_params[0].zero_()
Expand Down
12 changes: 6 additions & 6 deletions hypergrad/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from typing import TypeAlias

import torch

Params: TypeAlias = list[torch.Tensor]
from hypergrad.utils import Params


def diff_sgd(params: Params,
Expand All @@ -15,8 +11,12 @@ def diff_sgd(params: Params,
nesterov: bool = False,
) -> tuple[Params, Params]:
# differentiable SGD
params = list(params)
grads = list(grads)
if state is None:
state = [None for _ in params]
else:
state = list(state)
for i, param in enumerate(params):
grad = grads[i]
if weight_decay != 0:
Expand All @@ -35,4 +35,4 @@ def diff_sgd(params: Params,
grad = state[i]

params[i] = param.add(grad, alpha=-lr)
return params, state
return tuple(params), tuple(state)

0 comments on commit 9f6b83c

Please sign in to comment.