Skip to content

Commit

Permalink
fixed cuda bug in optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
plkmo committed Sep 16, 2019
1 parent 78540e7 commit d2a7bcc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
parser.add_argument("--tokens_length", type=int, default=200, help="Max tokens length for BERT")
parser.add_argument("--gradient_acc_steps", type=int, default=1, help="No. of steps of gradient accumulation")
parser.add_argument("--max_norm", type=float, default=1.0, help="Clipped gradient norm")
parser.add_argument("--num_epochs", type=int, default=25, help="No of epochs")
parser.add_argument("--num_epochs", type=int, default=125, help="No of epochs")
parser.add_argument("--lr", type=float, default=0.0007, help="learning rate")
parser.add_argument("--model_no", type=int, default=0, help="Model ID: (0: BERT, 1: XLNet)")

Expand Down
8 changes: 5 additions & 3 deletions ner/train_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def load_model_and_optimizer(args, cuda=False):
for p in net.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)


if cuda:
net.cuda()

criterion = nn.CrossEntropyLoss(ignore_index=0) # ignore padding tokens
optimizer = optim.Adam([{"params":net.bert.parameters(),"lr": args.lr/5},\
{"params":net.classifier.parameters(), "lr": args.lr}])
Expand All @@ -37,8 +40,7 @@ def load_model_and_optimizer(args, cuda=False):

start_epoch, acc = load_state(net, optimizer, scheduler, args, load_best=False)

if cuda:
net.cuda()


return net, criterion, optimizer, scheduler, start_epoch, acc

Expand Down

0 comments on commit d2a7bcc

Please sign in to comment.