Skip to content

Commit

Permalink
figure explanation, fine tuned
Browse files Browse the repository at this point in the history
  • Loading branch information
MiscellaneousStuff committed May 5, 2022
1 parent 94824eb commit d5e96e0
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 8 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,7 @@ dmypy.json
models/

# Ignore dataset CSV files
*.csv
*.csv

# Ignore testset visualisations
testset_visuals/
3 changes: 2 additions & 1 deletion create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def main(unused_argv):
if sentence_idx in valid_idxs:
dataset = "valid"

csv_writer.writerow([audio_path, text, dataset, modality])
csv_writer.writerow(\
[audio_path, text, dataset, modality, book, sentence_idx])
def entry_point():
app.run(main)

Expand Down
22 changes: 18 additions & 4 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,24 @@ def __init__(self, metadata_path, dataset_type=None):
if dataset_type:
for fi in self._flist:
line = fi
_, _, cur_dataset_type, _ = line
_, _, cur_dataset_type, _, _, _ = line
if cur_dataset_type == dataset_type:
fis.append(fi)
else:
Exception("No dataset type specified for SilentSpeech() dataset.""")
self._flist = fis

def get_exact(self, book, sentence_idx):
lines = [fi for fi in self._flist
if fi[-2] == book and fi[-1] == sentence_idx]
line = lines[0]
cur_path, text, dataset_type, _, _, _ = line
waveform, sr = torchaudio.load(cur_path)
return (waveform, sr, text, dataset_type)

def __getitem__(self, n):
line = self._flist[n]
cur_path, text, dataset_type, _ = line
cur_path, text, dataset_type, _, _, _ = line
waveform, sr = torchaudio.load(cur_path)
return (waveform, sr, text, dataset_type)

Expand All @@ -70,7 +78,7 @@ def __init__(self, metadata_path, dataset_type=None, \
if dataset_type:
for fi in self._flist:
line = fi
_, _, cur_dataset_type, modality = line
_, _, cur_dataset_type, modality, _, _ = line
if cur_dataset_type == dataset_type:
if silent_only and modality == "silent":
fis.append(fi)
Expand All @@ -85,9 +93,15 @@ def __init__(self, metadata_path, dataset_type=None, \

self._flist = fis

def get_item_vis(self, n):
line = self._flist[n]
cur_path, text, dataset_type, _, book, sentence_idx = line
mel_spectrogram = torch.load(cur_path)
return (mel_spectrogram, text, dataset_type, book, sentence_idx)

def __getitem__(self, n):
line = self._flist[n]
cur_path, text, dataset_type, _ = line
cur_path, text, dataset_type, _, _, _ = line
mel_spectrogram = torch.load(cur_path)
return (mel_spectrogram, text, dataset_type)

Expand Down
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def evaluate(model, test_loader, device, criterion, encoder):

with torch.no_grad():
for i, _data in enumerate(test_loader):
spectrograms, labels, input_lengths, label_lengths = _data
spectrograms, labels, input_lengths, label_lengths = _data
spectrograms, labels = spectrograms.to(device), labels.to(device)

output = model(spectrograms) # (batch, time, n_class)
Expand Down
90 changes: 89 additions & 1 deletion visualise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,93 @@
ground truth audio files along with their predicted mel spectrograms from
the transduction model."""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import librosa

import torch
import torch.utils.data as data

from absl import flags
from absl import app

from datasets import SilentSpeechDataset, SilentSpeechPredDataset
from preprocessing import valid_audio_transforms

FLAGS = flags.FLAGS
flags.DEFINE_string("pred_dataset_path", None, \
"Path to pred *.csv file which defines the dataset to evaluate")
flags.DEFINE_string("ground_dataset_path", None, \
"Path to ground *.csv file which defines the dataset to evaluate")
flags.DEFINE_string("testset_path", None, "Path to transduction model testset.json")
flags.DEFINE_boolean("closed_only", False, \
"(Optional) Evaluate only on the closed vocabulary slice of the dataset")
flags.DEFINE_integer("max_examples", 10, "Number of testset examples to visualise")
flags.mark_flag_as_required("pred_dataset_path")
flags.mark_flag_as_required("ground_dataset_path")

def stack_mel_spectrogram(data):
# Loop over each second of `audio_features`
new_data = data[0]
for i in range(1, data.shape[0]):
new_data = np.vstack((new_data, data[i]))

return new_data

def plot_mel_spectrograms(pred, y, text):
fig, ax = plt.subplots(2) # nrows=1, ncols=2)

# ax[0].set_title(f"Mel Spectogram (Predicted)")
pred = np.swapaxes(pred, 0, 1)
cax = ax[0].imshow(pred, interpolation='nearest', cmap=cm.coolwarm, origin='lower')

# ax[1].set_title(f"Mel Spectogram (Ground Truth)")
y = np.swapaxes(y, 0, 1)
cax = ax[1].imshow(y, interpolation='nearest', cmap=cm.coolwarm, origin='lower')

return fig, ax

def main(unused_argv):
g_dataset_path = FLAGS.ground_dataset_path
p_dataset_path = FLAGS.pred_dataset_path
closed_only = FLAGS.closed_only

# get desired book, sentence_idx

# ground truth
# voiced pred
# silent pred

pred_test_dataset = SilentSpeechPredDataset(\
p_dataset_path, dataset_type="test", silent_only=True)

ground_test_dataset = SilentSpeechDataset(\
g_dataset_path, dataset_type="test")

for i in range(len(pred_test_dataset))[0:min(len(pred_test_dataset)-1, FLAGS.max_examples)]:
p_mel_spectrogram, p_text, \
_, book, sentence_idx = pred_test_dataset.get_item_vis(i)
waveform, sr, g_text, _ = ground_test_dataset.get_exact(book, sentence_idx)

g_mel_spectrogram = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)

g_mel_spectrogram = torch.log(g_mel_spectrogram+1e-5)

fig, ax = plot_mel_spectrograms(\
stack_mel_spectrogram(p_mel_spectrogram),
stack_mel_spectrogram(g_mel_spectrogram),
g_text)

print(g_text, sentence_idx)

ax1 = ax[0].get_window_extent().transformed(fig.dpi_scale_trans.inverted()).expanded(1.2, 1.3)
ax2 = ax[1].get_window_extent().transformed(fig.dpi_scale_trans.inverted()).expanded(1.2, 1.3)
fig.savefig(f"./testset_visuals/{sentence_idx}_p.png", bbox_inches=ax1)
fig.savefig(f"./testset_visuals/{sentence_idx}_g.png", bbox_inches=ax2)

def entry_point():
app.run(main)

if __name__ == "__main__":
pass
app.run(main)

0 comments on commit d5e96e0

Please sign in to comment.