Skip to content

Commit

Permalink
Merge pull request SeanNaren#92 from SiddGururani/master
Browse files Browse the repository at this point in the history
Fixed a bug and cleaned up the code in continue_from section of train.py
  • Loading branch information
Sean Naren committed Jun 16, 2017
2 parents cf20b8d + 528a60e commit dba0e06
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,12 @@ def main():
else:
start_iter += 1
avg_loss = int(package.get('avg_loss', 0))
loss_results, cer_results, wer_results = package['loss_results'], package[
'cer_results'], package['wer_results']
if args.visdom and \
package['loss_results'] is not None and start_epoch > 0: # Add previous scores to visdom graph
epoch = start_epoch
loss_results[0:epoch], cer_results[0:epoch], wer_results[0:epoch] = package['loss_results'], package[
'cer_results'], package['wer_results']
x_axis = epochs[0:epoch]
y_axis = [loss_results[0:epoch], wer_results[0:epoch], cer_results[0:epoch]]
x_axis = epochs[0:start_epoch]
y_axis = [loss_results[0:start_epoch], wer_results[0:start_epoch], cer_results[0:start_epoch]]
for x in range(len(viz_windows)):
viz_windows[x] = viz.line(
X=x_axis,
Expand All @@ -183,9 +182,7 @@ def main():
)
if args.tensorboard and \
package['loss_results'] is not None and start_epoch > 0: # Previous scores to tensorboard logs
loss_results, cer_results, wer_results = package['loss_results'], package['cer_results'], package[
'wer_results']
for i in range(len(loss_results)):
for i in range(start_epoch):
info = {
'Avg Train Loss': loss_results[i],
'Avg WER': wer_results[i],
Expand Down Expand Up @@ -327,7 +324,6 @@ def main():
epoch + 1, wer=wer, cer=cer))

if args.visdom:

# epoch += 1
x_axis = epochs[0:epoch + 1]
y_axis = [loss_results[0:epoch + 1], wer_results[0:epoch + 1], cer_results[0:epoch + 1]]
Expand Down

0 comments on commit dba0e06

Please sign in to comment.