Skip to content

Commit

Permalink
fix (memory): disable cudnn backend
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Jun 12, 2017
1 parent 858e928 commit d56b34b
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from torch.autograd import Variable
from torch.backends import cudnn
from warpctc_pytorch import CTCLoss

from data.data_loader import AudioDataLoader, SpectrogramDataset
Expand Down Expand Up @@ -53,11 +54,14 @@
parser.add_argument('--tensorboard', dest='tensorboard', action='store_true', help='Turn on tensorboard graphing')
parser.add_argument('--log_dir', default='visualize/deepspeech_final', help='Location of tensorboard log')
parser.add_argument('--log_params', dest='log_params', action='store_true', help='Log parameter values and gradients')
parser.set_defaults(cuda=False, silent=False, checkpoint=False, visdom=False, visdom_server='http:https://localhost', augment=False, tensorboard=False, log_params=False)
parser.set_defaults(cuda=False, silent=False, checkpoint=False, visdom=False, visdom_server='http:https://localhost',
augment=False, tensorboard=False, log_params=False)


def to_np(x):
return x.data.cpu().numpy()


class AverageMeter(object):
"""Computes and stores the average and current value"""

Expand All @@ -76,6 +80,7 @@ def update(self, val, n=1):
self.count += n
self.avg = self.sum / self.count


def main():
args = parser.parse_args()
save_folder = args.save_folder
Expand Down Expand Up @@ -171,8 +176,9 @@ def main():
if args.visdom and \
package['loss_results'] is not None and start_epoch > 0: # Add previous scores to visdom graph
epoch = start_epoch
loss_results[0:epoch], cer_results[0:epoch], wer_results[0:epoch] = package['loss_results'], package['cer_results'], package[
'wer_results']
loss_results[0:epoch], cer_results[0:epoch], wer_results[0:epoch] = package['loss_results'], package[
'cer_results'], package[
'wer_results']
x_axis = epochs[0:epoch]
y_axis = [loss_results[0:epoch], wer_results[0:epoch], cer_results[0:epoch]]
for x in range(len(viz_windows)):
Expand All @@ -181,7 +187,8 @@ def main():
Y=y_axis[x],
opts=opts[x],
)
if args.tensorboard and package['loss_results'] is not None and start_epoch > 0: # Add previous scores to tensorboard logs
if args.tensorboard and package[
'loss_results'] is not None and start_epoch > 0: # Add previous scores to tensorboard logs
epoch = start_epoch
loss_results, cer_results, wer_results = package['loss_results'], package['cer_results'], package[
'wer_results']
Expand All @@ -192,7 +199,7 @@ def main():
'Avg CER': cer_results[i]
}
for tag, val in info.items():
logger.scalar_summary(tag, val, i+1)
logger.scalar_summary(tag, val, i + 1)
else:
avg_loss = 0
start_epoch = 0
Expand Down Expand Up @@ -265,7 +272,8 @@ def main():
if args.checkpoint_per_batch > 0 and i > 0 and (i + 1) % args.checkpoint_per_batch == 0:
file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth.tar' % (save_folder, epoch + 1, i + 1)
print("Saving checkpoint model to %s" % file_path)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, iteration=i, loss_results=loss_results,
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, iteration=i,
loss_results=loss_results,
wer_results=wer_results, cer_results=cer_results, avg_loss=avg_loss),
file_path)
avg_loss /= len(train_loader)
Expand All @@ -275,7 +283,7 @@ def main():
epoch + 1, loss=avg_loss))

start_iter = 0 # Reset start iteration for next epoch
total_cer, total_wer= 0, 0
total_cer, total_wer = 0, 0
model.eval()
for i, (data) in enumerate(test_loader): # test
inputs, targets, input_percentages, target_sizes = data
Expand Down Expand Up @@ -323,8 +331,8 @@ def main():
wer_results[epoch] = wer
cer_results[epoch] = cer
# epoch += 1
x_axis = epochs[0:epoch+1]
y_axis = [loss_results[0:epoch+1], wer_results[0:epoch+1], cer_results[0:epoch+1]]
x_axis = epochs[0:epoch + 1]
y_axis = [loss_results[0:epoch + 1], wer_results[0:epoch + 1], cer_results[0:epoch + 1]]
for x in range(len(viz_windows)):
if viz_windows[x] is None:
viz_windows[x] = viz.line(
Expand All @@ -349,13 +357,13 @@ def main():
'Avg CER': cer
}
for tag, val in info.items():
logger.scalar_summary(tag, val, epoch+1)
logger.scalar_summary(tag, val, epoch + 1)
if args.log_params:
for tag, value in model.named_parameters():
tag = tag.replace('.', '/')
logger.histo_summary(tag, to_np(value), epoch+1)
if value.grad is not None: # Condition inserted because batch_norm RNN_0 weights.grad and bias.grad are None. Check why
logger.histo_summary(tag+'/grad', to_np(value.grad), epoch+1)
logger.histo_summary(tag, to_np(value), epoch + 1)
if value.grad is not None: # Condition inserted because batch_norm RNN_0 weights.grad and bias.grad are None. Check why
logger.histo_summary(tag + '/grad', to_np(value.grad), epoch + 1)
if args.checkpoint:
file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results,
Expand All @@ -372,4 +380,5 @@ def main():


if __name__ == '__main__':
cudnn.enabled = False
main()

0 comments on commit d56b34b

Please sign in to comment.