From 6441435007c0ca1a2f041fc2521673bf28378a39 Mon Sep 17 00:00:00 2001 From: EgorLakomkin Date: Sun, 11 Jun 2017 18:45:46 +0200 Subject: [PATCH] +bucketing --- data/bucketing_sampler.py | 33 +++++++++++++++++++++++++++++++++ train.py | 12 +++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 data/bucketing_sampler.py diff --git a/data/bucketing_sampler.py b/data/bucketing_sampler.py new file mode 100644 index 00000000..5fe79b7a --- /dev/null +++ b/data/bucketing_sampler.py @@ -0,0 +1,33 @@ +from torch.utils.data.sampler import Sampler +import numpy as np +from data.data_loader import SpectrogramDataset, load_audio +from collections import defaultdict + +class SpectrogramDatasetWithLength(SpectrogramDataset): + def __init__(self, *args, **kwargs): + super(SpectrogramDatasetWithLength, self).__init__(*args, **kwargs) + audio_paths = [path for (path, _) in self.ids] + audio_lengths = [len(load_audio(path)) for path in audio_paths ] + hist, bin_edges = np.histogram(audio_lengths, bins="auto") + audio_samples_indices = np.digitize(audio_lengths, bins=bin_edges) + self.bins_to_samples = defaultdict(list) + for idx, bin_id in enumerate(audio_samples_indices): + self.bins_to_samples[bin_id].append(idx) + +class BucketingSampler(Sampler): + """ + """ + + def __init__(self, data_source): + self.data_source = data_source + assert hasattr(self.data_source, 'bins_to_samples') + + def __iter__(self): + for bin, sample_idx in self.data_source.bins_to_samples.items(): + np.random.shuffle(sample_idx) + for s in sample_idx: + yield s + + def __len__(self): + return len(self.data_source) + diff --git a/train.py b/train.py index 66b8c219..a53f50f4 100644 --- a/train.py +++ b/train.py @@ -8,6 +8,7 @@ from torch.autograd import Variable from warpctc_pytorch import CTCLoss +from data.bucketing_sampler import BucketingSampler, SpectrogramDatasetWithLength from data.data_loader import AudioDataLoader, SpectrogramDataset from decoder import ArgMaxDecoder from model import DeepSpeech, supported_rnns @@ -52,6 +53,7 @@ parser.add_argument('--tensorboard', dest='tensorboard', action='store_true', help='Turn on tensorboard graphing') parser.add_argument('--log_dir', default='visualize/deepspeech_final', help='Location of tensorboard log') parser.add_argument('--log_params', dest='log_params', action='store_true', help='Log parameter values and gradients') +parser.add_argument('--bucketing', dest='bucketing', action='store_true', help='Split utterances into buckets and sample from them') parser.set_defaults(cuda=False, silent=False, checkpoint=False, visdom=False, augment=False, tensorboard=False, log_params=False) def to_np(x): @@ -215,7 +217,7 @@ def main(): data_time.update(time.time() - end) inputs = Variable(inputs, requires_grad=False) target_sizes = Variable(target_sizes, requires_grad=False) - targets = Variable(targets. reqiores_grad=False) + targets = Variable(targets, requires_grad=False) if args.cuda: inputs = inputs.cuda() @@ -367,6 +369,14 @@ def main(): print('Learning rate annealed to: {lr:.6f}'.format(lr=optim_state['param_groups'][0]['lr'])) avg_loss = 0 + if args.bucketing and epoch == 0: + print("Switching to bucketing") + train_dataset = SpectrogramDatasetWithLength(audio_conf=audio_conf, manifest_filepath=args.train_manifest, + labels=labels, + normalize=True, augment=args.augment) + sampler = BucketingSampler(train_dataset) + train_loader.sampler = sampler + torch.save(DeepSpeech.serialize(model, optimizer=optimizer), args.final_model_path)