Skip to content

Commit

Permalink
Merge pull request #43 from kozistr/refactor/api
Browse files Browse the repository at this point in the history
[Fix] Ranger optimizer
  • Loading branch information
kozistr committed Jan 28, 2022
2 parents 3c0ba47 + 9c2650e commit 75463dc
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 22 deletions.
16 changes: 8 additions & 8 deletions pytorch_optimizer/ranger.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ def step(self, _: CLOSURE = None) -> LOSS:
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

bias_correction1 = 1 - beta1 ** state['step']

if self.use_gc and grad.dim() > self.gc_gradient_threshold:
grad = centralize_gradient(grad, gc_conv_only=False)

state['step'] += 1

exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

bias_correction1 = 1 - beta1 ** state['step']

buffered = self.buffer[int(state['step'] % 10)]

Expand Down Expand Up @@ -162,19 +162,19 @@ def step(self, _: CLOSURE = None) -> LOSS:
buffered[2] = step_size

if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])

if n_sma > self.n_sma_threshold:
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
else:
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])

p.data.copy_(p_data_fp32)

if state['step'] % group['k'] == 0:
slow_p = state['slow_buffer']
slow_p.add_(self.alpha, p.data - slow_p)
slow_p.add_(p.data - slow_p, alpha=self.alpha)
p.data.copy_(slow_p)

return loss
18 changes: 9 additions & 9 deletions pytorch_optimizer/ranger21.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,20 @@ def __init__(
norm_loss_factor: float = 1e-4,
eps: float = 1e-8,
):
"""
"""Ranger21
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate.
:param beta0: float. Manages the amplitude of the noise introduced by positive negative momentum.
While 0.9 is a recommended default value, you can use -0.5 to minimize the noise.
:param lr: float. learning rate
:param beta0: float. Manages the amplitude of the noise introduced by positive negative momentum
While 0.9 is a recommended default value, you can use -0.5 to minimize the noise
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param use_softplus: bool. use softplus to smooth
:param beta_softplus: float. beta
:param agc_clipping_value: float.
:param agc_eps: float.
:param agc_clipping_value: float
:param agc_eps: float
:param centralize_gradients: bool. use GC both convolution & fc layers
:param normalize_gradients: bool. use gradient normalization
:param lookahead_merge_time: int.
:param lookahead_blending_alpha: float.
:param lookahead_merge_time: int. merge time
:param lookahead_blending_alpha: float. blending alpha
:param weight_decay: float. weight decay (L2 penalty)
:param norm_loss_factor: float. norm loss factor
:param eps: float. term added to the denominator to improve numerical stability
Expand Down Expand Up @@ -237,8 +237,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:

# Phase 2 - Apply weight decay and step
for group in self.param_groups:
step = group['step']
lr = group['lr']
step = self.state[group['params'][0]]['step']

# warm up
lr = self.warm_up_dampening(lr, step)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__VERSION__ = '0.3.1'
__VERSION__ = '0.3.2'
8 changes: 4 additions & 4 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
SGDP,
AdaBelief,
AdaBound,
AdaHessian,
AdamP,
DiffGrad,
DiffRGrad,
Lamb,
Lookahead,
RAdam,
Ranger,
Ranger21,
)
from pytorch_optimizer.types import BETAS

__REFERENCE__ = 'https://github.com/jettify/pytorch-optimizer/blob/master/tests/test_optimizer_with_nn.py'

Expand Down Expand Up @@ -67,7 +66,7 @@ def build_lookahead(*parameters, **kwargs):
return Lookahead(AdamP(*parameters, **kwargs))


OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int, BETAS]], int]] = [
OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(build_lookahead, {'lr': 1e-2, 'weight_decay': 1e-3}, 200),
(AdaBelief, {'lr': 1e-2, 'weight_decay': 1e-3}, 200),
(AdaBound, {'lr': 1e-2, 'gamma': 0.1, 'weight_decay': 1e-3}, 200),
Expand All @@ -78,7 +77,8 @@ def build_lookahead(*parameters, **kwargs):
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3}, 200),
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
(SGDP, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
# (Ranger, {'lr': 1e-3, 'weight_decay': 1e-3, 'betas': (0.99, 0.999)}, 200),
(Ranger, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'num_iterations': 1000}, 500),
# (AdaHessian, {'lr': 1e-2, 'weight_decay': 1e-3}, 200),
]

Expand Down

0 comments on commit 75463dc

Please sign in to comment.