Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Improve overall performance of the optimizers #51

Merged
merged 34 commits into from
Feb 19, 2022
Merged
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f6fa0db
update: version to 0.4.0
kozistr Feb 19, 2022
73f8f5b
feature: improve performance
kozistr Feb 19, 2022
f7479bc
build(package): upgrade dev packages
kozistr Feb 19, 2022
b968b78
feature: improve perf
kozistr Feb 19, 2022
1c1d97c
fix: typo
kozistr Feb 19, 2022
431b8df
refactor: remove unused import
kozistr Feb 19, 2022
e293ac5
feature: improve perf
kozistr Feb 19, 2022
e0ce9d9
feature: improve perf
kozistr Feb 19, 2022
92afc46
feature: improve perf
kozistr Feb 19, 2022
dd145f8
feature: improve perf
kozistr Feb 19, 2022
daf4c6d
feature: improve perf
kozistr Feb 19, 2022
c37ddaa
feature: improve perf
kozistr Feb 19, 2022
b17a8db
feature: improve perf
kozistr Feb 19, 2022
49bef33
feature: improve perf
kozistr Feb 19, 2022
0cef0a6
feature: improve perf
kozistr Feb 19, 2022
6f2e92d
feature: improve perf
kozistr Feb 19, 2022
3e80d6f
feature: improve perf
kozistr Feb 19, 2022
6b5ea56
feature: improve perf
kozistr Feb 19, 2022
f706a25
feature: improve perf
kozistr Feb 19, 2022
eedc0c1
feature: improve perf
kozistr Feb 19, 2022
c9f237e
feature: improve perf
kozistr Feb 19, 2022
5d09f37
feature: improve perf
kozistr Feb 19, 2022
5ce1aad
fix: update
kozistr Feb 19, 2022
5adafbd
refactor: disable pylint warning W0212
kozistr Feb 19, 2022
2558782
update: step
kozistr Feb 19, 2022
01a984b
update: test_wd_ratio
kozistr Feb 19, 2022
4d21dcf
update: OPTIMIZERS
kozistr Feb 19, 2022
1445e85
update: test_closure
kozistr Feb 19, 2022
9f55ec6
update: ranger21
kozistr Feb 19, 2022
9d4a1a4
update: ranger21
kozistr Feb 19, 2022
534c200
update: OPTIMIZERS
kozistr Feb 19, 2022
a2067a5
refactor: grad
kozistr Feb 19, 2022
8657304
refactor: except Ranger21
kozistr Feb 19, 2022
658d3a9
refactor: test_optimizers
kozistr Feb 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feature: improve perf
  • Loading branch information
kozistr committed Feb 19, 2022
commit c37ddaa62733e39a6eb575dc4c99c6ef2fcd1506
16 changes: 8 additions & 8 deletions pytorch_optimizer/lars.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,37 +69,37 @@ def check_valid_parameters(self):
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()

for g in self.param_groups:
for p in g['params']:
if p.grad is None:
continue

if p.grad.data.is_sparse:
grad = p.grad
if grad.is_sparse:
raise RuntimeError('LARS does not support sparse gradients')

dp = p.grad

if p.ndim > 1: # if not normalization gamma/beta or bias
dp = dp.add(p, alpha=g['weight_decay'])
grad = grad.add(p, alpha=g['weight_decay'])
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
update_norm = torch.norm(grad)
one = torch.ones_like(param_norm)

q = torch.where(
param_norm > 0.0,
torch.where(update_norm > 0.0, (g['trust_coefficient'] * param_norm / update_norm), one),
one,
)
dp = dp.mul(q)
grad = grad.mul(q)

param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)

mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
mu.mul_(g['momentum']).add_(grad)

p.add_(mu, alpha=-g['lr'])

Expand Down