Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test] Add more test cases #45

Merged
merged 16 commits into from
Jan 29, 2022
18 changes: 16 additions & 2 deletions pytorch_optimizer/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def __init__(
adamd_debias_term: bool = False,
eps: float = 1e-16,
):
"""
"""AdaBelief
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param weight_decay: float. weight decay (L2 penalty)
:param n_sma_threshold: (recommended is 5)
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
:param fixed_decay: bool.
:param fixed_decay: bool. fix weight decay
:param rectify: bool. perform the rectified update similar to RAdam
:param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
:param amsgrad: bool. whether to use the AMSBound variant
Expand All @@ -63,6 +63,8 @@ def __init__(
self.adamd_debias_term = adamd_debias_term
self.eps = eps

self.check_valid_parameters()

buffer: BUFFER = [[None, None, None] for _ in range(10)]

if is_valid_parameters(params):
Expand All @@ -81,6 +83,18 @@ def __init__(
)
super().__init__(params, defaults)

def check_valid_parameters(self):
if self.lr < 0.0:
raise ValueError(f'Invalid learning rate : {self.lr}')
if not 0.0 <= self.betas[0] < 1.0:
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
if not 0.0 <= self.betas[1] < 1.0:
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
if self.weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if self.eps < 0.0:
raise ValueError(f'Invalid eps : {self.eps}')

def __setstate__(self, state: STATE):
super().__setstate__(state)
for group in self.param_groups:
Expand Down
6 changes: 4 additions & 2 deletions pytorch_optimizer/adabound.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def __init__(
adamd_debias_term: bool = False,
eps: float = 1e-8,
):
"""
"""AdaBound
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate
:param final_lr: float. final learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param gamma: float. convergence speed of the bound functions
:param weight_decay: float. weight decay (L2 penalty)
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
:param fixed_decay: bool.
:param fixed_decay: bool. fix weight decay
:param amsbound: bool. whether to use the AMSBound variant
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
:param eps: float. term added to the denominator to improve numerical stability
Expand All @@ -57,6 +57,8 @@ def __init__(
self.fixed_decay = fixed_decay
self.eps = eps

self.check_valid_parameters()

defaults: DEFAULTS = dict(
lr=lr,
betas=betas,
Expand Down
35 changes: 17 additions & 18 deletions pytorch_optimizer/pcgrad.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from copy import deepcopy
from typing import Iterable, List
from typing import Iterable, List, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -35,12 +35,12 @@ def check_valid_parameters(self):
raise ValueError(f'invalid reduction : {self.reduction}')

@staticmethod
def flatten_grad(grads) -> torch.Tensor:
def flatten_grad(grads: List[torch.Tensor]) -> torch.Tensor:
return torch.cat([g.flatten() for g in grads])

@staticmethod
def un_flatten_grad(grads, shapes) -> List[torch.Tensor]:
un_flatten_grad = []
un_flatten_grad: List[torch.Tensor] = []
idx: int = 0
for shape in shapes:
length = np.prod(shape)
Expand All @@ -54,39 +54,40 @@ def zero_grad(self):
def step(self):
return self.optimizer.step()

def set_grad(self, grads):
def set_grad(self, grads: List[torch.Tensor]):
idx: int = 0
for group in self.optimizer.param_groups:
for p in group['params']:
p.grad = grads[idx]
idx += 1

def retrieve_grad(self):
def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tensor]]:
"""get the gradient of the parameters of the network with specific objective"""
grad, shape, has_grad = [], [], []
for group in self.optimizer.param_groups:
for p in group['params']:
if p.grad is None:
shape.append(p.shape)
grad.append(torch.zeros_like(p).to(p.device))
has_grad.append(torch.zeros_like(p).to(p.device))
grad.append(torch.zeros_like(p, device=p.device))
has_grad.append(torch.zeros_like(p, device=p.device))
continue

shape.append(p.grad.shape)
grad.append(p.grad.clone())
has_grad.append(torch.ones_like(p).to(p.device))
has_grad.append(torch.ones_like(p, device=p.device))

return grad, shape, has_grad

def pack_grad(self, objectives: Iterable[nn.Module]):
def pack_grad(
self, objectives: Iterable[nn.Module]
) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
"""pack the gradient of the parameters of the network for each objective
:param objectives: Iterable[float]. a list of objectives
:param objectives: Iterable[nn.Module]. a list of objectives
:return:
"""
grads, shapes, has_grads = [], [], []
for objective in objectives:
self.zero_grad()

self.optimizer.zero_grad(set_to_none=True)
objective.backward(retain_graph=True)

grad, shape, has_grad = self.retrieve_grad()
Expand All @@ -98,7 +99,7 @@ def pack_grad(self, objectives: Iterable[nn.Module]):
return grads, shapes, has_grads

def project_conflicting(self, grads, has_grads) -> torch.Tensor:
"""
"""project conflicting
:param grads: a list of the gradient of the parameters
:param has_grads: a list of mask represent whether the parameter has gradient
:return:
Expand All @@ -114,12 +115,10 @@ def project_conflicting(self, grads, has_grads) -> torch.Tensor:
g_i -= g_i_g_j * g_j / (g_j.norm() ** 2)

merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad])

if self.reduction == 'mean':
merged_grad = merged_grad.mean(dim=0)
else: # self.reduction == 'sum'
merged_grad = merged_grad.sum(dim=0)
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0)
else:
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0)

merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)

Expand Down
6 changes: 3 additions & 3 deletions pytorch_optimizer/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def __init__(
adamd_debias_term: bool = False,
eps: float = 1e-8,
):
"""
"""RAdam
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate.
:param lr: float. learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param weight_decay: float. weight decay (L2 penalty)
:param n_sma_threshold: int. (recommended is 5)
:param degenerated_to_sgd: float.
:param degenerated_to_sgd: float. degenerated to SGD
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
:param eps: float. term added to the denominator to improve numerical stability
"""
Expand Down
14 changes: 14 additions & 0 deletions pytorch_optimizer/ranger21.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __init__(
self.norm_loss_factor = norm_loss_factor
self.eps = eps

self.check_valid_parameters()

# lookahead
self.lookahead_step: int = 0

Expand Down Expand Up @@ -124,6 +126,18 @@ def __init__(
self.start_warm_down: int = num_iterations - self.num_warm_down_iterations
self.warm_down_lr_delta: float = self.starting_lr - self.min_lr

def check_valid_parameters(self):
if self.lr < 0.0:
raise ValueError(f'Invalid learning rate : {self.lr}')
if not 0.0 <= self.betas[0] < 1.0:
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
if not 0.0 <= self.betas[1] < 1.0:
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
if self.weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if self.eps < 0.0:
raise ValueError(f'Invalid eps : {self.eps}')

def __setstate__(self, state: STATE):
super().__setstate__(state)

Expand Down
5 changes: 1 addition & 4 deletions tests/test_load_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,5 @@ def test_load_optimizers_valid(valid_optimizer_names):

@pytest.mark.parametrize('invalid_optimizer_names', INVALID_OPTIMIZER_NAMES)
def test_load_optimizers_invalid(invalid_optimizer_names):
try:
with pytest.raises(NotImplementedError):
load_optimizers(invalid_optimizer_names)
except NotImplementedError:
return True
return False
65 changes: 65 additions & 0 deletions tests/test_optimizer_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import List

import pytest

from pytorch_optimizer import load_optimizers

OPTIMIZER_NAMES: List[str] = [
'adamp',
'sgdp',
'madgrad',
'ranger',
'ranger21',
'radam',
'adabound',
'adahessian',
'adabelief',
'diffgrad',
'diffrgrad',
'lamb',
]

BETA_OPTIMIZER_NAMES: List[str] = [
'adabelief',
'adabound',
'adahessian',
'adamp',
'diffgrad',
'diffrgrad',
'lamb',
'radam',
'ranger',
'ranger21',
]


@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
def test_learning_rate(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, lr=-1e-2)


@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
def test_epsilon(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, eps=-1e-6)


@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
def test_weight_decay(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, weight_decay=-1e-3)


@pytest.mark.parametrize('optimizer_names', BETA_OPTIMIZER_NAMES)
def test_betas(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, betas=(-0.1, 0.1))

with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, betas=(0.1, -0.1))
47 changes: 47 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DiffRGrad,
Lamb,
Lookahead,
PCGrad,
RAdam,
Ranger,
Ranger21,
Expand All @@ -39,6 +40,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class MultiHeadLogisticRegression(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.head1 = nn.Linear(2, 1)
self.head2 = nn.Linear(2, 1)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.fc1(x)
x = F.relu(x)
return self.head1(x), self.head2(x)


def make_dataset(num_samples: int = 100, dims: int = 2, seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:
rng = np.random.RandomState(seed)

Expand Down Expand Up @@ -180,4 +194,37 @@ def test_sam_optimizers(optimizer_config):
loss_fn(y_data, model(x_data)).backward()
optimizer.second_step(zero_grad=True)

if init_loss == np.inf:
init_loss = loss

assert init_loss > 2.0 * loss


@pytest.mark.parametrize('optimizer_config', FP32_OPTIMIZERS, ids=ids)
def test_pc_grad_optimizers(optimizer_config):
torch.manual_seed(42)

x_data, y_data = make_dataset()

model: nn.Module = MultiHeadLogisticRegression()
loss_fn_1: nn.Module = nn.BCEWithLogitsLoss()
loss_fn_2: nn.Module = nn.L1Loss()

optimizer_class, config, iterations = optimizer_config
optimizer = PCGrad(optimizer_class(model.parameters(), **config))

loss: float = np.inf
init_loss: float = np.inf
for _ in range(iterations):
optimizer.zero_grad()
y_pred_1, y_pred_2 = model(x_data)
loss1, loss2 = loss_fn_1(y_pred_1, y_data), loss_fn_2(y_pred_2, y_data)

loss = (loss1 + loss2) / 2.0
if init_loss == np.inf:
init_loss = loss

optimizer.pc_backward([loss1, loss2])
optimizer.step()

assert init_loss > 2.0 * loss