Skip to content

Commit

Permalink
update: TRAC optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Aug 4, 2024
1 parent b5df165 commit 6742620
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 40 deletions.
61 changes: 23 additions & 38 deletions pytorch_optimizer/optimizer/trac.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,21 @@ class TRAC(BaseOptimizer):
Here's an example::
model = YourModel()
base_optimizer = AdamW
optimizer = TRAC(model.parameters(), base_optimizer)
optimizer = TRAC(AdamW(model.parameters()))
for input, output in data:
loss = loss_function(output, model(input))
optimizer.zero_grad()
loss = loss_fn(model(input), output)
loss.backward()
optimizer.step()
optimizer.zero_grad()
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param base_optimizer: Optimizer. base optimizer.
:param optimizer: Optimizer. base optimizer.
:param betas: List[float]. list of beta values.
:param num_coefs: int. the number of polynomial coefficients to use in the approximation.
:param s_prev: float. initial scale value.
:param eps: float. term added to the denominator to improve numerical stability.
:param kwargs: Dict. parameters for optimizer.
"""

def __init__(
Expand All @@ -117,9 +116,8 @@ def __init__(
num_coefs: int = 128,
s_prev: float = 1e-8,
eps: float = 1e-8,
**kwargs,
):
self.validate_non_negative(num_coefs, 'num_coefs')
self.validate_positive(num_coefs, 'num_coefs')
self.validate_non_negative(s_prev, 's_prev')
self.validate_non_negative(eps, 'eps')

Expand All @@ -133,15 +131,7 @@ def __init__(

self.optimizer = optimizer
self.state: STATE = defaultdict(dict)

self.defaults: DEFAULTS = {
'trac_betas': betas,
'trac_num_coefs': num_coefs,
'trac_s_prev': s_prev,
'trac_eps': eps,
**optimizer.defaults,
**kwargs,
}
self.defaults: DEFAULTS = optimizer.defaults

def __str__(self) -> str:
return 'TRAC'
Expand All @@ -152,28 +142,26 @@ def param_groups(self):

@torch.no_grad()
def reset(self):
device = next(iter(self.param_groups['params'][0])).device
device = self.param_groups[0]['params'][0].device

self.state = {
'betas': torch.tensor(self.betas, device=device),
's_prev': torch.tensor(self.s_prev, device=device),
'eps': self.eps,
's': torch.zeros(len(self.betas), device=device),
'theta_ref': {},
'variance': torch.zeros(len(self.betas), device=device),
'sigma': torch.full((len(self.betas),), 1e-8, device=device),
'step': 0,
}

for group in self.param_groups:
for p in group['params']:
self.state[p] = {'ref': p.clone()}
self.state[p] = p.clone()

@torch.no_grad()
def zero_grad(self) -> None:
self.optimizer.zero_grad(set_to_none=True)

def erfi(self, x: torch.Tensor) -> torch.Tensor:
@torch.no_grad()
def erf_imag(self, x: torch.Tensor) -> torch.Tensor:
if not torch.is_floating_point(x):
x = x.to(torch.float32)

Expand All @@ -198,18 +186,18 @@ def trac_step(self, updates: Dict, grads: Dict) -> None:

deltas = {}

device = updates[next(iter(updates.keys()))].device
device = self.param_groups[0]['params'][0].device

h = torch.zeros((1,), device=device)
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue

Check warning on line 195 in pytorch_optimizer/optimizer/trac.py

View check run for this annotation

Codecov / codecov/patch

pytorch_optimizer/optimizer/trac.py#L195

Added line #L195 was not covered by tests

theta_ref = self.state[p]['ref']
theta_ref = self.state[p]
update = updates[p]

deltas[p] = (update - theta_ref) / (torch.sum(self.state['s']) + self.state['eps'])
deltas[p] = (update - theta_ref) / (torch.sum(self.state['s']) + self.eps)
update.neg_().add_(p)

grad, delta = grads[p], deltas[p]
Expand All @@ -220,48 +208,45 @@ def trac_step(self, updates: Dict, grads: Dict) -> None:
delta.add_(update)

s = self.state['s']
s_prev = self.state['s_prev']
betas = self.state['betas']
eps = self.state['eps']
variance = self.state['variance']
sigma = self.state['sigma']

variance.mul_(betas.pow(2)).add_(h.pow(2))
sigma.mul_(betas).sub_(h)

f_term = s_prev / self.erfi(1.0 / torch.sqrt(torch.tensor(2.0)))
s_term = self.erfi(sigma / (torch.sqrt(torch.tensor(2.0)) * variance.sqrt() + eps))
f_term = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))
s_term = self.erf_imag(sigma / (torch.sqrt(torch.tensor(2.0)) * variance.sqrt() + self.eps))
s.copy_(f_term * s_term)

for group in self.param_groups:
for p in group['params']:
if grads[p] is None:
continue

Check warning on line 225 in pytorch_optimizer/optimizer/trac.py

View check run for this annotation

Codecov / codecov/patch

pytorch_optimizer/optimizer/trac.py#L225

Added line #L225 was not covered by tests

p.copy_(self.state[p]['ref'] + deltas[p] * max(torch.sum(s), 0.0))
p.copy_(self.state[p] + deltas[p] * max(torch.sum(s), 0.0))

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
updates, grads = self.backup_params_and_grads()
with torch.enable_grad():
loss = self.optimizer.step(closure)

loss = self.optimizer.step(closure)
updates, grads = self.backup_params_and_grads()

if len(self.state) == 0:
device = updates[next(iter(updates.keys()))].device

self.state = {
'betas': torch.tensor(self.betas, device=device),
's_prev': torch.tensor(self.s_prev, device=device),
'eps': self.eps,
's': torch.zeros(len(self.betas), device=device),
'theta_ref': {},
'variance': torch.zeros(len(self.betas), device=device),
'sigma': torch.full((len(self.betas),), 1e-8, device=device),
'step': 0,
}

for group in self.param_groups:
for p in group['params']:
self.state[p] = {'ref': updates[p].clone()}
self.state[p] = updates[p].clone()

self.trac_step(updates, grads)

Expand Down
18 changes: 16 additions & 2 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def test_trac_optimizer(environment):
optimizer = TRAC(load_optimizer('adamw')(model.parameters(), lr=1e0))

init_loss, loss = np.inf, np.inf
for _ in range(10):
for _ in range(5):
optimizer.zero_grad()

y_pred = model(x_data)
Expand All @@ -689,4 +689,18 @@ def test_trac_optimizer(environment):

optimizer.step()

assert tensor_to_numpy(init_loss) > 1.5 * tensor_to_numpy(loss)
assert tensor_to_numpy(init_loss) > 2.0 * tensor_to_numpy(loss)


def test_trac_optimizer_erf_imag():
model = Example()

optimizer = TRAC(load_optimizer('adamw')(model.parameters()))

optimizer.reset()
optimizer.zero_grad()

complex_tensor = torch.complex(torch.tensor(0.0), torch.tensor(1.0))
optimizer.erf_imag(complex_tensor)

assert str(optimizer).lower() == 'trac'

0 comments on commit 6742620

Please sign in to comment.