Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noise inject #35 #53

Merged
merged 3 commits into from
May 8, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added noise injection capabilities, small refactor
  • Loading branch information
SeanNaren committed May 7, 2017
commit 4b198b18f263edb0578163bea3f16ff12a61194d
80 changes: 66 additions & 14 deletions data/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
import os
from tempfile import NamedTemporaryFile

import librosa
import numpy as np
import scipy.signal
import torch
import torchaudio
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchaudio
from tempfile import NamedTemporaryFile
import os

windows = {'hamming': scipy.signal.hamming, 'hann': scipy.signal.hann, 'blackman': scipy.signal.blackman,
'bartlett': scipy.signal.bartlett}


def load_audio(path):
sound, _ = torchaudio.load(path.encode('utf-8')) # py3 fix
sound = sound.numpy()
if len(sound.shape) > 1:
if sound.shape[1] == 1:
#single channel, we just squeeze it
sound = sound.squeeze()
else:
#multiple channels, average
sound = sound.mean(axis=1)
sound = sound.mean(axis=1) # multiple channels, average
return sound


class AudioParser(object):
def parse_transcript(self, transcript_path):
"""
Expand All @@ -39,10 +40,52 @@ def parse_audio(self, audio_path):
raise NotImplementedError


class NoiseInjection(object):
def __init__(self,
path=None,
sr=16000,
noise_levels=(0, 0.5)):
"""
Adds noise to an input signal with specific SNR.
Modified code from https://github.com/willfrey/audio/blob/master/torchaudio/transforms.py
"""
self.paths = librosa.util.find_files(path)
self.sr = sr
self.noise_levels = noise_levels

def inject_noise(self, data):
noise_src = load_audio(np.random.choice(self.paths))
noise_offset_fraction = np.random.rand()
noise_level = np.random.uniform(*self.noise_levels)

noise_dst = np.zeros_like(data)

src_offset = int(len(noise_src) * noise_offset_fraction)
src_left = len(noise_src) - src_offset

dst_offset = 0
dst_left = len(data)

while dst_left > 0:
copy_size = min(dst_left, src_left)
np.copyto(noise_dst[dst_offset:dst_offset + copy_size],
noise_src[src_offset:src_offset + copy_size])
if src_left > dst_left:
dst_left = 0
else:
dst_left -= copy_size
dst_offset += copy_size
src_left = len(noise_src)
src_offset = 0

data += noise_level * noise_dst
return data


class SpectrogramParser(AudioParser):
def __init__(self, audio_conf, normalize=False, augment=False):
"""
Parses audio file into spectrogram with optional normalization
Parses audio file into spectrogram with optional normalization and various augmentations
:param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
:param normalize(default False): Apply standard mean and deviation normalization to audio tensor
:param augment(default False): Apply random tempo and gain perturbations
Expand All @@ -54,12 +97,20 @@ def __init__(self, audio_conf, normalize=False, augment=False):
self.window = windows.get(audio_conf['window'], windows['hamming'])
self.normalize = normalize
self.augment = augment
self.noiseInjector = NoiseInjection(audio_conf['noise_dir'], self.sample_rate,
audio_conf['noise_levels']) if audio_conf.get(
'noise_dir') is not None else None
self.noise_prob = audio_conf.get('noise_prob')

def parse_audio(self, audio_path):
if not self.augment:
y = load_audio(audio_path)
else:
if self.augment:
y = load_randomly_augmented_audio(audio_path)
else:
y = load_audio(audio_path)
if self.noiseInjector:
add_noise = np.random.binomial(1, self.noise_prob)
if add_noise:
y = self.noiseInjector.inject_noise(y)
n_fft = int(self.sample_rate * self.window_size)
win_length = n_fft
hop_length = int(self.sample_rate * self.window_stride)
Expand Down Expand Up @@ -147,7 +198,6 @@ def func(p):
return inputs, targets, input_percentages, target_sizes



class AudioDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
"""
Expand All @@ -156,15 +206,17 @@ def __init__(self, *args, **kwargs):
super(AudioDataLoader, self).__init__(*args, **kwargs)
self.collate_fn = _collate_fn

def augment_audio_with_sox(path, sample_rate, tempo=1.0, gain=0.0):

def augment_audio_with_sox(path, sample_rate, tempo, gain):
"""
Changes tempo and gain of the recording with sox and loads it.
"""
with NamedTemporaryFile(suffix=".wav") as augmented_file:
augmented_filename = augmented_file.name
sox_augment_params = ["tempo", "{:.3f}".format(tempo), "gain", "{:.3f}".format(gain)]
sox_params = "sox {} -r {} -c 1 -b 16 {} {} >/dev/null 2>&1".format(path, sample_rate,
augmented_filename, " ".join(sox_augment_params))
augmented_filename,
" ".join(sox_augment_params))
os.system(sox_params)
y = load_audio(path)
return y
Expand All @@ -182,4 +234,4 @@ def load_randomly_augmented_audio(path, sample_rate=16000, tempo_range=(0.85, 1.
gain_value = np.random.uniform(low=low_gain, high=high_gain)
audio = augment_audio_with_sox(path=path, sample_rate=sample_rate,
tempo=tempo_value, gain=gain_value)
return audio
return audio
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
python-levenshtein
librosa
torch
visdom
wget
12 changes: 11 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@
parser.add_argument('--continue_from', default='', help='Continue from checkpoint model')
parser.add_argument('--rnn_type', default='lstm', help='Type of the RNN. rnn|gru|lstm are supported')
parser.add_argument('--augment', dest='augment', action='store_true', help='Use random tempo and gain perturbations.')
parser.add_argument('--noise_dir', default='',
help='Directory to inject noise into audio. If default, noise Inject not added')
parser.add_argument('--noise_prob', default=0.4, help='Probability of noise being added per sample')
parser.add_argument('--noise_min', default=0.0,
help='Minimum noise level to sample from. (1.0 means all noise, not original signal)', type=float)
parser.add_argument('--noise_max', default=0.5,
help='Maximum noise levels to sample from. Maximum 1.0', type=float)
parser.set_defaults(cuda=False, silent=False, checkpoint=False, visdom=False, augment=False)


Expand Down Expand Up @@ -121,7 +128,10 @@ def main():
audio_conf = dict(sample_rate=args.sample_rate,
window_size=args.window_size,
window_stride=args.window_stride,
window=args.window)
window=args.window,
noise_dir=args.noise_dir,
noise_prob=args.noise_prob,
noise_levels=(args.noise_min, args.noise_max))

train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, labels=labels,
normalize=True, augment=args.augment)
Expand Down