Skip to content

Commit

Permalink
refactor: get_adanorm_gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Jul 21, 2024
1 parent f5c109f commit 21c1690
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,14 @@ def get_adanorm_gradient(
:param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
:param r: float. Optional[float]. momentum (ratio).
"""
if not adanorm:
if not adanorm or exp_grad_norm is None:
return grad

grad_norm = torch.linalg.norm(grad)

exp_grad_norm.mul_(r).add_(grad_norm, alpha=1.0 - r)

return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad
return grad.mul(exp_grad_norm).div_(grad_norm) if exp_grad_norm > grad_norm else grad

@staticmethod
def get_rms(x: torch.Tensor) -> float:
Expand Down

0 comments on commit 21c1690

Please sign in to comment.