diff --git a/README.md b/README.md index fb1f1263..d7332aac 100644 --- a/README.md +++ b/README.md @@ -62,3 +62,9 @@ python train.py --train_manifest data/train_manifest.csv --val_manifest data/val ``` Use `python train.py --help` for more parameters and options. + +There is also [Visdom](https://github.com/facebookresearch/visdom) support to visualise training. Once a server has been started, to use: + +``` +python train.py --visdom true +``` \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 888983eb..e9f686d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ python-levenshtein librosa -torch \ No newline at end of file +torch +visdom \ No newline at end of file diff --git a/train.py b/train.py index af71ef59..30805fb4 100644 --- a/train.py +++ b/train.py @@ -34,6 +34,7 @@ parser.add_argument('--learning_anneal', default=1.1, type=float, help='Annealing applied to learning rate every epoch') parser.add_argument('--silent', default=False, type=bool, help='Turn off progress tracking per iteration') parser.add_argument('--epoch_save', default=False, type=bool, help='Save model every epoch') +parser.add_argument('--visdom', default=False, type=bool, help='Turn on visdom graphing') parser.add_argument('--save_folder', default='models/', help='Location to save epoch models') parser.add_argument('--final_model_path', default='models/deepspeech_final.pth.tar', help='Location to save final model') @@ -72,6 +73,20 @@ def checkpoint(model, args, nout, epoch=None): def main(): args = parser.parse_args() save_folder = args.save_folder + + if args.visdom: + from visdom import Visdom + viz = Visdom() + + opts = [dict(title='Loss', ylabel='Loss', xlabel='Epoch'), + dict(title='WER', ylabel='WER', xlabel='Epoch'), + dict(title='CER', ylabel='CER', xlabel='Epoch')] + + viz_windows = [None, None, None] + loss_results, cer_results, wer_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor( + args.epochs) + epochs = torch.range(1, args.epochs) + try: os.makedirs(save_folder) except OSError as e: @@ -168,7 +183,7 @@ def main(): avg_loss /= len(train_loader) print('Training Summary Epoch: [{0}]\t' 'Average Loss {loss:.3f}\t'.format( - (epoch + 1), loss=avg_loss)) + epoch + 1, loss=avg_loss)) total_cer, total_wer = 0, 0 for i, (data) in enumerate(test_loader): # test @@ -202,11 +217,35 @@ def main(): wer = total_wer / len(test_loader.dataset) cer = total_cer / len(test_loader.dataset) + wer *= 100 + cer *= 100 print('Validation Summary Epoch: [{0}]\t' 'Average WER {wer:.0f}\t' 'Average CER {cer:.0f}\t'.format( - (epoch + 1), wer=wer * 100, cer=cer * 100)) + epoch + 1, wer=wer, cer=cer)) + + if args.visdom: + loss_results[epoch] = avg_loss + wer_results[epoch] = wer + cer_results[epoch] = cer + epoch += 1 + x_axis = epochs[0:epoch] + y_axis = [loss_results[0:epoch], wer_results[0:epoch], cer_results[0:epoch]] + for x in range(len(viz_windows)): + if viz_windows[x] is None: + viz_windows[x] = viz.line( + X=x_axis, + Y=y_axis[x], + opts=opts[x], + ) + else: + viz.line( + X=x_axis, + Y=y_axis[x], + win=viz_windows[x], + update='replace', + ) if args.epoch_save: file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch) torch.save(checkpoint(model, args, len(labels), epoch), file_path)