Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] AdamP optimizer #31

Merged
merged 2 commits into from
Oct 6, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
refactor: AdamP
  • Loading branch information
kozistr committed Oct 6, 2021
commit 5e9b37578185d638b50a82748e892671f9309fc1
16 changes: 8 additions & 8 deletions pytorch_optimizer/adamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,36 +52,36 @@ def __init__(
self.lr = lr
self.betas = betas
self.weight_decay = weight_decay
self.eps = eps
self.wd_ratio = wd_ratio
self.use_gc = use_gc
self.eps = eps

self.check_valid_parameters()

defaults: DEFAULTS = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
delta=delta,
wd_ratio=wd_ratio,
nesterov=nesterov,
eps=eps,
)
super().__init__(params, defaults)
super().__init__(params=params, default=defaults)

def check_valid_parameters(self):
if self.lr < 0.0:
raise ValueError(f'Invalid learning rate : {self.lr}')
if self.weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if not 0.0 <= self.betas[0] < 1.0:
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
if not 0.0 <= self.betas[1] < 1.0:
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
if not 0.0 <= self.wd_ratio < 1.0:
raise ValueError(f'Invalid wd_ratio : {self.wd_ratio}')
if self.weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if self.eps < 0.0:
raise ValueError(f'Invalid eps : {self.eps}')
if not 0.0 <= self.wd_ratio < 1.0:
raise ValueError(f'Invalid wd_ratio : {self.wd_ratio}')

@staticmethod
def channel_view(x: torch.Tensor) -> torch.Tensor:
Expand All @@ -97,7 +97,7 @@ def cosine_similarity(
y: torch.Tensor,
eps: float,
view_func: Callable[[torch.Tensor], torch.Tensor],
):
) -> torch.Tensor:
x = view_func(x)
y = view_func(y)
return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()
Expand Down