Skip to content

Commit

Permalink
+bucketing
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorLakomkin committed Jun 11, 2017
1 parent e048430 commit 6441435
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
33 changes: 33 additions & 0 deletions data/bucketing_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from torch.utils.data.sampler import Sampler
import numpy as np
from data.data_loader import SpectrogramDataset, load_audio
from collections import defaultdict

class SpectrogramDatasetWithLength(SpectrogramDataset):
def __init__(self, *args, **kwargs):
super(SpectrogramDatasetWithLength, self).__init__(*args, **kwargs)
audio_paths = [path for (path, _) in self.ids]
audio_lengths = [len(load_audio(path)) for path in audio_paths ]
hist, bin_edges = np.histogram(audio_lengths, bins="auto")
audio_samples_indices = np.digitize(audio_lengths, bins=bin_edges)
self.bins_to_samples = defaultdict(list)
for idx, bin_id in enumerate(audio_samples_indices):
self.bins_to_samples[bin_id].append(idx)

class BucketingSampler(Sampler):
"""
"""

def __init__(self, data_source):
self.data_source = data_source
assert hasattr(self.data_source, 'bins_to_samples')

def __iter__(self):
for bin, sample_idx in self.data_source.bins_to_samples.items():
np.random.shuffle(sample_idx)
for s in sample_idx:
yield s

def __len__(self):
return len(self.data_source)

12 changes: 11 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.autograd import Variable
from warpctc_pytorch import CTCLoss

from data.bucketing_sampler import BucketingSampler, SpectrogramDatasetWithLength
from data.data_loader import AudioDataLoader, SpectrogramDataset
from decoder import ArgMaxDecoder
from model import DeepSpeech, supported_rnns
Expand Down Expand Up @@ -52,6 +53,7 @@
parser.add_argument('--tensorboard', dest='tensorboard', action='store_true', help='Turn on tensorboard graphing')
parser.add_argument('--log_dir', default='visualize/deepspeech_final', help='Location of tensorboard log')
parser.add_argument('--log_params', dest='log_params', action='store_true', help='Log parameter values and gradients')
parser.add_argument('--bucketing', dest='bucketing', action='store_true', help='Split utterances into buckets and sample from them')
parser.set_defaults(cuda=False, silent=False, checkpoint=False, visdom=False, augment=False, tensorboard=False, log_params=False)

def to_np(x):
Expand Down Expand Up @@ -215,7 +217,7 @@ def main():
data_time.update(time.time() - end)
inputs = Variable(inputs, requires_grad=False)
target_sizes = Variable(target_sizes, requires_grad=False)
targets = Variable(targets. reqiores_grad=False)
targets = Variable(targets, requires_grad=False)

if args.cuda:
inputs = inputs.cuda()
Expand Down Expand Up @@ -367,6 +369,14 @@ def main():
print('Learning rate annealed to: {lr:.6f}'.format(lr=optim_state['param_groups'][0]['lr']))

avg_loss = 0
if args.bucketing and epoch == 0:
print("Switching to bucketing")
train_dataset = SpectrogramDatasetWithLength(audio_conf=audio_conf, manifest_filepath=args.train_manifest,
labels=labels,
normalize=True, augment=args.augment)
sampler = BucketingSampler(train_dataset)
train_loader.sampler = sampler

torch.save(DeepSpeech.serialize(model, optimizer=optimizer), args.final_model_path)


Expand Down

0 comments on commit 6441435

Please sign in to comment.