Skip to content

Commit

Permalink
Merge pull request SeanNaren#87 from SeanNaren/fixes
Browse files Browse the repository at this point in the history
Updated Test variable and README information
  • Loading branch information
Sean Naren committed Jun 12, 2017
2 parents 30f9feb + 1f3a238 commit a125278
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 17 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ python model.py --model_path models/deepspeech.pth.tar

To also note, there is no final softmax layer on the model as when trained, warp-ctc does this softmax internally. This will have to also be implemented in complex decoders if anything is built on top of the model, so take this into consideration!

## Testing/Inference

To evaluate a trained model on a test set (has to be in the same format as the training set):

```
python test.py --model_path models/deepspeech.pth.tar --test_manifest /path/to/test_manifest.csv --cuda
```

An example script to output a prediction has been provided:

```
python predict.py --model_path models/deepspeech.pth.tar --audio_path /path/to/audio.wav
```

## Acknowledgements

Thanks to [Egor](https://github.com/EgorLakomkin) and [Ryan](https://github.com/ryanleary) for their contributions!
6 changes: 3 additions & 3 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def iteration(input_data):
input_percentages = torch.IntTensor(batch_size).fill_(1)

inputs = Variable(input_data, requires_grad=False)
target_sizes = Variable(target_size requires_grad=False)
targets = Variable(target requires_grad=False)
target_sizes = Variable(target_size, requires_grad=False)
targets = Variable(target, requires_grad=False)
start = time.time()
out = model(inputs)
out = out.transpose(0, 1) # TxNxH

seq_length = out.size(0)
sizes = Variable(input_percentages.mul_(int(seq_length)).int() requires_grad=False)
sizes = Variable(input_percentages.mul_(int(seq_length)).int(), requires_grad=False)
loss = criterion(out, targets, sizes, target_sizes)
loss = loss / inputs.size(0) # average the loss by minibatch
# compute gradient
Expand Down
24 changes: 14 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
'rnn': nn.RNN,
'gru': nn.GRU
}
supported_rnns_inv = dict((v,k) for k,v in supported_rnns.items())
supported_rnns_inv = dict((v, k) for k, v in supported_rnns.items())


class SequenceWise(nn.Module):
Expand Down Expand Up @@ -41,15 +41,14 @@ def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=Fals
super(BatchRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.batch_norm_activate = batch_norm
self.bidirectional = bidirectional
self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size))
self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None
self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size,
bidirectional=bidirectional, bias=False)
self.num_directions = 2 if bidirectional else 1

def forward(self, x):
if self.batch_norm_activate:
if self.batch_norm is not None:
x = self.batch_norm(x)
x, _ = self.rnn(x)
if self.bidirectional:
Expand All @@ -58,10 +57,13 @@ def forward(self, x):


class DeepSpeech(nn.Module):
def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layers=5, audio_conf={}, bidirectional=True):
def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layers=5, audio_conf=None,
bidirectional=True):
super(DeepSpeech, self).__init__()

# model metadata needed for serialization/deserialization
if audio_conf is None:
audio_conf = {}
self._version = '0.0.1'
self._hidden_size = rnn_hidden_size
self._hidden_layers = nb_layers
Expand Down Expand Up @@ -121,7 +123,8 @@ def forward(self, x):
def load_model(cls, path, cuda=False):
package = torch.load(path, map_location=lambda storage, loc: storage)
model = cls(rnn_hidden_size=package['hidden_size'], nb_layers=package['hidden_layers'],
labels=package['labels'], audio_conf=package['audio_conf'], rnn_type=supported_rnns[package['rnn_type']])
labels=package['labels'], audio_conf=package['audio_conf'],
rnn_type=supported_rnns[package['rnn_type']])
model.load_state_dict(package['state_dict'])
if cuda:
model = torch.nn.DataParallel(model).cuda()
Expand Down Expand Up @@ -167,10 +170,11 @@ def get_audio_conf(model):
model_is_cuda = next(model.parameters()).is_cuda
return model.module._audio_conf if model_is_cuda else model._audio_conf


if __name__ == '__main__':
import os.path
import argparse
import json

parser = argparse.ArgumentParser(description='DeepSpeech model information')
parser.add_argument('--model_path', default='models/deepspeech_final.pth.tar',
help='Path to model file created by training')
Expand Down Expand Up @@ -199,9 +203,9 @@ def get_audio_conf(model):
print("Training Information")
epochs = package['epoch']
print(" Epochs: ", epochs)
print(" Current Loss: {0:.3f}".format(package['loss_results'][epochs-1]))
print(" Current CER: {0:.3f}".format(package['cer_results'][epochs-1]))
print(" Current WER: {0:.3f}".format(package['wer_results'][epochs-1]))
print(" Current Loss: {0:.3f}".format(package['loss_results'][epochs - 1]))
print(" Current CER: {0:.3f}".format(package['cer_results'][epochs - 1]))
print(" Current WER: {0:.3f}".format(package['wer_results'][epochs - 1]))

if package.get('meta', None) is not None:
print("")
Expand Down
8 changes: 4 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
parser.add_argument('--model_path', default='models/deepspeech_final.pth.tar',
help='Path to model file created by training')
parser.add_argument('--cuda', action="store_true", help='Use cuda to test model')
parser.add_argument('--val_manifest', metavar='DIR',
help='path to validation manifest csv', default='data/val_manifest.csv')
parser.add_argument('--test_manifest', metavar='DIR',
help='path to validation manifest csv', default='data/test_manifest.csv')
parser.add_argument('--batch_size', default=20, type=int, help='Batch size for training')
parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading')
args = parser.parse_args()
Expand All @@ -26,7 +26,7 @@
audio_conf = DeepSpeech.get_audio_conf(model)
decoder = ArgMaxDecoder(labels)

test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels,
test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.test_manifest, labels=labels,
normalize=True)
test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size,
num_workers=args.num_workers)
Expand Down Expand Up @@ -63,6 +63,6 @@
wer = total_wer / len(test_loader.dataset)
cer = total_cer / len(test_loader.dataset)

print('Validation Summary \t'
print('Test Summary \t'
'Average WER {wer:.3f}\t'
'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))

0 comments on commit a125278

Please sign in to comment.