Skip to content
This repository has been archived by the owner on Feb 12, 2022. It is now read-only.

Commit

Permalink
Bugfix: parameters of both the model as well as the criterion (Adapti…
Browse files Browse the repository at this point in the history
…ve Softmax) must be optimized
  • Loading branch information
Smerity committed Apr 3, 2018
1 parent 1e24cc5 commit 1f0982e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,11 @@ def model_load(fn):
print('Using', splits)
criterion = SplitCrossEntropyLoss(args.emsize, splits=splits, verbose=False)
###
params = list(model.parameters()) + list(criterion.parameters())
if args.cuda:
model = model.cuda()
criterion = criterion.cuda()
params = list(model.parameters()) + list(criterion.parameters())
###
params = list(model.parameters()) + list(criterion.parameters())
total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size())
print('Args:', args)
print('Model total parameters:', total_params)
Expand Down Expand Up @@ -206,7 +205,7 @@ def train():
loss.backward()

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
if args.clip: torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
if args.clip: torch.nn.utils.clip_grad_norm(params, args.clip)
optimizer.step()

total_loss += raw_loss.data
Expand All @@ -232,10 +231,11 @@ def train():
# At any point you can hit Ctrl + C to break out of training early.
try:
optimizer = None
# Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax)
if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
optimizer = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay)
if args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay)
for epoch in range(1, args.epochs+1):
epoch_start_time = time.time()
train()
Expand Down

0 comments on commit 1f0982e

Please sign in to comment.