Source code for homura.utils.grad_tools

from __future__ import annotations

from collections.abc import Iterable

import torch
from torch import nn
from torch.autograd import grad


[docs]def param_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor: devices = [p.data.device for p in parameters] if len(devices) > 1 and any([devices[0] != d for d in devices[1:]]): raise RuntimeError("Inconsistent devices in parameters!") return torch.cat([param.reshape(-1) for param in parameters])
[docs]def vjp(f: torch.Tensor, p: nn.Parameter | list[nn.Parameter], v: torch.Tensor, *, only_retain_graph: bool = False ) -> torch.Tensor: """ vector Jacobian product """ return param_to_vector(grad(f, p, v, create_graph=not only_retain_graph, retain_graph=True)) # dim p
[docs]def jvp(f: torch.Tensor, p: nn.Parameter | list[nn.Parameter], v: torch.Tensor ) -> torch.Tensor: """ Jacobian vector product """ dummy = torch.ones_like(f, requires_grad=True) g = vjp(f, p, dummy) # note that we don't need higher order gradient w.r.t. dummy return vjp(g, dummy, v, only_retain_graph=True) # dim y
[docs]def hvp(loss: torch.Tensor, f, p: nn.Parameter or list[nn.Parameter], v: torch.Tensor ) -> torch.Tensor: """ Hessian vector product """ df_dp = param_to_vector(grad(loss, p, create_graph=True)) return vjp(df_dp, p, v) # dim p
[docs]def ggnvp(loss: torch.Tensor, f: torch.Tensor, p: nn.Parameter | list[nn.Parameter], v: torch.Tensor ) -> torch.Tensor: """ Generalized Gaussian Newton vector product. In case of loss=F.cross_entropy(output, target), GGN matrix is equivalent to the Fisher matrix. """ jv = jvp(f, p, v) # dim y hjv = hvp(loss, None, f, jv) # dim y jhjv = vjp(f.view(-1), p, hjv) # dim p return jhjv