Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement more Shampoo features #99

Merged
merged 55 commits into from
Jan 31, 2023
Merged
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
bda6d74
update: set multiplier to 1.0 when grafting type is None
kozistr Jan 31, 2023
caba6d5
build(version): v2.3.1
kozistr Jan 31, 2023
8a5161e
update: default start_preconditioning_step to 5
kozistr Jan 31, 2023
1bade60
feature: moving_average_for_momentum
kozistr Jan 31, 2023
218e781
update: decoupled_weight_decay
kozistr Jan 31, 2023
106d2e8
update: Shampoo recipes
kozistr Jan 31, 2023
683be6d
feature: decoupled_learning_rate
kozistr Jan 31, 2023
1bb4463
update: Shampoo recipe
kozistr Jan 31, 2023
efcfc7c
update: default parameters
kozistr Jan 31, 2023
24ba91c
update: exclude Shampoo w/ decoupled_learning_rate
kozistr Jan 31, 2023
5847c49
update: max_error_ratio
kozistr Jan 31, 2023
3abbc24
update: Shampoo recipe
kozistr Jan 31, 2023
5ca8b4a
update: Shampoo recipe
kozistr Jan 31, 2023
6f88acc
update: Shampoo recipe
kozistr Jan 31, 2023
91dc676
feature: RMSPropGraft
kozistr Jan 31, 2023
7f253c2
style: fix Graft
kozistr Jan 31, 2023
2a27c7c
feature: supports RMSPropGraft
kozistr Jan 31, 2023
5734b7f
docs: Shampoo optimizer docstring
kozistr Jan 31, 2023
54f2d64
update: Shampoo recipes
kozistr Jan 31, 2023
4e5e3ff
fix: missing beta2
kozistr Jan 31, 2023
15d5ae7
update: test_pc_grad_optimizers
kozistr Jan 31, 2023
922e13d
update: test_shampoo_optimizer
kozistr Jan 31, 2023
7f8c3f8
update: Shampoo recipe
kozistr Jan 31, 2023
d29dda4
feature: SQRTNGraft
kozistr Jan 31, 2023
d54694e
feature: supports SQRTN grafting
kozistr Jan 31, 2023
ad96202
update: recipes
kozistr Jan 31, 2023
f841c0a
update: test_shampoo_optimizer_graft_types
kozistr Jan 31, 2023
8818c2a
update: recipes
kozistr Jan 31, 2023
3505b32
refactor: term, momentum to beta1
kozistr Jan 31, 2023
9fc43f1
update: reset
kozistr Jan 31, 2023
6abd00a
update: test_shampoo_optimizer_graft_types
kozistr Jan 31, 2023
f069f0b
refactor: betas
kozistr Jan 31, 2023
cdad7fe
update: BETA_OPTIMIZER_NAMES
kozistr Jan 31, 2023
5046889
update: test_momentum
kozistr Jan 31, 2023
d56640a
docs: Shampoo docstring
kozistr Jan 31, 2023
7741fc8
fix: pre-conditioner parameter, beta2
kozistr Jan 31, 2023
ea311b4
fix: reset
kozistr Jan 31, 2023
796b360
update: recipes
kozistr Jan 31, 2023
b6c167f
feature: PreConditionerType
kozistr Jan 31, 2023
854e324
docs: PreConditionerType docstring
kozistr Jan 31, 2023
f0beb9a
feature: supports PreConditionerType
kozistr Jan 31, 2023
1624b1b
refactor: precondition_block
kozistr Jan 31, 2023
b4652f5
refactor: preconditioned_grad
kozistr Jan 31, 2023
bff9177
update: should_precondition_dims
kozistr Jan 31, 2023
6e3f6ac
update: recipes
kozistr Jan 31, 2023
5693b1d
update: recipes
kozistr Jan 31, 2023
b9a9f7e
update: recipes
kozistr Jan 31, 2023
5271865
update: precondition_block
kozistr Jan 31, 2023
29ec7c0
update: BlockPartitioner
kozistr Jan 31, 2023
44609d1
update: merge_small_dims
kozistr Jan 31, 2023
8bffa33
update: test_merge_small_dims
kozistr Jan 31, 2023
4225940
update: test_shampoo_optimizer
kozistr Jan 31, 2023
abcf6f0
update: BlockPartitioner
kozistr Jan 31, 2023
a0af041
fix: add_statistics when PreConditionerType is 1 (INPUT)
kozistr Jan 31, 2023
5da4074
style: format
kozistr Jan 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feature: PreConditionerType
  • Loading branch information
kozistr committed Jan 31, 2023
commit b6c167f5a7b50fbdab632748173e60cf85e92fed
68 changes: 59 additions & 9 deletions pytorch_optimizer/optimizer/shampoo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ def merge_partitions(self, partitions: List[torch.Tensor]) -> torch.Tensor:
return partitions[0]


class PreConditionerType(IntEnum):
r"""Type of PreConditioner.

In default (ALL), computes pre-conditioner for each dim.
INPUT is one-sided Shampoo, in this cases only on input dim.
Assumes last dim is always the output dim and everything else input dim.
"""

ALL = 0
INPUT = 1


class PreConditioner:
r"""Compute statistics/shape from gradients for preconditioning.

Expand All @@ -187,6 +199,7 @@ class PreConditioner:
:param block_size: int.
:param shape_interpretation: bool.
:param matrix_eps: float.
:param pre_conditioner_type: int. type of pre-conditioner.
"""

def __init__(
Expand All @@ -197,10 +210,12 @@ def __init__(
block_size: int,
shape_interpretation: bool,
matrix_eps: float,
pre_conditioner_type: int = PreConditionerType.ALL,
):
self.beta2 = beta2
self.inverse_exponent_override = inverse_exponent_override
self.matrix_eps = matrix_eps
self.pre_conditioner_type = pre_conditioner_type

self.original_shape: List[int] = var.shape
self.transformed_shape: List[int] = var.shape
Expand Down Expand Up @@ -236,18 +251,50 @@ def add_statistics(self, grad: torch.Tensor):
stat: torch.Tensor = torch.tensordot(partitioned_grad, partitioned_grad, [axes, axes])
self.statistics[j * rank + i].mul_(self.beta2).add_(stat, alpha=w2)

def exponent_for_pre_conditioner(self) -> int:
r"""Return exponent to use for inverse-pth root M^{-1/p}."""
def should_precondition_dims(self) -> List[bool]:
r"""A vector containing indicator indicating if the dim is preconditioned."""
rank: int = len(self.partitioner.split_sizes)
return (
self.inverse_exponent_override if self.inverse_exponent_override > 0 else 2 * len(self.transformed_shape)
[True] * rank
if self.pre_conditioner_type == PreConditionerType.ALL or rank <= 1
else [True] * (rank - 1) + [False]
)

def exponent_for_pre_conditioner(self) -> int:
r"""Return exponent to use for inverse-pth root M^{-1/p}."""
if self.inverse_exponent_override > 0:
return self.inverse_exponent_override

num_pre_conditioners: int = sum(self.should_precondition_dims())
return 2 * num_pre_conditioners

def compute_pre_conditioners(self):
r"""Compute L^{-1/exp} for each stats matrix L."""
exp: int = self.exponent_for_pre_conditioner()
for i, stat in enumerate(self.statistics):
self.pre_conditioners[i] = compute_power(stat, exp, ridge_epsilon=self.matrix_eps)

@staticmethod
def precondition_block(
partitioned_grad: torch.Tensor,
should_preconditioned_dims: List[bool],
pre_conditioners_for_grad: List[torch.Tensor],
) -> torch.Tensor:
r"""Perform a preconditioning op on a single gradient block.

Loop invariant: the dimension to be preconditioned is first
We keep all axes in the same cyclic order they were originally.
"""

for j, should_precondition in enumerate(should_preconditioned_dims):
rank: int = len(partitioned_grad.shape)
if not should_precondition:
roll = tuple(range(1, rank)) + (0,)
partitioned_grad = torch.permute(partitioned_grad, roll)
continue
partitioned_grad = torch.tensordot(partitioned_grad, pre_conditioners_for_grad[j], dims=[[0], [0]])
return partitioned_grad

def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor:
r"""Precondition the gradient.

Expand All @@ -259,15 +306,18 @@ def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor:
reshaped_grad = torch.reshape(grad, self.transformed_shape)
partitioned_grads = self.partitioner.partition(reshaped_grad)

num_splits: int = self.partitioner.num_splits
should_precondition_dims: List[bool] = self.should_precondition_dims()
num_pre_conditioners: int = sum(should_precondition_dims)

pre_cond_partitioned_grads: List[torch.Tensor] = []
for i, partitioned_grad in enumerate(partitioned_grads):
pre_conditioners_for_grad = self.pre_conditioners[i * num_splits:(i + 1) * num_splits] # fmt: skip
rank: int = len(partitioned_grad.shape)
pre_conditioners_for_grad = self.pre_conditioners[
i * num_pre_conditioners:(i + 1) * num_pre_conditioners
] # fmt: skip

pre_cond_grad = partitioned_grad
for j in range(rank):
pre_cond_grad = torch.tensordot(pre_cond_grad, pre_conditioners_for_grad[j], [[0], [0]])
pre_cond_grad = self.precondition_block(
partitioned_grad, should_precondition_dims, pre_conditioners_for_grad
)

pre_cond_partitioned_grads.append(pre_cond_grad)

Expand Down