Skip to content

Commit

Permalink
Merge pull request SeanNaren#20 from dpressel/master
Browse files Browse the repository at this point in the history
Using pytorch builtin avoids no-grad params
  • Loading branch information
Sean Naren committed Mar 29, 2017
2 parents 159743c + c8e8268 commit d9c7bcf
Showing 1 changed file with 1 addition and 10 deletions.
11 changes: 1 addition & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,7 @@ def main():
optimizer.zero_grad()
loss.backward()

# rescale gradients if necessary
total_norm = torch.FloatTensor([0])
for param in model.parameters():
param = param.norm().pow(2).data.cpu()
total_norm.add_(param)
total_norm = total_norm.sqrt()
if total_norm[0] > args.max_norm:
for param in model.parameters():
param.grad.mul_(args.max_norm / total_norm[0])

torch.nn.utils.clip_grad_norm(model.parameters(), args.max_norm)
# SGD step
optimizer.step()

Expand Down

0 comments on commit d9c7bcf

Please sign in to comment.