Skip to content

Commit

Permalink
Change early stopping time
Browse files Browse the repository at this point in the history
  • Loading branch information
heiheiyoyo committed Jul 11, 2022
1 parent 359ac43 commit d8c918c
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,33 @@ def trace_func(*args, **kwargs):
tokenizer.save_pretrained(save_dir)
tic_train = time.time()

if args.early_stopping:
# Early Stopping
early_stopping(dev_loss_avg, model)
if early_stopping.early_stop:
if show_bar:
with logging_redirect_tqdm([logger.logger]):
logger.info("Early stopping")
else:
if args.early_stopping:
dev_loss_avg, precision, recall, f1 = evaluate(
model, metric, data_loader=dev_data_loader, device=args.device, loss_fn=criterion)

if show_bar:
train_postfix_info.update({
'F1': f'{f1:.3f}',
'dev loss': f'{dev_loss_avg:.5f}'
})
train_data_iterator.set_postfix(train_postfix_info)
with logging_redirect_tqdm([logger.logger]):
logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f"
% (precision, recall, f1, dev_loss_avg))
else:
logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f"
% (precision, recall, f1, dev_loss_avg))

# Early Stopping
early_stopping(dev_loss_avg, model)
if early_stopping.early_stop:
if show_bar:
with logging_redirect_tqdm([logger.logger]):
logger.info("Early stopping")
tokenizer.save_pretrained(early_stopping_save_dir)
sys.exit(0)
else:
logger.info("Early stopping")
tokenizer.save_pretrained(early_stopping_save_dir)
sys.exit(0)


if __name__ == "__main__":
Expand Down

0 comments on commit d8c918c

Please sign in to comment.