Skip to content

Commit

Permalink
Removed variables, added new masking layer, refactored training/loadi…
Browse files Browse the repository at this point in the history
…ng of model
  • Loading branch information
sean.narenthiran committed Jul 12, 2018
1 parent 80a060f commit e7b459d
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 67 deletions.
43 changes: 21 additions & 22 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __repr__(self):
class MaskConv(nn.Module):
def __init__(self, seq_module):
"""
Adds padding to the output of the module based on the given lengths
Adds padding to the output of the module based on the given lengths. This is to ensure that the
results of the model do not change when batch sizes change during inference.
Input needs to be in the shape of (BxCxDxT)
:param seq_module: The sequential module containing the conv stack.
"""
Expand All @@ -57,10 +58,12 @@ def forward(self, x, lengths):
"""
for module in self.seq_module:
x = module(x)
mask = torch.ByteTensor(x.size()).fill_(0).cuda() # TODO don't hard-code cuda
for i, length in enumerate(lengths):
length = length.item()
if (x[i].size(2) - length) > 0:
x[i].narrow(2, length, x[i].size(2) - length).fill_(0)
if (mask[i].size(2) - length) > 0:
mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1)
x = x.masked_fill(mask, 0)
return x, lengths


Expand All @@ -80,7 +83,7 @@ def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=Fals
self.bidirectional = bidirectional
self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None
self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size,
bidirectional=bidirectional, bias=False)
bidirectional=bidirectional, bias=True)
self.num_directions = 2 if bidirectional else 1

def flatten_parameters(self):
Expand All @@ -91,10 +94,9 @@ def forward(self, x, output_lengths):
x = self.batch_norm(x)
x = nn.utils.rnn.pack_padded_sequence(x, output_lengths)
x, h = self.rnn(x)
if self.bidirectional:
x = x._replace(
data=x.data[:, :self.hidden_size] + x.data[:, self.hidden_size:]) # sum bidirectional outputs
x, _ = nn.utils.rnn.pad_packed_sequence(x)
if self.bidirectional:
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum
return x


Expand Down Expand Up @@ -206,7 +208,7 @@ def forward(self, x, lengths):
x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH

for rnn in self.rnns:
x = rnn(x, output_lengths.numpy())
x = rnn(x, output_lengths)

if not self._bidirectional: # no need for lookahead layer in bidirectional
x = self.lookahead(x)
Expand All @@ -231,7 +233,7 @@ def get_seq_lens(self, input_length):
return seq_len.int()

@classmethod
def load_model(cls, path, cuda=False):
def load_model(cls, path):
package = torch.load(path, map_location=lambda storage, loc: storage)
model = cls(rnn_hidden_size=package['hidden_size'], nb_layers=package['hidden_layers'],
labels=package['labels'], audio_conf=package['audio_conf'],
Expand All @@ -247,25 +249,20 @@ def load_model(cls, path, cuda=False):
model.load_state_dict(package['state_dict'])
for x in model.rnns:
x.flatten_parameters()
if cuda:
model = torch.nn.DataParallel(model).cuda()
return model

@classmethod
def load_model_package(cls, package, cuda=False):
def load_model_package(cls, package):
model = cls(rnn_hidden_size=package['hidden_size'], nb_layers=package['hidden_layers'],
labels=package['labels'], audio_conf=package['audio_conf'],
rnn_type=supported_rnns[package['rnn_type']], bidirectional=package.get('bidirectional', True))
model.load_state_dict(package['state_dict'])
if cuda:
model = torch.nn.DataParallel(model).cuda()
return model

@staticmethod
def serialize(model, optimizer=None, epoch=None, iteration=None, loss_results=None,
cer_results=None, wer_results=None, avg_loss=None, meta=None):
model_is_cuda = next(model.parameters()).is_cuda
model = model.module if model_is_cuda else model
model = model.module if DeepSpeech.is_parallel(model) else model
package = {
'version': model._version,
'hidden_size': model._hidden_size,
Expand Down Expand Up @@ -294,8 +291,7 @@ def serialize(model, optimizer=None, epoch=None, iteration=None, loss_results=No

@staticmethod
def get_labels(model):
model_is_cuda = next(model.parameters()).is_cuda
return model.module._labels if model_is_cuda else model._labels
return model.module._labels if model.is_parallel(model) else model._labels

@staticmethod
def get_param_size(model):
Expand All @@ -309,13 +305,11 @@ def get_param_size(model):

@staticmethod
def get_audio_conf(model):
model_is_cuda = next(model.parameters()).is_cuda
return model.module._audio_conf if model_is_cuda else model._audio_conf
return model.module._audio_conf if DeepSpeech.is_parallel(model) else model._audio_conf

@staticmethod
def get_meta(model):
model_is_cuda = next(model.parameters()).is_cuda
m = model.module if model_is_cuda else model
m = model.module if DeepSpeech.is_parallel(model) else model
meta = {
"version": m._version,
"hidden_size": m._hidden_size,
Expand All @@ -324,6 +318,11 @@ def get_meta(model):
}
return meta

@staticmethod
def is_parallel(model):
return isinstance(model, torch.nn.parallel.DataParallel) or \
isinstance(model, torch.nn.parallel.DistributedDataParallel)


if __name__ == '__main__':
import os.path
Expand Down
8 changes: 5 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@

if __name__ == '__main__':
torch.set_grad_enabled(False)
model = DeepSpeech.load_model(args.model_path, cuda=args.cuda)
model = DeepSpeech.load_model(args.model_path)
if args.cuda:
model.cuda()
model.eval()

labels = DeepSpeech.get_labels(model)
Expand All @@ -63,7 +65,7 @@
output_data = []
for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)):
inputs, targets, input_percentages, target_sizes = data
input_sizes = Variable(input_percentages.mul_(int(inputs.size(3))).int(), requires_grad=False)
input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

# unflatten targets
split_targets = []
Expand All @@ -79,7 +81,7 @@

if decoder is None:
# add output to data array, and continue
output_data.append((out.data.cpu().numpy(), output_sizes.data.cpu().numpy()))
output_data.append((out.numpy(), output_sizes.numpy()))
continue

decoded_output, _ = decoder.decode(out.data, output_sizes.data)
Expand Down
47 changes: 7 additions & 40 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import errno
import json
import os
import time
Expand All @@ -10,7 +9,6 @@
from warpctc_pytorch import CTCLoss

from data.data_loader import AudioDataLoader, SpectrogramDataset, BucketingSampler, DistributedBucketingSampler
from data.distributed import DistributedDataParallel
from decoder import GreedyDecoder
from model import DeepSpeech, supported_rnns

Expand Down Expand Up @@ -125,32 +123,11 @@ def update(self, val, n=1):
viz_window = None
epochs = torch.arange(1, args.epochs + 1)
if args.tensorboard and main_proc:
try:
os.makedirs(args.log_dir)
except OSError as e:
if e.errno == errno.EEXIST:
print('Tensorboard log directory already exists.')
for file in os.listdir(args.log_dir):
file_path = os.path.join(args.log_dir, file)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
except Exception:
raise
else:
raise
os.makedirs(args.log_dir, exist_ok=True)
from tensorboardX import SummaryWriter

tensorboard_writer = SummaryWriter(args.log_dir)

try:
os.makedirs(save_folder)
except OSError as e:
if e.errno == errno.EEXIST:
print('Model Save directory already exists.')
else:
raise
criterion = CTCLoss()
os.makedirs(save_folder, exist_ok=True)

avg_loss, start_epoch, start_iter = 0, 0, 0
if args.continue_from: # Starting from previous model
Expand All @@ -164,15 +141,6 @@ def update(self, val, n=1):
momentum=args.momentum, nesterov=True)
if not args.finetune: # Don't want to restart training
optimizer.load_state_dict(package['optim_dict'])

# Temporary fix for pytorch #2830 & #1442 while pull request #3658 in not incorporated in a release
# TODO : remove when a new release of pytorch include pull request #3658
if args.cuda:
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()

start_epoch = int(package.get('epoch', 1)) - 1 # Index start at 0 for training
start_iter = package.get('iteration', None)
if start_iter is None:
Expand Down Expand Up @@ -228,7 +196,7 @@ def update(self, val, n=1):
parameters = model.parameters()
optimizer = torch.optim.SGD(parameters, lr=args.lr,
momentum=args.momentum, nesterov=True)

criterion = CTCLoss()
decoder = GreedyDecoder(labels)
train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, labels=labels,
normalize=True, augment=args.augment)
Expand All @@ -248,12 +216,11 @@ def update(self, val, n=1):
print("Shuffling batches for the following epochs")
train_sampler.shuffle(start_epoch)

if args.cuda and not args.distributed:
model = torch.nn.DataParallel(model, device_ids=args.device_ids).cuda()
elif args.cuda and args.distributed:
if args.cuda:
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=(int(args.gpu_rank),) if args.rank else None)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=(int(args.gpu_rank),) if args.rank else None)

print(model)
print("Number of parameters: %d" % DeepSpeech.get_param_size(model))
Expand Down
4 changes: 3 additions & 1 deletion transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def decode_results(model, decoded_output, decoded_offsets):

if __name__ == '__main__':
torch.set_grad_enabled(False)
model = DeepSpeech.load_model(args.model_path, cuda=args.cuda)
model = DeepSpeech.load_model(args.model_path)
if args.cuda:
model.cuda()
model.eval()

labels = DeepSpeech.get_labels(model)
Expand Down
4 changes: 3 additions & 1 deletion tune_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def decode_dataset(logits, test_dataset, batch_size, lm_alpha, lm_beta, mesh_x,
print("error: LM must be provided for tuning")
sys.exit(1)

model = DeepSpeech.load_model(args.model_path, cuda=False)
model = DeepSpeech.load_model(args.model_path)
if args.cuda:
model.cuda()
model.eval()

labels = DeepSpeech.get_labels(model)
Expand Down

0 comments on commit e7b459d

Please sign in to comment.