Skip to content

Commit

Permalink
Dataloader (#2)
Browse files Browse the repository at this point in the history
* Custom data-loader, changes to use dataloader

* Add requirements, changed folder name to reflect data
  • Loading branch information
Sean Naren committed Jan 27, 2017
1 parent e23cd9a commit 5fe10e8
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 133 deletions.
33 changes: 5 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# deepspeech.pytorch

# TODO
* Create data-loader separate from aeon. This is due to dependency and just some strange implementations that could be handled with some more pre-processing
* WER/CER are not in line with what is expected
* Fix WER/CER measurements
* Add tests for dataloading
* Script to download an4 and create manifests (streamline rather than 2 separate scripts)
* Support LibriSpeech via multi-processed scripts
* Cleaner Warp-CTC bindings that does not rely on numpy

Implementation of [Baidu Warp-CTC](https://github.com/baidu-research/warp-ctc) using pytorch.
Creates a network based on the [DeepSpeech2](http:https://arxiv.org/pdf/1512.02595v1.pdf) architecture, trained with the CTC activation function.
Expand All @@ -19,15 +18,6 @@ Install pytorch if you haven't already:
conda install pytorch -c https://conda.anaconda.org/t/6N-MsQ4WZ7jo/soumith
```

Install the Nervana Aeon dataloader. Installation instructions can be seen [here](https://aeon.nervanasys.com/index.html/getting_started.html) Instructions to install for Anaconda/Ubuntu below:

```
sudo apt-get install libcurl4-openssl-dev clang libopencv-dev libsox-dev
git clone https://github.com/NervanaSystems/aeon.git
cd aeon
python setup.py install
```

Install this fork for Warp-CTC bindings:
```
git clone https://github.com/SeanNaren/warp-ctc.git
Expand All @@ -43,7 +33,7 @@ python setup.py install
Finally:

```
pip install python-levenshtein
pip install -r requirements.txt
```

# Usage
Expand All @@ -59,23 +49,10 @@ python create_dataset_manifest.py --root_path dataset/

This will generate csv manifests files used to load the data for training.

Optionally, a noise dataset can be used to inject noise artifically into the training data. Just fill a folder with noise wav files you want to inject (a source of noise files is the [musan dataset](http:https://www.openslr.org/17/)) and run the below command:
```
python create_noise_manifest.py --root_path noise/ # or whatever you've named your noise folder
```

## Training

You need to find the maximum duration of the training and testing samples. The command below will iterate through the current
folder and find the longest duration:

```
find . -type f -name "*.wav" | xargs soxi -D | sort | tail -n 1
```

Afterwards you can run the training script.

```
python main.py --max_duration 6.4 # This is the default max duration (for an4)
python main.py --train_manifest train_manifest.csv --test_manifest test_manifest.csv
```

21 changes: 0 additions & 21 deletions create_noise_manifest.py

This file was deleted.

Empty file added data/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

parser = argparse.ArgumentParser(description='Creates training and testing manifests')
parser.add_argument('--root_path', default='dataset', help='Path to the dataset')
parser.add_argument('--root_path', default='an4_dataset', help='Path to the dataset')


def create_manifest(data_path, tag):
Expand Down
87 changes: 87 additions & 0 deletions data/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import scipy.signal
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import librosa
import numpy as np


class AudioDataset(Dataset):
def __init__(self, conf):
super(AudioDataset, self).__init__()
with open(conf['manifest_filename']) as f:
ids = f.readlines()
ids = [x.strip().split(',') for x in ids]
self.ids = ids
self.size = len(ids)
self.conf = conf
self.audio_conf = conf['audio']
self.alphabet_map = dict([(conf['alphabet'][i], i) for i in range(len(conf['alphabet']))])
self.normalize = conf.get('normalize', False)

def __getitem__(self, index):
sample = self.ids[index]
audio_path, transcript_path = sample[0], sample[1]
spect = self.spectrogram(audio_path)
transcript = self.parse_transcript(transcript_path)
return spect, transcript

def parse_transcript(self, transcript_path):
with open(transcript_path, 'r') as transcript_file:
transcript = transcript_file.read().replace('\n', '')
transcript = [self.alphabet_map[x] for x in list(transcript)]
return transcript

def spectrogram(self, audio_path):
y, _ = librosa.core.load(audio_path, sr=self.audio_conf['sample_rate'])
n_fft = int(self.audio_conf['sample_rate'] * self.audio_conf['window_size'])
win_length = n_fft
hop_length = int(self.audio_conf['sample_rate'] * self.audio_conf['window_stride'])
window = scipy.signal.hamming # TODO if statement to select window based on conf
# STFT
D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window)
spect, phase = librosa.magphase(D)
# S = log(S+1)
spect = np.log1p(spect)
spect = torch.FloatTensor(spect)
if self.normalize:
mean = spect.mean()
std = spect.std()
spect.add_(-mean)
spect.div_(std)

return spect

def __len__(self):
return self.size


def collate_fn(batch):
def func(p):
return p[0].size(1)

longest_sample = max(batch, key=func)[0]
freq_size = longest_sample.size(0)
minibatch_size = len(batch)
max_seqlength = longest_sample.size(1)
inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength)
input_percentages = torch.FloatTensor(minibatch_size)
target_sizes = torch.IntTensor(minibatch_size)
targets = []
for x in range(minibatch_size):
sample = batch[x]
tensor = sample[0]
target = sample[1]
seq_length = tensor.size(1)
inputs[x][0].narrow(1, 0, seq_length).copy_(tensor)
input_percentages[x] = seq_length / float(max_seqlength)
target_sizes[x] = len(target)
targets.extend(target)
targets = torch.IntTensor(targets)
return inputs, targets, input_percentages, target_sizes


class AudioDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super(AudioDataLoader, self).__init__(*args, **kwargs)
self.collate_fn = collate_fn
File renamed without changes.
Loading

0 comments on commit 5fe10e8

Please sign in to comment.