Skip to content

Commit

Permalink
Added result for default (char) setup
Browse files Browse the repository at this point in the history
  • Loading branch information
iankur committed Jul 24, 2020
1 parent cae3bef commit 61d9fab
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 14 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
## ContextNet: Improving Convolutional Neural Networks for Automatic Speech Recognition with Global Context

:construction: This repository contains TF2.x based implementation for [this paper](https://arxiv.org/pdf/2005.03191.pdf).
This repository contains TF2.x based implementation for [this paper](https://arxiv.org/pdf/2005.03191.pdf). The default setup, which is a character-based model, achieves **11.66%** and **28.31%** WERs on LibriSpeech test-clean and test-other sets respectively. These WERs can easily be improved by using:
* large vocabulary ([subword unit](https://arxiv.org/abs/1508.07909) is one way to achieve this)
* data augmentation ([SpecAugment](https://arxiv.org/abs/1904.08779) is one such technique)
* regularization or limiting model capacity

### Dependencies:
* Pysoundfile
Expand Down
2 changes: 1 addition & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _get_output_sequence(vocab, transcript):
labels = [vocab[char] if char in vocab else vocab['<unk>'] for char in transcript]
return np.array(labels, dtype=np.int32), np.array(len(labels), dtype=np.int32)

def create_dataset(librispeech_dir, data_key, vocab, mean=None, std_dev=None, batch_size=1, num_feats=40):
def create_dataset(librispeech_dir, data_key, vocab, mean=None, std_dev=None, num_feats=40):
""" librispeech_dir (str): path to directory containing librispeech data
data_key (str) : train / dev / test
mean (str|None) : path to file containing mean of librispeech training data
Expand Down
15 changes: 7 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def create_conv_blocks():

# C1-2 : 5 conv layers, 256 output channels, strides 1
blocks.append(ConvBlock([256//8, 256], 5, 256, 5, 1))
blocks.append(ConvBlock([256//8, 256], 1, 256, 5, 1))
blocks.append(ConvBlock([256//8, 256], 5, 256, 5, 1))

# C3 : 5 conv layers, 256 output channels, strides 2
blocks.append(ConvBlock([256//8, 256], 5, 256, 5, 2))
Expand All @@ -44,8 +44,8 @@ def create_conv_blocks():
for i in range(15, 21+1):
blocks.append(ConvBlock([512//8, 512], 5, 512, 5, 1))

# C22 : 1 conv layers, 640 output channels, strides 1
blocks.append(ConvBlock([640//8, 640], 5, 640, 5, 1, residual=False))
# C22 : 1 conv layer, 640 output channels, strides 1
blocks.append(ConvBlock([640//8, 640], 1, 640, 5, 1, residual=False))

return blocks

Expand Down Expand Up @@ -80,14 +80,14 @@ def get_config(self):
return tf.keras.optimizers.Adam(learning_rate=lr)

def train(num_units, num_vocab, num_lstms, lstm_units, out_dim,
lr, batch_size, num_epochs, data_path, vocab, mean, std_dev, num_features):
lr, num_epochs, data_path, vocab, mean, std_dev, num_features):
model = create_model(num_units=num_units, num_vocab=num_vocab,
num_lstms=num_lstms, lstm_units=lstm_units, out_dim=out_dim)

dev_dataset = create_dataset(data_path, "dev", vocab,
mean, std_dev, batch_size, num_features)
mean, std_dev, num_features)
train_dataset = create_dataset(data_path, "train", vocab,
mean, std_dev, batch_size, num_features)
mean, std_dev, num_features)

step = tf.Variable(1)
optimizer = create_optimizer(lr)
Expand Down Expand Up @@ -165,8 +165,7 @@ def train_step(x, y, x_len, y_len):
parser.add_argument("--out_dim", type=int, default=640, help="Label encoder output size")

# Optimization arguments
parser.add_argument("--lr", type=float, default=0.0025, help="Learning rate")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--lr", type=float, default=0.0015, help="Learning rate")
parser.add_argument("--num_epochs", type=int, default=10, help="Number of epochs")

# Train / validation data
Expand Down
4 changes: 0 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
########## Dependencies ##############
# librosa : pip install librosa
######################################

# Taken from https://github.com/rwth-i6/returnn/blob/master/GeneratingDataset.py
def _get_audio_features_mfcc(audio, sample_rate, window_len=0.025, step_len=0.010, num_feature_filters=40):
"""
Expand Down

1 comment on commit 61d9fab

@Wikidepia
Copy link

@Wikidepia Wikidepia commented on 61d9fab Apr 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, do you have any wandb or log when you train this model?
Thanks

Please sign in to comment.