Skip to content

Commit

Permalink
fix scaler
Browse files Browse the repository at this point in the history
  • Loading branch information
moskomule committed Dec 6, 2021
1 parent cd975f9 commit 3156c44
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions homura/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,8 @@ def __init__(self,
self.logger.info("model converted to DataParallel")

self._use_amp = use_amp
self.scaler = torch.cuda.amp.GradScaler(enabled=self._use_amp)
if self._use_amp:
self.scaler = torch.cuda.amp.GradScaler()
self.logger.info("AMP is activated")
self._use_channel_last = use_channel_last
if self._use_channel_last:
Expand All @@ -539,14 +539,11 @@ def iteration(self,
loss = self.loss_f(output, target)

if self.is_train:
# this code supports both AMP and non AMP
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self._use_amp:
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
self.optimizer.step()
if self.update_scheduler_iter:
self.scheduler.step()
if self._is_debug and torch.isnan(loss):
Expand Down

0 comments on commit 3156c44

Please sign in to comment.