from functools import partial
import torch
from torch.optim import Optimizer
[docs]def Adam(lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, multi_tensor: bool = False):
locs = locals()
locs.pop("multi_tensor")
opt = torch.optim._multi_tensor.Adam if multi_tensor else torch.optim.Adam
return partial(opt, **locs)
[docs]def SGD(lr=1e-1, momentum=0, dampening=0, weight_decay=0, nesterov=False, multi_tensor: bool = False):
locs = locals()
locs.pop("multi_tensor")
opt = torch.optim._multi_tensor.SGD if multi_tensor else torch.optim.SGD
return partial(opt, **locs)
[docs]def RMSprop(lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, multi_tensor: bool = False):
locs = locals()
locs.pop("multi_tensor")
opt = torch.optim._multi_tensor.RMSprop if multi_tensor else torch.optim.RMSprop
return partial(opt, **locs)
[docs]def AdamW(lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, multi_tensor: bool = False):
locs = locals()
locs.pop("multi_tensor")
opt = torch.optim._multi_tensor.AdamW if multi_tensor else torch.optim.AdamW
return partial(opt, **locs)
[docs]class LARC(object):
""" LARC based on NVIDIA's Apex for Layer-wise Adaptive Rate Scaling. LARC is designed to wrap a given optimizer.
Optimizer should be wrapped after initializing scheduler.
"""
def __init__(self,
optimizer: Optimizer,
trust_coefficient: float = 0.02,
no_clip: bool = False,
eps: float = 1e-8):
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.clip = not no_clip
self.eps = eps
def __getstate__(self):
return self.optim.__getstate__()
def __setstate__(self, state):
self.optim.__setstate__(state)
@property
def state(self):
return self.optim.state
def __repr__(self):
return self.optim.__repr__()
@property
def param_groups(self):
return self.optim.param_groups
@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value
[docs] def state_dict(self):
return self.optim.state_dict()
[docs] def load_state_dict(self, state_dict):
self.optim.load_state_dict(state_dict)
[docs] def zero_grad(self):
self.optim.zero_grad()
[docs] def add_param_group(self, param_group):
self.optim.add_param_group(param_group)
[docs] @torch.no_grad()
def step(self):
weight_decays = []
for group in self.optim.param_groups:
# absorb weight decay control from optimizer
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
weight_decays.append(weight_decay)
group['weight_decay'] = 0
params = []
grads = []
lrs = []
for p in group['params']:
if p.grad is None:
continue
param_norm = torch.norm(p.data)
grad_norm = torch.norm(p.grad.data)
if param_norm != 0 and grad_norm != 0:
# calculate adaptive lr + weight decay
# .item() may be sub-optimal, but required because _foreach_* don't support broadcasting at the moment
adaptive_lr = (self.trust_coefficient * param_norm /
(grad_norm + param_norm * weight_decay + self.eps)).item()
# clip learning rate for LARC
if self.clip:
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
adaptive_lr = min(adaptive_lr / group['lr'], 1.0)
params.append(p.data)
grads.append(p.grad.data)
lrs.append(adaptive_lr)
# p.grad.data += weight_decay * p.data
# p.grad.data *= adaptive_lr
torch._foreach_add_(grads, params, alpha=weight_decay)
torch._foreach_mul_(grads, lrs)
self.optim.step()
# return weight decay control to optimizer
for i, group in enumerate(self.optim.param_groups):
group['weight_decay'] = weight_decays[i]