Skip to content

Commit

Permalink
Merge pull request SeanNaren#21 from SeanNaren/visdom-rebase
Browse files Browse the repository at this point in the history
Added visdom support
  • Loading branch information
Sean Naren committed Mar 30, 2017
2 parents d9c7bcf + 23a9e73 commit 77aa36a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,9 @@ python train.py --train_manifest data/train_manifest.csv --val_manifest data/val
```

Use `python train.py --help` for more parameters and options.

There is also [Visdom](https://github.com/facebookresearch/visdom) support to visualise training. Once a server has been started, to use:

```
python train.py --visdom true
```
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
python-levenshtein
librosa
torch
torch
visdom
43 changes: 41 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
parser.add_argument('--learning_anneal', default=1.1, type=float, help='Annealing applied to learning rate every epoch')
parser.add_argument('--silent', default=False, type=bool, help='Turn off progress tracking per iteration')
parser.add_argument('--epoch_save', default=False, type=bool, help='Save model every epoch')
parser.add_argument('--visdom', default=False, type=bool, help='Turn on visdom graphing')
parser.add_argument('--save_folder', default='models/', help='Location to save epoch models')
parser.add_argument('--final_model_path', default='models/deepspeech_final.pth.tar',
help='Location to save final model')
Expand Down Expand Up @@ -72,6 +73,20 @@ def checkpoint(model, args, nout, epoch=None):
def main():
args = parser.parse_args()
save_folder = args.save_folder

if args.visdom:
from visdom import Visdom
viz = Visdom()

opts = [dict(title='Loss', ylabel='Loss', xlabel='Epoch'),
dict(title='WER', ylabel='WER', xlabel='Epoch'),
dict(title='CER', ylabel='CER', xlabel='Epoch')]

viz_windows = [None, None, None]
loss_results, cer_results, wer_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor(
args.epochs)
epochs = torch.range(1, args.epochs)

try:
os.makedirs(save_folder)
except OSError as e:
Expand Down Expand Up @@ -168,7 +183,7 @@ def main():
avg_loss /= len(train_loader)
print('Training Summary Epoch: [{0}]\t'
'Average Loss {loss:.3f}\t'.format(
(epoch + 1), loss=avg_loss))
epoch + 1, loss=avg_loss))

total_cer, total_wer = 0, 0
for i, (data) in enumerate(test_loader): # test
Expand Down Expand Up @@ -202,11 +217,35 @@ def main():

wer = total_wer / len(test_loader.dataset)
cer = total_cer / len(test_loader.dataset)
wer *= 100
cer *= 100

print('Validation Summary Epoch: [{0}]\t'
'Average WER {wer:.0f}\t'
'Average CER {cer:.0f}\t'.format(
(epoch + 1), wer=wer * 100, cer=cer * 100))
epoch + 1, wer=wer, cer=cer))

if args.visdom:
loss_results[epoch] = avg_loss
wer_results[epoch] = wer
cer_results[epoch] = cer
epoch += 1
x_axis = epochs[0:epoch]
y_axis = [loss_results[0:epoch], wer_results[0:epoch], cer_results[0:epoch]]
for x in range(len(viz_windows)):
if viz_windows[x] is None:
viz_windows[x] = viz.line(
X=x_axis,
Y=y_axis[x],
opts=opts[x],
)
else:
viz.line(
X=x_axis,
Y=y_axis[x],
win=viz_windows[x],
update='replace',
)
if args.epoch_save:
file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch)
torch.save(checkpoint(model, args, len(labels), epoch), file_path)
Expand Down

0 comments on commit 77aa36a

Please sign in to comment.