Skip to content

Commit

Permalink
Added final eval code and command to eval SOTA model
Browse files Browse the repository at this point in the history
  • Loading branch information
MiscellaneousStuff committed May 5, 2022
1 parent b01e1d7 commit e3c9ed8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,24 @@ You can clone this repository by doing the following:
git clone https://github.com/MiscellaneousStuff/semg_asr.git
git submodule init
git submodule update
```

## Evaluate

To evaluate the best trained model released with the report, run the
following code:

```bash
python3 evaluate.py \
--checkpoint_path "path_to_pretrained_model/ds2_DATASET_SILENT_SPEECH_EPOCHS_10_TEST_LOSS_1.8498832106590273_WER_0.6825681123095443" \
--dataset_path "path_to_dataset.csv" \
--semg_eval
```

There are a large number of models and different datasets which have
been evaluated in the report, to find the full list of evaluation conditions
and how to run them, run:

```bash
python3 evaluate.py --help
```
29 changes: 27 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
flags.DEFINE_integer("batch_size", 5, "Sets the batch size for the evaluation")
flags.DEFINE_boolean("closed_only", False, \
"(Optional) Evaluate only on the closed vocabulary slice of the dataset")
flags.DEFINE_integer("print_top", 3, \
"(Optional) Set number of most accurate predictions to print")
flags.mark_flag_as_required("checkpoint_path")
flags.mark_flag_as_required("dataset_path")

Expand All @@ -61,6 +63,9 @@ def evaluate(model, test_loader, device, criterion, encoder):
test_loss = 0
test_cer, test_wer = [], []

# Print the most accurate transcription
scored_preds = []

with torch.no_grad():
for i, _data in enumerate(test_loader):
spectrograms, labels, input_lengths, label_lengths = _data
Expand All @@ -78,8 +83,14 @@ def evaluate(model, test_loader, device, criterion, encoder):
output.transpose(0, 1), labels, label_lengths, encoder)

for j in range(len(decoded_preds)):
test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
cur_ground = decoded_targets[j]
cur_pred = decoded_preds[j]
cur_wer = wer(cur_ground, cur_pred)
cur_cer = cer(cur_ground, cur_pred)
test_cer.append(cur_cer)
test_wer.append(cur_wer)

scored_preds.append([cur_ground, cur_pred, cur_wer])

avg_cer = sum(test_cer) / len(test_cer)
avg_wer = sum(test_wer) / len(test_wer)
Expand All @@ -88,6 +99,20 @@ def evaluate(model, test_loader, device, criterion, encoder):
'Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n'\
.format(test_loss, avg_cer, avg_wer))

sorted_preds = sorted(scored_preds, key=lambda pred: pred[2])

print_top = FLAGS.print_top
if print_top > 0:
sorted_preds = \
sorted_preds[0:min(print_top, len(sorted_preds) - 1)]
for i, pred in enumerate(sorted_preds):
# Get rid of unknown tokens in final output
ground = pred[0].replace("<unk>", "")
prediction = pred[1].replace("<unk>", "")

score = pred[2]
print(f"{i+1}.\n Target: {ground}\n Prediction: {prediction}\n WER: {score:4f}")

def main(unused_argv):
checkpoint_path = FLAGS.checkpoint_path
semg_eval = FLAGS.semg_eval
Expand Down

0 comments on commit e3c9ed8

Please sign in to comment.