Skip to content

Commit

Permalink
Fix benchmark script (SeanNaren#62)
Browse files Browse the repository at this point in the history
* Added fixes to benchmark script

* Fixed default
  • Loading branch information
Sean Naren committed May 16, 2017
1 parent ce96424 commit 743fa66
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import json
import time
import torch
from torch.autograd import Variable
Expand All @@ -13,7 +14,8 @@
help='The size of the fake input in seconds using default stride of 0.01, '
'15s is usually the maximum duration')
parser.add_argument('--dry_runs', type=int, default=20, help='Dry runs before measuring performance')
parser.add_argument('--runs', type=int, default=20, help='Hidden size of RNNs')
parser.add_argument('--runs', type=int, default=20, help='How many benchmark runs to measure performance')
parser.add_argument('--labels_path', default='labels.json', help='Path to the labels to infer over in the model')
parser.add_argument('--hidden_size', default=400, type=int, help='Hidden size of RNNs')
parser.add_argument('--hidden_layers', default=4, type=int, help='Number of RNN layers')
parser.add_argument('--rnn_type', default='lstm', help='Type of the RNN. rnn|gru|lstm are supported')
Expand All @@ -25,10 +27,18 @@

rnn_type = args.rnn_type.lower()
assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"

with open(args.labels_path) as label_file:
labels = str(''.join(json.load(label_file)))

audio_conf = dict(sample_rate=args.sample_rate,
window_size=args.window_size)

model = DeepSpeech(rnn_hidden_size=args.hidden_size,
nb_layers=args.hidden_layers, num_classes=29,
rnn_type=supported_rnns[rnn_type],
sample_rate=args.sample_rate, window_size=args.window_size)
nb_layers=args.hidden_layers,
audio_conf=audio_conf,
labels=labels,
rnn_type=supported_rnns[rnn_type])

parameters = model.parameters()
optimizer = torch.optim.SGD(parameters, lr=3e-4,
Expand Down

0 comments on commit 743fa66

Please sign in to comment.