Skip to content

Commit

Permalink
Changes to support variable length
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNaren authored and sean.narenthiran committed Jul 11, 2018
1 parent 655cd58 commit 80a060f
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 96 deletions.
1 change: 1 addition & 0 deletions data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def _collate_fn(batch):
def func(p):
return p[0].size(1)

batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
longest_sample = max(batch, key=func)[0]
freq_size = longest_sample.size(0)
minibatch_size = len(batch)
Expand Down
70 changes: 57 additions & 13 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,31 @@ def __repr__(self):
return tmpstr


class MaskConv(nn.Module):
def __init__(self, seq_module):
"""
Adds padding to the output of the module based on the given lengths
Input needs to be in the shape of (BxCxDxT)
:param seq_module: The sequential module containing the conv stack.
"""
super(MaskConv, self).__init__()
self.seq_module = seq_module

def forward(self, x, lengths):
"""
:param x: The input of size BxCxDxT
:param lengths: The actual length of each sequence in the batch
:return: Masked output from the module
"""
for module in self.seq_module:
x = module(x)
for i, length in enumerate(lengths):
length = length.item()
if (x[i].size(2) - length) > 0:
x[i].narrow(2, length, x[i].size(2) - length).fill_(0)
return x, lengths


class InferenceBatchSoftmax(nn.Module):
def forward(self, input_):
if not self.training:
Expand All @@ -61,12 +86,15 @@ def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=Fals
def flatten_parameters(self):
self.rnn.flatten_parameters()

def forward(self, x):
def forward(self, x, output_lengths):
if self.batch_norm is not None:
x = self.batch_norm(x)
x, _ = self.rnn(x)
x = nn.utils.rnn.pack_padded_sequence(x, output_lengths)
x, h = self.rnn(x)
if self.bidirectional:
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum
x = x._replace(
data=x.data[:, :self.hidden_size] + x.data[:, self.hidden_size:]) # sum bidirectional outputs
x, _ = nn.utils.rnn.pad_packed_sequence(x)
return x


Expand Down Expand Up @@ -130,18 +158,18 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layer
window_size = self._audio_conf.get("window_size", 0.02)
num_classes = len(self._labels)

self.conv = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(0, 10)),
self.conv = MaskConv(nn.Sequential(
nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True),
nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), ),
nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True)
)
))
# Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
rnn_input_size = int(math.floor(rnn_input_size - 41) / 2 + 1)
rnn_input_size = int(math.floor(rnn_input_size - 21) / 2 + 1)
rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)
rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)
rnn_input_size *= 32

rnns = []
Expand All @@ -168,14 +196,17 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layer
)
self.inference_softmax = InferenceBatchSoftmax()

def forward(self, x):
x = self.conv(x)
def forward(self, x, lengths):
lengths = lengths.cpu().int()
output_lengths = self.get_seq_lens(lengths)
x, _ = self.conv(x, output_lengths)

sizes = x.size()
x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension
x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH

x = self.rnns(x)
for rnn in self.rnns:
x = rnn(x, output_lengths.numpy())

if not self._bidirectional: # no need for lookahead layer in bidirectional
x = self.lookahead(x)
Expand All @@ -184,7 +215,20 @@ def forward(self, x):
x = x.transpose(0, 1)
# identity in training mode, softmax in eval mode
x = self.inference_softmax(x)
return x
return x, output_lengths

def get_seq_lens(self, input_length):
"""
Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable
containing the size sequences that will be output by the network.
:param input_length: 1D Tensor
:return: 1D Tensor scaled by model
"""
seq_len = input_length
for m in self.conv.modules():
if type(m) == nn.modules.conv.Conv2d:
seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) / m.stride[1] + 1)
return seq_len.int()

@classmethod
def load_model(cls, path, cuda=False):
Expand Down
9 changes: 4 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
output_data = []
for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)):
inputs, targets, input_percentages, target_sizes = data
input_sizes = Variable(input_percentages.mul_(int(inputs.size(3))).int(), requires_grad=False)

# unflatten targets
split_targets = []
Expand All @@ -74,16 +75,14 @@
if args.cuda:
inputs = inputs.cuda()

out = model(inputs) # NxTxH
seq_length = out.size(1)
sizes = input_percentages.mul_(int(seq_length)).int()
out, output_sizes = model(inputs, input_sizes)

if decoder is None:
# add output to data array, and continue
output_data.append((out.data.cpu().numpy(), sizes.numpy()))
output_data.append((out.data.cpu().numpy(), output_sizes.data.cpu().numpy()))
continue

decoded_output, _, = decoder.decode(out.data, sizes)
decoded_output, _ = decoder.decode(out.data, output_sizes.data)
target_strings = target_decoder.convert_to_strings(split_targets)
for x in range(len(target_strings)):
transcript, reference = decoded_output[x][0], target_strings[x][0]
Expand Down
137 changes: 61 additions & 76 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch.distributed as dist
import torch.utils.data.distributed
from torch.autograd import Variable
from tqdm import tqdm
from warpctc_pytorch import CTCLoss

Expand Down Expand Up @@ -253,7 +252,8 @@ def update(self, val, n=1):
model = torch.nn.DataParallel(model, device_ids=args.device_ids).cuda()
elif args.cuda and args.distributed:
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=(int(args.gpu_rank),) if args.rank else None)
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=(int(args.gpu_rank),) if args.rank else None)

print(model)
print("Number of parameters: %d" % DeepSpeech.get_param_size(model))
Expand All @@ -270,22 +270,17 @@ def update(self, val, n=1):
if i == len(train_sampler):
break
inputs, targets, input_percentages, target_sizes = data
input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
# measure data loading time
data_time.update(time.time() - end)
inputs = Variable(inputs, requires_grad=False)
target_sizes = Variable(target_sizes, requires_grad=False)
targets = Variable(targets, requires_grad=False)

if args.cuda:
inputs = inputs.cuda()

out = model(inputs)
out, output_sizes = model(inputs, input_sizes)
out = out.transpose(0, 1) # TxNxH

seq_length = out.size(0)
sizes = Variable(input_percentages.mul_(int(seq_length)).int(), requires_grad=False)

loss = criterion(out, targets, sizes, target_sizes)
loss = criterion(out, targets, output_sizes, target_sizes)
loss = loss / inputs.size(0) # average the loss by minibatch

loss_sum = loss.data.sum()
Expand All @@ -307,9 +302,6 @@ def update(self, val, n=1):
# SGD step
optimizer.step()

if args.cuda:
torch.cuda.synchronize()

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
Expand All @@ -318,8 +310,7 @@ def update(self, val, n=1):
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
(epoch + 1), (i + 1), len(train_sampler), batch_time=batch_time,
data_time=data_time, loss=losses))
(epoch + 1), (i + 1), len(train_sampler), batch_time=batch_time, data_time=data_time, loss=losses))
if args.checkpoint_per_batch > 0 and i > 0 and (i + 1) % args.checkpoint_per_batch == 0 and main_proc:
file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth.tar' % (save_folder, epoch + 1, i + 1)
print("Saving checkpoint model to %s" % file_path)
Expand All @@ -334,15 +325,15 @@ def update(self, val, n=1):
epoch_time = time.time() - start_epoch_time
print('Training Summary Epoch: [{0}]\t'
'Time taken (s): {epoch_time:.0f}\t'
'Average Loss {loss:.3f}\t'.format(
epoch + 1, epoch_time=epoch_time, loss=avg_loss))
'Average Loss {loss:.3f}\t'.format(epoch + 1, epoch_time=epoch_time, loss=avg_loss))

start_iter = 0 # Reset start iteration for next epoch
total_cer, total_wer = 0, 0
model.eval()
with torch.no_grad():
for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)):
inputs, targets, input_percentages, target_sizes = data
input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

# unflatten targets
split_targets = []
Expand All @@ -354,11 +345,9 @@ def update(self, val, n=1):
if args.cuda:
inputs = inputs.cuda()

out = model(inputs) # NxTxH
seq_length = out.size(1)
sizes = input_percentages.mul_(int(seq_length)).int()
out, output_sizes = model(inputs, input_sizes)

decoded_output, _ = decoder.decode(out.data, sizes)
decoded_output, _ = decoder.decode(out.data, output_sizes)
target_strings = decoder.convert_to_strings(split_targets)
wer, cer = 0, 0
for x in range(len(target_strings)):
Expand All @@ -367,9 +356,6 @@ def update(self, val, n=1):
cer += decoder.cer(transcript, reference) / float(len(reference))
total_cer += cer
total_wer += wer

if args.cuda:
torch.cuda.synchronize()
del out
wer = total_wer / len(test_loader.dataset)
cer = total_cer / len(test_loader.dataset)
Expand All @@ -380,56 +366,55 @@ def update(self, val, n=1):
cer_results[epoch] = cer
print('Validation Summary Epoch: [{0}]\t'
'Average WER {wer:.3f}\t'
'Average CER {cer:.3f}\t'.format(
epoch + 1, wer=wer, cer=cer))
'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer))

if args.visdom and main_proc:
x_axis = epochs[0:epoch + 1]
y_axis = torch.stack((loss_results[0:epoch + 1], wer_results[0:epoch + 1], cer_results[0:epoch + 1]), dim=1)
if viz_window is None:
viz_window = viz.line(
X=x_axis,
Y=y_axis,
opts=opts,
)
else:
viz.line(
X=x_axis.unsqueeze(0).expand(y_axis.size(1), x_axis.size(0)).transpose(0, 1), # Visdom fix
Y=y_axis,
win=viz_window,
update='replace',
)
if args.tensorboard and main_proc:
values = {
'Avg Train Loss': avg_loss,
'Avg WER': wer,
'Avg CER': cer
}
tensorboard_writer.add_scalars(args.id, values, epoch + 1)
if args.log_params:
for tag, value in model.named_parameters():
tag = tag.replace('.', '/')
tensorboard_writer.add_histogram(tag, to_np(value), epoch + 1)
tensorboard_writer.add_histogram(tag + '/grad', to_np(value.grad), epoch + 1)
if args.checkpoint and main_proc:
file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results,
wer_results=wer_results, cer_results=cer_results),
file_path)
# anneal lr
optim_state = optimizer.state_dict()
optim_state['param_groups'][0]['lr'] = optim_state['param_groups'][0]['lr'] / args.learning_anneal
optimizer.load_state_dict(optim_state)
print('Learning rate annealed to: {lr:.6f}'.format(lr=optim_state['param_groups'][0]['lr']))

if (best_wer is None or best_wer > wer) and main_proc:
print("Found better validated model, saving to %s" % args.model_path)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results,
wer_results=wer_results, cer_results=cer_results)
, args.model_path)
best_wer = wer

avg_loss = 0
if not args.no_shuffle:
print("Shuffling batches...")
train_sampler.shuffle(epoch)
if args.visdom and main_proc:
x_axis = epochs[0:epoch + 1]
y_axis = torch.stack(
(loss_results[0:epoch + 1], wer_results[0:epoch + 1], cer_results[0:epoch + 1]), dim=1)
if viz_window is None:
viz_window = viz.line(
X=x_axis,
Y=y_axis,
opts=opts,
)
else:
viz.line(
X=x_axis.unsqueeze(0).expand(y_axis.size(1), x_axis.size(0)).transpose(0, 1), # Visdom fix
Y=y_axis,
win=viz_window,
update='replace',
)
if args.tensorboard and main_proc:
values = {
'Avg Train Loss': avg_loss,
'Avg WER': wer,
'Avg CER': cer
}
tensorboard_writer.add_scalars(args.id, values, epoch + 1)
if args.log_params:
for tag, value in model.named_parameters():
tag = tag.replace('.', '/')
tensorboard_writer.add_histogram(tag, to_np(value), epoch + 1)
tensorboard_writer.add_histogram(tag + '/grad', to_np(value.grad), epoch + 1)
if args.checkpoint and main_proc:
file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results,
wer_results=wer_results, cer_results=cer_results),
file_path)
# anneal lr
optim_state = optimizer.state_dict()
optim_state['param_groups'][0]['lr'] = optim_state['param_groups'][0]['lr'] / args.learning_anneal
optimizer.load_state_dict(optim_state)
print('Learning rate annealed to: {lr:.6f}'.format(lr=optim_state['param_groups'][0]['lr']))

if (best_wer is None or best_wer > wer) and main_proc:
print("Found better validated model, saving to %s" % args.model_path)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results,
wer_results=wer_results, cer_results=cer_results), args.model_path)
best_wer = wer

avg_loss = 0
if not args.no_shuffle:
print("Shuffling batches...")
train_sampler.shuffle(epoch)
8 changes: 6 additions & 2 deletions transcribe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import warnings

import torch

warnings.simplefilter('ignore')

from decoder import GreedyDecoder
Expand Down Expand Up @@ -86,6 +88,8 @@ def decode_results(model, decoded_output, decoded_offsets):

spect = parser.parse_audio(args.audio_path).contiguous()
spect = spect.view(1, 1, spect.size(0), spect.size(1))
out = model(spect)
decoded_output, decoded_offsets = decoder.decode(out.data)
input_sizes = torch.IntTensor([spect.size(3)]).int()
out, output_sizes = model(spect, input_sizes)
out = out.transpose(0, 1) # TxNxH
decoded_output, decoded_offsets = decoder.decode(out, output_sizes)
print(json.dumps(decode_results(model, decoded_output, decoded_offsets)))

0 comments on commit 80a060f

Please sign in to comment.