-
Notifications
You must be signed in to change notification settings - Fork 18
/
test_optimizer_parameters.py
41 lines (32 loc) · 1008 Bytes
/
test_optimizer_parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import List
import pytest
from pytorch_optimizer import load_optimizers
VALID_OPTIMIZER_NAMES: List[str] = [
'adamp',
'sgdp',
'madgrad',
'ranger',
'ranger21',
'radam',
'adabound',
'adahessian',
'adabelief',
'diffgrad',
'diffrgrad',
'lamb',
]
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
def test_learning_rate(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, lr=-1e-2)
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
def test_epsilon(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, eps=-1e-6)
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
def test_weight_decay(optimizer_names):
with pytest.raises(ValueError):
optimizer = load_optimizers(optimizer_names)
optimizer(None, weight_decay=-1e-3)