Skip to content

Commit

Permalink
Fix "RuntimeError: Expected object of type torch.FloatTensor but foun…
Browse files Browse the repository at this point in the history
…d type torch.cuda.FloatTensor for argument SeanNaren#4 'other'" for File "train.py", line 270, in <module> optimizer.step()
  • Loading branch information
shuieryin committed Jul 23, 2018
1 parent 60ef8ba commit bcad275
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def update(self, val, n=1):
optimizer = torch.optim.SGD(parameters, lr=args.lr,
momentum=args.momentum, nesterov=True)
if not args.finetune: # Don't want to restart training
if args.cuda:
model.cuda()
optimizer.load_state_dict(package['optim_dict'])
start_epoch = int(package.get('epoch', 1)) - 1 # Index start at 0 for training
start_iter = package.get('iteration', None)
Expand Down

0 comments on commit bcad275

Please sign in to comment.