Skip to content

Commit

Permalink
Merge pull request #45 from kozistr/test/cases
Browse files Browse the repository at this point in the history
[Test] Add more test cases
  • Loading branch information
kozistr committed Jan 29, 2022
2 parents 5f1ef59 + 033a842 commit dc3c356
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 29 deletions.
18 changes: 16 additions & 2 deletions pytorch_optimizer/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def __init__(
adamd_debias_term: bool = False,
eps: float = 1e-16,
):
"""
"""AdaBelief
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param weight_decay: float. weight decay (L2 penalty)
:param n_sma_threshold: (recommended is 5)
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
:param fixed_decay: bool.
:param fixed_decay: bool. fix weight decay
:param rectify: bool. perform the rectified update similar to RAdam
:param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
:param amsgrad: bool. whether to use the AMSBound variant
Expand All @@ -63,6 +63,8 @@ def __init__(
self.adamd_debias_term = adamd_debias_term
self.eps = eps

self.check_valid_parameters()

buffer: BUFFER = [[None, None, None] for _ in range(10)]

if is_valid_parameters(params):
Expand All @@ -81,6 +83,18 @@ def __init__(
)
super().__init__(params, defaults)

def check_valid_parameters(self):
if self.lr < 0.0:
raise ValueError(f'Invalid learning rate : {self.lr}')
if not 0.0 <= self.betas[0] < 1.0:
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
if not 0.0 <= self.betas[1] < 1.0:
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
if self.weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if self.eps < 0.0:
raise ValueError(f'Invalid eps : {self.eps}')

def __setstate__(self, state: STATE):
super().__setstate__(state)
for group in self.param_groups:
Expand Down
6 changes: 4 additions & 2 deletions pytorch_optimizer/adabound.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def __init__(
adamd_debias_term: bool = False,
eps: float = 1e-8,
):
"""
"""AdaBound
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate
:param final_lr: float. final learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param gamma: float. convergence speed of the bound functions
:param weight_decay: float. weight decay (L2 penalty)
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
:param fixed_decay: bool.
:param fixed_decay: bool. fix weight decay
:param amsbound: bool. whether to use the AMSBound variant
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
:param eps: float. term added to the denominator to improve numerical stability
Expand All @@ -57,6 +57,8 @@ def __init__(
self.fixed_decay = fixed_decay
self.eps = eps

self.check_valid_parameters()

defaults: DEFAULTS = dict(
lr=lr,
betas=betas,
Expand Down
35 changes: 17 additions & 18 deletions pytorch_optimizer/pcgrad.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from copy import deepcopy
from typing import Iterable, List
from typing import Iterable, List, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -35,12 +35,12 @@ def check_valid_parameters(self):
raise ValueError(f'invalid reduction : {self.reduction}')

@staticmethod
def flatten_grad(grads) -> torch.Tensor:
def flatten_grad(grads: List[torch.Tensor]) -> torch.Tensor:
return torch.cat([g.flatten() for g in grads])

@staticmethod
def un_flatten_grad(grads, shapes) -> List[torch.Tensor]:
un_flatten_grad = []
un_flatten_grad: List[torch.Tensor] = []
idx: int = 0
for shape in shapes:
length = np.prod(shape)
Expand All @@ -54,39 +54,40 @@ def zero_grad(self):
def step(self):
return self.optimizer.step()

def set_grad(self, grads):
def set_grad(self, grads: List[torch.Tensor]):
idx: int = 0
for group in self.optimizer.param_groups:
for p in group['params']:
p.grad = grads[idx]
idx += 1

def retrieve_grad(self):
def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tensor]]:
"""get the gradient of the parameters of the network with specific objective"""
grad, shape, has_grad = [], [], []
for group in self.optimizer.param_groups:
for p in group['params']:
if p.grad is None:
shape.append(p.shape)
grad.append(torch.zeros_like(p).to(p.device))
has_grad.append(torch.zeros_like(p).to(p.device))
grad.append(torch.zeros_like(p, device=p.device))
has_grad.append(torch.zeros_like(p, device=p.device))
continue

shape.append(p.grad.shape)
grad.append(p.grad.clone())
has_grad.append(torch.ones_like(p).to(p.device))
has_grad.append(torch.ones_like(p, device=p.device))

return grad, shape, has_grad

def pack_grad(self, objectives: Iterable[nn.Module]):
def pack_grad(
self, objectives: Iterable[nn.Module]
) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
"""pack the gradient of the parameters of the network for each objective
:param objectives: Iterable[float]. a list of objectives
:param objectives: Iterable[nn.Module]. a list of objectives
:return:
"""
grads, shapes, has_grads = [], [], []
for objective in objectives:
self.zero_grad()

self.optimizer.zero_grad(set_to_none=True)
objective.backward(retain_graph=True)

grad, shape, has_grad = self.retrieve_grad()
Expand All @@ -98,7 +99,7 @@ def pack_grad(self, objectives: Iterable[nn.Module]):
return grads, shapes, has_grads

def project_conflicting(self, grads, has_grads) -> torch.Tensor:
"""
"""project conflicting
:param grads: a list of the gradient of the parameters
:param has_grads: a list of mask represent whether the parameter has gradient
:return:
Expand All @@ -114,12 +115,10 @@ def project_conflicting(self, grads, has_grads) -> torch.Tensor:
g_i -= g_i_g_j * g_j / (g_j.norm() ** 2)

merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad])

if self.reduction == 'mean':
merged_grad = merged_grad.mean(dim=0)
else: # self.reduction == 'sum'
merged_grad = merged_grad.sum(dim=0)
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0)
else:
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0)

merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)

Expand Down
6 changes: 3 additions & 3 deletions pytorch_optimizer/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def __init__(
adamd_debias_term: bool = False,
eps: float = 1e-8,
):
"""
"""RAdam
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate.
:param lr: float. learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param weight_decay: float. weight decay (L2 penalty)
:param n_sma_threshold: int. (recommended is 5)
:param degenerated_to_sgd: float.
:param degenerated_to_sgd: float. degenerated to SGD
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
:param eps: float. term added to the denominator to improve numerical stability
"""
Expand Down
14 changes: 14 additions & 0 deletions pytorch_optimizer/ranger21.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __init__(
self.norm_loss_factor = norm_loss_factor
self.eps = eps

self.check_valid_parameters()

# lookahead
self.lookahead_step: int = 0

Expand Down Expand Up @@ -124,6 +126,18 @@ def __init__(
self.start_warm_down: int = num_iterations - self.num_warm_down_iterations
self.warm_down_lr_delta: float = self.starting_lr - self.min_lr

def check_valid_parameters(self):
if self.lr < 0.0:
raise ValueError(f'Invalid learning rate : {self.lr}')
if not 0.0 <= self.betas[0] < 1.0:
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
if not 0.0 <= self.betas[1] < 1.0:
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
if self.weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if self.eps < 0.0:
raise ValueError(f'Invalid eps : {self.eps}')

def __setstate__(self, state: STATE):
super().__setstate__(state)

Expand Down
5 changes: 1 addition & 4 deletions tests/test_load_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,5 @@ def test_load_optimizers_valid(valid_optimizer_names):

@pytest.mark.parametrize('invalid_optimizer_names', INVALID_OPTIMIZER_NAMES)
def test_load_optimizers_invalid(invalid_optimizer_names):
try:
with pytest.raises(NotImplementedError):
load_optimizers(invalid_optimizer_names)
except NotImplementedError:
return True
return False
65 changes: 65 additions & 0 deletions tests/test_optimizer_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import List

import pytest

from pytorch_optimizer import load_optimizers

OPTIMIZER_NAMES: List[str] = [
'adamp',
'sgdp',
'madgrad',
'ranger',
'ranger21',
'radam',
'adabound',
'adahessian',
'adabelief',
'diffgrad',
'diffrgrad',
'lamb',
]

BETA_OPTIMIZER_NAMES: List[str] = [
'adabelief',
'adabound',
'adahessian',
'adamp',
'diffgrad',
'diffrgrad',
'lamb',
'radam',
'ranger',
'ranger21',
]


@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
def test_learning_rate(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, lr=-1e-2)


@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
def test_epsilon(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, eps=-1e-6)


@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
def test_weight_decay(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, weight_decay=-1e-3)


@pytest.mark.parametrize('optimizer_names', BETA_OPTIMIZER_NAMES)
def test_betas(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, betas=(-0.1, 0.1))

with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, betas=(0.1, -0.1))
47 changes: 47 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DiffRGrad,
Lamb,
Lookahead,
PCGrad,
RAdam,
Ranger,
Ranger21,
Expand All @@ -39,6 +40,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class MultiHeadLogisticRegression(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.head1 = nn.Linear(2, 1)
self.head2 = nn.Linear(2, 1)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.fc1(x)
x = F.relu(x)
return self.head1(x), self.head2(x)


def make_dataset(num_samples: int = 100, dims: int = 2, seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:
rng = np.random.RandomState(seed)

Expand Down Expand Up @@ -180,4 +194,37 @@ def test_sam_optimizers(optimizer_config):
loss_fn(y_data, model(x_data)).backward()
optimizer.second_step(zero_grad=True)

if init_loss == np.inf:
init_loss = loss

assert init_loss > 2.0 * loss


@pytest.mark.parametrize('optimizer_config', FP32_OPTIMIZERS, ids=ids)
def test_pc_grad_optimizers(optimizer_config):
torch.manual_seed(42)

x_data, y_data = make_dataset()

model: nn.Module = MultiHeadLogisticRegression()
loss_fn_1: nn.Module = nn.BCEWithLogitsLoss()
loss_fn_2: nn.Module = nn.L1Loss()

optimizer_class, config, iterations = optimizer_config
optimizer = PCGrad(optimizer_class(model.parameters(), **config))

loss: float = np.inf
init_loss: float = np.inf
for _ in range(iterations):
optimizer.zero_grad()
y_pred_1, y_pred_2 = model(x_data)
loss1, loss2 = loss_fn_1(y_pred_1, y_data), loss_fn_2(y_pred_2, y_data)

loss = (loss1 + loss2) / 2.0
if init_loss == np.inf:
init_loss = loss

optimizer.pc_backward([loss1, loss2])
optimizer.step()

assert init_loss > 2.0 * loss

0 comments on commit dc3c356

Please sign in to comment.