From 23a9e73b9fa2059c237b557d3007ad02240266fe Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 20 Mar 2017 10:58:29 +0000 Subject: [PATCH] Added visdom support --- README.md | 6 ++++++ data/an4.py | 4 ++-- data/utils.py | 2 +- requirements.txt | 3 ++- train.py | 43 +++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 52 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index fb1f1263..d7332aac 100644 --- a/README.md +++ b/README.md @@ -62,3 +62,9 @@ python train.py --train_manifest data/train_manifest.csv --val_manifest data/val ``` Use `python train.py --help` for more parameters and options. + +There is also [Visdom](https://github.com/facebookresearch/visdom) support to visualise training. Once a server has been started, to use: + +``` +python train.py --visdom true +``` \ No newline at end of file diff --git a/data/an4.py b/data/an4.py index 3795d5c4..8ff0ac9a 100644 --- a/data/an4.py +++ b/data/an4.py @@ -5,7 +5,7 @@ import subprocess -from data.utils import create_manifest +from utils import create_manifest parser = argparse.ArgumentParser(description='Processes and downloads an4.') parser.add_argument('--an4_path', default='an4_dataset/', help='Path to save dataset') @@ -53,7 +53,7 @@ def _format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_ new_path = new_wav_path + filename text_path = new_transcript_path + filename.replace('.wav', '.txt') with io.FileIO(text_path, "w") as file: - file.write(extracted_transcript) + file.write(extracted_transcript.encode('utf-8')) os.rename(current_path, new_path) diff --git a/data/utils.py b/data/utils.py index 8ea1bd0e..f6e3b693 100644 --- a/data/utils.py +++ b/data/utils.py @@ -32,7 +32,7 @@ def create_manifest(data_path, tag, ordered=True): for wav_path in file_paths: transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt') sample = os.path.abspath(wav_path) + ',' + os.path.abspath(transcript_path) + '\n' - file.write(sample) + file.write(sample.encode('utf-8')) counter += 1 _update_progress(counter / float(size)) print('\n') diff --git a/requirements.txt b/requirements.txt index 888983eb..e9f686d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ python-levenshtein librosa -torch \ No newline at end of file +torch +visdom \ No newline at end of file diff --git a/train.py b/train.py index 81e0488a..2035bedf 100644 --- a/train.py +++ b/train.py @@ -34,6 +34,7 @@ parser.add_argument('--learning_anneal', default=1.1, type=float, help='Annealing applied to learning rate every epoch') parser.add_argument('--silent', default=False, type=bool, help='Turn off progress tracking per iteration') parser.add_argument('--epoch_save', default=False, type=bool, help='Save model every epoch') +parser.add_argument('--visdom', default=False, type=bool, help='Turn on visdom graphing') parser.add_argument('--save_folder', default='models/', help='Location to save epoch models') parser.add_argument('--final_model_path', default='models/deepspeech_final.pth.tar', help='Location to save final model') @@ -72,6 +73,20 @@ def checkpoint(model, args, nout, epoch=None): def main(): args = parser.parse_args() save_folder = args.save_folder + + if args.visdom: + from visdom import Visdom + viz = Visdom() + + opts = [dict(title='Loss', ylabel='Loss', xlabel='Epoch'), + dict(title='WER', ylabel='WER', xlabel='Epoch'), + dict(title='CER', ylabel='CER', xlabel='Epoch')] + + viz_windows = [None, None, None] + loss_results, cer_results, wer_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor( + args.epochs) + epochs = torch.range(1, args.epochs) + try: os.makedirs(save_folder) except OSError as e: @@ -177,7 +192,7 @@ def main(): avg_loss /= len(train_loader) print('Training Summary Epoch: [{0}]\t' 'Average Loss {loss:.3f}\t'.format( - (epoch + 1), loss=avg_loss)) + epoch + 1, loss=avg_loss)) total_cer, total_wer = 0, 0 for i, (data) in enumerate(test_loader): # test @@ -211,11 +226,35 @@ def main(): wer = total_wer / len(test_loader.dataset) cer = total_cer / len(test_loader.dataset) + wer *= 100 + cer *= 100 print('Validation Summary Epoch: [{0}]\t' 'Average WER {wer:.0f}\t' 'Average CER {cer:.0f}\t'.format( - (epoch + 1), wer=wer * 100, cer=cer * 100)) + epoch + 1, wer=wer, cer=cer)) + + if args.visdom: + loss_results[epoch] = avg_loss + wer_results[epoch] = wer + cer_results[epoch] = cer + epoch += 1 + x_axis = epochs[0:epoch] + y_axis = [loss_results[0:epoch], wer_results[0:epoch], cer_results[0:epoch]] + for x in range(len(viz_windows)): + if viz_windows[x] is None: + viz_windows[x] = viz.line( + X=x_axis, + Y=y_axis[x], + opts=opts[x], + ) + else: + viz.line( + X=x_axis, + Y=y_axis[x], + win=viz_windows[x], + update='replace', + ) if args.epoch_save: file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch) torch.save(checkpoint(model, args, len(labels), epoch), file_path)