Skip to content

Commit

Permalink
fix gradient accumulation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
moskomule committed Dec 20, 2021
1 parent 70658e3 commit 4434d63
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions homura/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,15 +553,17 @@ def iteration(self,

if self.is_train:
loss = 0
context = self.model.join if is_distributed() else contextlib.nullcontext
with context():
for input, target in zip(input.chunk(self.grad_accum_steps), target.chunk(self.grad_accum_steps)):
with torch.cuda.amp.autocast(self._use_amp):
output = self.model(input)
_loss = self.loss_f(output, target) / self.grad_accum_steps
loss += _loss.detach()
# this code supports both AMP and non AMP
self.scaler.scale(_loss).backward()

for i, (input, target) in enumerate(zip(input.chunk(self.grad_accum_steps),
target.chunk(self.grad_accum_steps))):
context = self.model.no_sync if is_distributed() and i < self.grad_accum_steps - 1 \
else contextlib.nullcontext
with torch.cuda.amp.autocast(self._use_amp), context():
output = self.model(input)
_loss = self.loss_f(output, target) / self.grad_accum_steps
loss += _loss.detach()
# this code supports both AMP and non AMP
self.scaler.scale(_loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
Expand Down

0 comments on commit 4434d63

Please sign in to comment.