Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement more Shampoo features #99

Merged
merged 55 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
bda6d74
update: set multiplier to 1.0 when grafting type is None
kozistr Jan 31, 2023
caba6d5
build(version): v2.3.1
kozistr Jan 31, 2023
8a5161e
update: default start_preconditioning_step to 5
kozistr Jan 31, 2023
1bade60
feature: moving_average_for_momentum
kozistr Jan 31, 2023
218e781
update: decoupled_weight_decay
kozistr Jan 31, 2023
106d2e8
update: Shampoo recipes
kozistr Jan 31, 2023
683be6d
feature: decoupled_learning_rate
kozistr Jan 31, 2023
1bb4463
update: Shampoo recipe
kozistr Jan 31, 2023
efcfc7c
update: default parameters
kozistr Jan 31, 2023
24ba91c
update: exclude Shampoo w/ decoupled_learning_rate
kozistr Jan 31, 2023
5847c49
update: max_error_ratio
kozistr Jan 31, 2023
3abbc24
update: Shampoo recipe
kozistr Jan 31, 2023
5ca8b4a
update: Shampoo recipe
kozistr Jan 31, 2023
6f88acc
update: Shampoo recipe
kozistr Jan 31, 2023
91dc676
feature: RMSPropGraft
kozistr Jan 31, 2023
7f253c2
style: fix Graft
kozistr Jan 31, 2023
2a27c7c
feature: supports RMSPropGraft
kozistr Jan 31, 2023
5734b7f
docs: Shampoo optimizer docstring
kozistr Jan 31, 2023
54f2d64
update: Shampoo recipes
kozistr Jan 31, 2023
4e5e3ff
fix: missing beta2
kozistr Jan 31, 2023
15d5ae7
update: test_pc_grad_optimizers
kozistr Jan 31, 2023
922e13d
update: test_shampoo_optimizer
kozistr Jan 31, 2023
7f8c3f8
update: Shampoo recipe
kozistr Jan 31, 2023
d29dda4
feature: SQRTNGraft
kozistr Jan 31, 2023
d54694e
feature: supports SQRTN grafting
kozistr Jan 31, 2023
ad96202
update: recipes
kozistr Jan 31, 2023
f841c0a
update: test_shampoo_optimizer_graft_types
kozistr Jan 31, 2023
8818c2a
update: recipes
kozistr Jan 31, 2023
3505b32
refactor: term, momentum to beta1
kozistr Jan 31, 2023
9fc43f1
update: reset
kozistr Jan 31, 2023
6abd00a
update: test_shampoo_optimizer_graft_types
kozistr Jan 31, 2023
f069f0b
refactor: betas
kozistr Jan 31, 2023
cdad7fe
update: BETA_OPTIMIZER_NAMES
kozistr Jan 31, 2023
5046889
update: test_momentum
kozistr Jan 31, 2023
d56640a
docs: Shampoo docstring
kozistr Jan 31, 2023
7741fc8
fix: pre-conditioner parameter, beta2
kozistr Jan 31, 2023
ea311b4
fix: reset
kozistr Jan 31, 2023
796b360
update: recipes
kozistr Jan 31, 2023
b6c167f
feature: PreConditionerType
kozistr Jan 31, 2023
854e324
docs: PreConditionerType docstring
kozistr Jan 31, 2023
f0beb9a
feature: supports PreConditionerType
kozistr Jan 31, 2023
1624b1b
refactor: precondition_block
kozistr Jan 31, 2023
b4652f5
refactor: preconditioned_grad
kozistr Jan 31, 2023
bff9177
update: should_precondition_dims
kozistr Jan 31, 2023
6e3f6ac
update: recipes
kozistr Jan 31, 2023
5693b1d
update: recipes
kozistr Jan 31, 2023
b9a9f7e
update: recipes
kozistr Jan 31, 2023
5271865
update: precondition_block
kozistr Jan 31, 2023
29ec7c0
update: BlockPartitioner
kozistr Jan 31, 2023
44609d1
update: merge_small_dims
kozistr Jan 31, 2023
8bffa33
update: test_merge_small_dims
kozistr Jan 31, 2023
4225940
update: test_shampoo_optimizer
kozistr Jan 31, 2023
abcf6f0
update: BlockPartitioner
kozistr Jan 31, 2023
a0af041
fix: add_statistics when PreConditionerType is 1 (INPUT)
kozistr Jan 31, 2023
5da4074
style: format
kozistr Jan 31, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytorch_optimizer"
version = "2.3.0"
version = "2.3.1"
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
license = "Apache-2.0"
authors = ["kozistr <[email protected]>"]
Expand Down
86 changes: 62 additions & 24 deletions pytorch_optimizer/optimizer/shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.optimizer.shampoo_utils import AdagradGraft, Graft, LayerWiseGrafting, PreConditioner, SGDGraft
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.optimizer.shampoo_utils import (
AdagradGraft,
Graft,
LayerWiseGrafting,
PreConditioner,
PreConditionerType,
RMSPropGraft,
SGDGraft,
SQRTNGraft,
)


class Shampoo(Optimizer, BaseOptimizer):
Expand All @@ -14,9 +23,11 @@ class Shampoo(Optimizer, BaseOptimizer):

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param momentum: float. momentum.
:param beta2: float. beta2.
:param betas: BETAS. beta1, beta2.
:param moving_average_for_momentum: bool. perform moving_average for momentum (beta1).
:param weight_decay: float. weight decay (L2 penalty).
:param decoupled_weight_decay: bool. use decoupled weight_decay.
:param decoupled_learning_rate: bool. use decoupled lr, otherwise couple it w/ preconditioned gradient.
:param inverse_exponent_override: int. fixed exponent for pre-conditioner, if > 0.
:param start_preconditioning_step: int.
:param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
Expand All @@ -28,7 +39,8 @@ class Shampoo(Optimizer, BaseOptimizer):
:param shape_interpretation: bool. Automatic shape interpretation (for eg: [4, 3, 1024, 512] would
result in 12 x [1024, 512] L and R statistics. Disabled by default which results in Shampoo constructing
statistics [4, 4], [3, 3], [1024, 1024], [512, 512].
:param graft_type: bool. Type of grafting (SGD or AdaGrad).
:param graft_type: int. type of grafting (SGD or AdaGrad or RMSProp or SQRT_N or None).
:param pre_conditioner_type: int. type of pre-conditioner.
:param nesterov: bool. Nesterov momentum.
:param diagonal_eps: float. term added to the denominator to improve numerical stability.
:param matrix_eps: float. term added to the denominator to improve numerical stability.
Expand All @@ -38,31 +50,37 @@ def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
momentum: float = 0.0,
beta2: float = 1.0,
betas: BETAS = (0.9, 0.999),
moving_average_for_momentum: bool = False,
weight_decay: float = 0.0,
decoupled_weight_decay: bool = False,
decoupled_learning_rate: bool = True,
inverse_exponent_override: int = 0,
start_preconditioning_step: int = 1,
start_preconditioning_step: int = 5,
preconditioning_compute_steps: int = 1,
statistics_compute_steps: int = 1,
block_size: int = 128,
shape_interpretation: bool = True,
graft_type: int = LayerWiseGrafting.SGD,
pre_conditioner_type: int = PreConditionerType.ALL,
nesterov: bool = True,
diagonal_eps: float = 1e-6,
matrix_eps: float = 1e-12,
diagonal_eps: float = 1e-10,
matrix_eps: float = 1e-6,
):
self.lr = lr
self.momentum = momentum
self.beta2 = beta2
self.betas = betas
self.moving_average_for_momentum = moving_average_for_momentum
self.weight_decay = weight_decay
self.decoupled_weight_decay = decoupled_weight_decay
self.decoupled_learning_rate = decoupled_learning_rate
self.inverse_exponent_override = inverse_exponent_override
self.start_preconditioning_step = start_preconditioning_step
self.preconditioning_compute_steps = preconditioning_compute_steps
self.statistics_compute_steps = statistics_compute_steps
self.block_size = block_size
self.shape_interpretation = shape_interpretation
self.graft_type = graft_type
self.pre_conditioner_type = pre_conditioner_type
self.nesterov = nesterov
self.diagonal_eps = diagonal_eps
self.matrix_eps = matrix_eps
Expand All @@ -71,14 +89,14 @@ def __init__(

defaults: DEFAULTS = {
'lr': lr,
'momentum': momentum,
'betas': betas,
'weight_decay': weight_decay,
}
super().__init__(params, defaults)

def validate_parameters(self):
self.validate_learning_rate(self.lr)
self.validate_momentum(self.momentum)
self.validate_betas(self.betas)
self.validate_weight_decay(self.weight_decay)
self.validate_update_frequency(self.start_preconditioning_step)
self.validate_update_frequency(self.statistics_compute_steps)
Expand All @@ -100,16 +118,21 @@ def reset(self):
state['momentum'] = torch.zeros_like(p)
state['pre_conditioner'] = PreConditioner(
p,
self.beta2,
group['betas'][1], # beta2
self.inverse_exponent_override,
self.block_size,
self.shape_interpretation,
self.matrix_eps,
self.pre_conditioner_type,
)
if self.graft_type == LayerWiseGrafting.ADAGRAD:
state['graft'] = AdagradGraft(p, self.diagonal_eps)
elif self.graft_type == LayerWiseGrafting.RMSPROP:
state['graft'] = RMSPropGraft(p, self.diagonal_eps)
elif self.graft_type == LayerWiseGrafting.SGD:
state['graft'] = SGDGraft(p)
elif self.graft_type == LayerWiseGrafting.SQRTN:
state['graft'] = SQRTNGraft(p)
else:
state['graft'] = Graft(p)

Expand All @@ -121,6 +144,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
loss = closure()

for group in self.param_groups:
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
Expand All @@ -135,48 +159,59 @@ def step(self, closure: CLOSURE = None) -> LOSS:
state['momentum'] = torch.zeros_like(p)
state['pre_conditioner'] = PreConditioner(
p,
self.beta2,
beta2,
self.inverse_exponent_override,
self.block_size,
self.shape_interpretation,
self.matrix_eps,
self.pre_conditioner_type,
)
if self.graft_type == LayerWiseGrafting.ADAGRAD:
state['graft'] = AdagradGraft(p, self.diagonal_eps)
elif self.graft_type == LayerWiseGrafting.RMSPROP:
state['graft'] = RMSPropGraft(p, self.diagonal_eps)
elif self.graft_type == LayerWiseGrafting.SGD:
state['graft'] = SGDGraft(p)
elif self.graft_type == LayerWiseGrafting.SQRTN:
state['graft'] = SQRTNGraft(p)
else:
state['graft'] = Graft(p)

state['step'] += 1
pre_conditioner, graft = state['pre_conditioner'], state['graft']

# gather statistics, compute pre-conditioners
graft.add_statistics(grad)
graft.add_statistics(grad, beta2)
if state['step'] % self.statistics_compute_steps == 0:
pre_conditioner.add_statistics(grad)
if state['step'] % self.preconditioning_compute_steps == 0:
pre_conditioner.compute_pre_conditioners()

# pre-condition gradients
graft_grad: torch.Tensor = graft.precondition_gradient(grad)
pre_conditioner_multiplier: float = group['lr'] if not self.decoupled_learning_rate else 1.0
graft_grad: torch.Tensor = graft.precondition_gradient(grad * pre_conditioner_multiplier)
shampoo_grad: torch.Tensor = grad
if state['step'] >= self.start_preconditioning_step:
shampoo_grad = pre_conditioner.preconditioned_grad(grad)

# grafting
graft_norm = torch.norm(graft_grad)
shampoo_norm = torch.norm(shampoo_grad)
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
if self.graft_type != LayerWiseGrafting.NONE:
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))

# apply weight decay (adam style)
if group['weight_decay'] > 0.0:
shampoo_grad.add_(p, alpha=group['weight_decay'])
graft_grad.add_(p, alpha=group['weight_decay'])
if not self.decoupled_weight_decay:
shampoo_grad.add_(p, alpha=group['weight_decay'])
graft_grad.add_(p, alpha=group['weight_decay'])
else:
shampoo_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
graft_grad.mul_(1.0 - group['lr'] * group['weight_decay'])

# Momentum and Nesterov momentum, if needed
state['momentum'].mul_(group['momentum']).add_(shampoo_grad)
graft_momentum = graft.update_momentum(grad, group['momentum'])
state['momentum'].mul_(beta1).add_(shampoo_grad)
graft_momentum = graft.update_momentum(grad, beta1)

if state['step'] >= self.start_preconditioning_step:
momentum_update = state['momentum']
Expand All @@ -186,7 +221,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
wd_update = graft_grad

if self.nesterov:
momentum_update.mul_(group['momentum']).add_(wd_update)
w: float = (1.0 - beta1) if self.moving_average_for_momentum else 1.0
wd_update.mul_(w)

momentum_update.mul_(beta1).add_(wd_update)

p.add_(momentum_update, alpha=-group['lr'])

Expand Down
Loading