Skip to content

Commit

Permalink
Various fixes and changes
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNaren committed Jan 28, 2017
1 parent ae62eaf commit ce07d8b
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 43 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# deepspeech.pytorch

* Fix WER/CER measurements
* Add tests for dataloading
* Script to download an4 and create manifests (streamline rather than 2 separate scripts)
* Fix validation. Assume problems with SequenceWise module and using batch norm in 3d mode.
* Support LibriSpeech via multi-processed scripts

Implementation of [Baidu Warp-CTC](https://github.com/baidu-research/warp-ctc) using pytorch.
Expand Down
29 changes: 0 additions & 29 deletions data/create_dataset_manifest.py

This file was deleted.

15 changes: 12 additions & 3 deletions data/get_an4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import io
import shutil

import subprocess

from data.utils import create_manifest

parser = argparse.ArgumentParser(description='Processes and downloads an4.')
parser.add_argument('--an4_path', default='dataset/', help='Path to save dataset')
parser.add_argument('--an4_path', default='an4_dataset/', help='Path to save dataset')
parser.add_argument('--sample_rate', default=16000, type=int, help='Sample rate')


Expand Down Expand Up @@ -62,13 +66,18 @@ def main():
args = parser.parse_args()
root_path = 'an4/'
name = 'an4'
os.system('wget http:https://www.speech.cs.cmu.edu/databases/an4/an4_raw.bigendian.tar.gz')
os.system('tar -xzvf an4_raw.bigendian.tar.gz')
subprocess.call(['wget http:https://www.speech.cs.cmu.edu/databases/an4/an4_raw.bigendian.tar.gz'], shell=True)
subprocess.call(['tar -xzvf an4_raw.bigendian.tar.gz'], stdout=open(os.devnull, 'wb'), shell=True)
os.makedirs(args.an4_path)
format_data('train', name, 'an4_clstk')
format_data('test', name, 'an4test_clstk')
shutil.rmtree(root_path)
os.remove('an4_raw.bigendian.tar.gz')
train_path = args.an4_path + '/train/'
test_path = args.an4_path + '/test/'
print ('Creating manifests...')
create_manifest(train_path, 'train')
create_manifest(test_path, 'test')


if __name__ == '__main__':
Expand Down
37 changes: 37 additions & 0 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import argparse
import io
import os

import subprocess

parser = argparse.ArgumentParser(description='Creates training and testing manifests')
parser.add_argument('--root_path', default='an4_dataset', help='Path to the dataset')

"""
We need to add progress bars like will did in gen audio. Just copy that code.
We also need a call (basically the same in the dataloader, a find that gives us the total number of wav files)
"""


def create_manifest(data_path, tag, ordered=True):
manifest_path = '%s_manifest.csv' % tag
file_paths = []
with os.popen('find %s -type f -name "*.wav"' % data_path) as pipe:
for file_path in pipe:
file_paths.append(file_path.strip())
if ordered:
print("Sorting files by length...")

def func(element):
output = subprocess.check_output(
['soxi -D %s' % element.strip()],
shell=True
)
return float(output)

file_paths.sort(key=func)
with io.FileIO(manifest_path, "w") as file:
for wav_path in file_paths:
transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt')
sample = os.path.abspath(wav_path) + ',' + os.path.abspath(transcript_path) + '\n'
file.write(sample)
12 changes: 6 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

parser = argparse.ArgumentParser(description='DeepSpeech pytorch params')
parser.add_argument('--train_manifest', metavar='DIR',
help='path to train manifest csv', default='train_manifest.csv')
help='path to train manifest csv', default='data/train_manifest.csv')
parser.add_argument('--test_manifest', metavar='DIR',
help='path to test manifest csv', default='test_manifest.csv')
help='path to test manifest csv', default='data/test_manifest.csv')
parser.add_argument('--sample_rate', default=16000, type=int, help='Sample rate')
parser.add_argument('--batch_size', default=20, type=int, help='Batch size for training')
parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading')
Expand Down Expand Up @@ -64,12 +64,12 @@ def main():

train_dataloader_config = dict(type="audio,transcription",
audio=audio_config,
manifest_filename='train_manifest.csv',
manifest_filename=args.train_manifest,
alphabet=alphabet,
normalize=True)
test_dataloader_config = dict(type="audio,transcription",
audio=audio_config,
manifest_filename='test_manifest.csv',
manifest_filename=args.test_manifest,
alphabet=alphabet,
normalize=True)
train_loader = AudioDataLoader(AudioDataset(train_dataloader_config), args.batch_size,
Expand Down Expand Up @@ -183,8 +183,8 @@ def main():
target_strings = decoder.process_strings(decoder.convert_to_strings(split_targets))
wer, cer = 0, 0
for x in range(len(target_strings)):
wer += decoder.wer(decoded_output[x], target_strings[x])
cer += decoder.cer(decoded_output[x], target_strings[x])
wer += decoder.wer(decoded_output[x], target_strings[x]) / float(len(target_strings[x].split()))
cer += decoder.cer(decoded_output[x], target_strings[x]) / float(len(target_strings[x]))
total_cer += cer
total_wer += wer

Expand Down
7 changes: 4 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,20 @@ def __init__(self, input_size, hidden_size, bidirectional=False, batch_norm=True
self.batch_norm_activate = batch_norm
self.bidirectional = bidirectional
self.batch_norm = nn.BatchNorm1d(input_size)
self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size,
bidirectional=bidirectional, bias=False)
self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
bidirectional=bidirectional, bias=False)
self.num_directions = 2 if bidirectional else 1

def forward(self, x):
c0 = Variable(torch.zeros(self.num_directions, x.size(1), self.hidden_size).type_as(x.data))
h0 = Variable(torch.zeros(self.num_directions, x.size(1), self.hidden_size).type_as(x.data))
if self.batch_norm_activate:
t, n = x.size(0), x.size(1)
x = x.view(n, -1, t)
x = self.batch_norm(x)
x = x.transpose(1, 2).transpose(0, 1)
x = x.contiguous()
x, _ = self.rnn(x, h0)
x, _ = self.rnn(x, (c0, h0))
if self.bidirectional:
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxD*2) -> (TxNxD) by sum
return x
Expand Down

0 comments on commit ce07d8b

Please sign in to comment.