Skip to content

Commit

Permalink
fix gradient accumulation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
moskomule committed Dec 20, 2021
1 parent 31cbee0 commit d5c914c
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions homura/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,12 +552,16 @@ def iteration(self,
input, target = data

if self.is_train:
for input, target in zip(input.chunk(self.grad_accum_steps), target.chunk(self.grad_accum_steps)):
with torch.cuda.amp.autocast(self._use_amp):
output = self.model(input)
loss = self.loss_f(output, target)
# this code supports both AMP and non AMP
self.scaler.scale(loss).backward()
loss = 0
context = self.model.no_sync if is_distributed() else contextlib.nullcontext
with context():
for input, target in zip(input.chunk(self.grad_accum_steps), target.chunk(self.grad_accum_steps)):
with torch.cuda.amp.autocast(self._use_amp):
output = self.model(input)
_loss = self.loss_f(output, target) / self.grad_accum_steps
loss += _loss.detach()
# this code supports both AMP and non AMP
self.scaler.scale(_loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
Expand Down

0 comments on commit d5c914c

Please sign in to comment.