Skip to content

Commit

Permalink
Merge pull request SeanNaren#83 from EgorLakomkin/bucketing
Browse files Browse the repository at this point in the history
Sampler that splits utterances into buckets
  • Loading branch information
Sean Naren committed Jun 15, 2017
2 parents 5ccac99 + 09d06d7 commit cea3b01
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 6 deletions.
42 changes: 42 additions & 0 deletions data/bucketing_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from torch.utils.data.sampler import Sampler
import numpy as np
from data.data_loader import SpectrogramDataset, load_audio
from collections import defaultdict


class SpectrogramDatasetWithLength(SpectrogramDataset):
def __init__(self, *args, **kwargs):
"""
SpectrogramDataset that splits utterances into buckets based on their length.
Bucketing is done via numpy's histogram method.
Used by BucketingSampler to sample utterances from the same bin.
"""
super(SpectrogramDatasetWithLength, self).__init__(*args, **kwargs)
audio_paths = [path for (path, _) in self.ids]
audio_lengths = [len(load_audio(path)) for path in audio_paths]
hist, bin_edges = np.histogram(audio_lengths, bins="auto")
audio_samples_indices = np.digitize(audio_lengths, bins=bin_edges)
self.bins_to_samples = defaultdict(list)
for idx, bin_id in enumerate(audio_samples_indices):
self.bins_to_samples[bin_id].append(idx)


class BucketingSampler(Sampler):
def __init__(self, data_source):
"""
Samples from a dataset that has been bucketed into bins of similar sized sequences to reduce
memory overhead.
:param data_source: The dataset to be sampled from
"""
super().__init__(data_source)
self.data_source = data_source
assert hasattr(self.data_source, 'bins_to_samples')

def __iter__(self):
for bin, sample_idx in self.data_source.bins_to_samples.items():
np.random.shuffle(sample_idx)
for s in sample_idx:
yield s

def __len__(self):
return len(self.data_source)
2 changes: 1 addition & 1 deletion data/merge_manifests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
file_path = files[x]
file_path = file_path.split(',')[0]
output = subprocess.check_output(
['soxi -D %s' % file_path.strip()],
['soxi -D \"%s\"' % file_path.strip()],
shell=True
)
duration = float(output)
Expand Down
2 changes: 1 addition & 1 deletion data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _order_files(file_paths):

def func(element):
output = subprocess.check_output(
['soxi -D %s' % element.strip()],
['soxi -D \"%s\"' % element.strip()],
shell=True
)
return float(output)
Expand Down
19 changes: 15 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.autograd import Variable
from warpctc_pytorch import CTCLoss

from data.bucketing_sampler import BucketingSampler, SpectrogramDatasetWithLength
from data.data_loader import AudioDataLoader, SpectrogramDataset
from decoder import ArgMaxDecoder
from model import DeepSpeech, supported_rnns
Expand Down Expand Up @@ -52,10 +53,10 @@
parser.add_argument('--tensorboard', dest='tensorboard', action='store_true', help='Turn on tensorboard graphing')
parser.add_argument('--log_dir', default='visualize/deepspeech_final', help='Location of tensorboard log')
parser.add_argument('--log_params', dest='log_params', action='store_true', help='Log parameter values and gradients')
parser.add_argument('--no_bucketing', dest='no_bucketing', action='store_false',
help='Turn off bucketing and sample from dataset based on sequence length (smallest to largest)')
parser.set_defaults(cuda=False, silent=False, checkpoint=False, visdom=False, augment=False, tensorboard=False,
log_params=False)


log_params=False, no_bucketing=False)
def to_np(x):
return x.data.cpu().numpy()

Expand Down Expand Up @@ -124,7 +125,6 @@ def main():

with open(args.labels_path) as label_file:
labels = str(''.join(json.load(label_file)))

audio_conf = dict(sample_rate=args.sample_rate,
window_size=args.window_size,
window_stride=args.window_stride,
Expand Down Expand Up @@ -271,6 +271,8 @@ def main():
loss_results=loss_results,
wer_results=wer_results, cer_results=cer_results, avg_loss=avg_loss),
file_path)
del loss
del out
avg_loss /= len(train_loader)

print('Training Summary Epoch: [{0}]\t'
Expand Down Expand Up @@ -311,6 +313,7 @@ def main():

if args.cuda:
torch.cuda.synchronize()
del out
wer = total_wer / len(test_loader.dataset)
cer = total_cer / len(test_loader.dataset)
wer *= 100
Expand Down Expand Up @@ -370,6 +373,14 @@ def main():
print('Learning rate annealed to: {lr:.6f}'.format(lr=optim_state['param_groups'][0]['lr']))

avg_loss = 0
if not args.no_bucketing and epoch == 0:
print("Switching to bucketing sampler for following epochs")
train_dataset = SpectrogramDatasetWithLength(audio_conf=audio_conf, manifest_filepath=args.train_manifest,
labels=labels,
normalize=True, augment=args.augment)
sampler = BucketingSampler(train_dataset)
train_loader.sampler = sampler

torch.save(DeepSpeech.serialize(model, optimizer=optimizer), args.final_model_path)


Expand Down

0 comments on commit cea3b01

Please sign in to comment.