From f2c0c2d808020e01902e3fd126ad90f570da30cf Mon Sep 17 00:00:00 2001 From: "sean.narenthiran" Date: Sat, 20 Jun 2020 18:53:36 +0100 Subject: [PATCH] Integrated Hydra into the training script --- config.yaml | 55 ++++++++++++ requirements.txt | 3 +- train.py | 225 ++++++++++++++++++----------------------------- 3 files changed, 142 insertions(+), 141 deletions(-) create mode 100644 config.yaml diff --git a/config.yaml b/config.yaml new file mode 100644 index 00000000..876d345e --- /dev/null +++ b/config.yaml @@ -0,0 +1,55 @@ +training: + no_cuda: false # Enable CPU only training + finetune: false # Fine-tune the model from checkpoint "continue_from" + seed: 123456 # Seed for generators + dist_backend: nccl # If using distribution, the backend to be used + epochs: 70 # Number of Training Epochs +data: + train_manifest: data/train_manifest.csv + val_manifest: data/val_manifest.csv + sample_rate: 16000 # The sample rate for the data/model features + batch_size: 20 # Batch size for training + num_workers: 4 # Number of workers used in data-loading + labels_path: labels.json # Contains tokens for model output + window_size: .02 # Window size for spectrogram generation (seconds) + window_stride: .01 # Window stride for spectrogram generation (seconds) + window: hamming # Window type for spectrogram generation +model: + rnn_type: lstm # Type of RNN to use in modeel, rnn/gru/lstm are supported + hidden_size: 1024 # Hidden size of RNN Layer + hidden_layers: 5 # Number of RNN layers + bidirectional: true # Use BiRNNs. If False, uses lookahead conv +optimizer: + learning_rate: 3e-4 # Initial Learning Rate + weight_decay: 1e-5 # Initial Weight Decay + momentum: 0.9 + adam: false # Replace SGD with AdamW + eps: 1e-8 # Adam eps + beta: (0.9, 0.999) # Adam betas + max_norm: 400 # Norm cutoff to prevent explosion of gradients + learning_anneal: 1.1 # Annealing applied to learning rate after each epoch +checkpointing: + continue_from: '' # Continue training from checkpoint model + checkpoint: True # Enables epoch checkpoint saving of model + checkpoint_per_iteration: 0 # Save checkpoint per N number of iterations. Default is disabled + save_n_recent_models: 10 # Maximum number of checkpoints to save. If the max is reached, we delete older checkpoints + save_folder: models/ # Location to save epoch models + best_val_model_name: deepspeech_final.pth # Name to save best validated model within the save folder + load_auto_checkpoint: false # Enable when handling interruptions. Automatically load the latest checkpoint from the save folder +visualization: + visdom: false # Turn on visdom graphing + tensorboard: false # Turn on Tensorboard graphing + log_dir: visualize/deepspeech_final # Location of Tensorboard log + log_params: false # Log parameter values and gradients + id: DeepSpeech training # Identifier for visdom/tensorboard run +augmentation: + speed_volume_perturb: false # Use random tempo and gain perturbations. + spec_augment: false # Use simple spectral augmentation on mel spectograms. + noise_dir: '' # Directory to inject noise into audio. If default, noise Inject not added + noise_prob: 0.4 # Probability of noise being added per sample + noise_min: 0.0 # Minimum noise level to sample from. (1.0 means all noise, not original signal) + noise_max: 0.5 # Maximum noise levels to sample from. Maximum 1.0 +apex: + opt_level: O1 # Apex optimization level, check https://nvidia.github.io/apex/amp.html for more information + loss_scale: 1 # Loss scaling used by Apex. Default is 1 due to warp-ctc not supporting scaling of gradients + diff --git a/requirements.txt b/requirements.txt index 8c2ae21e..bed87dee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ matplotlib flask sox sklearn -soundfile \ No newline at end of file +soundfile +hydra-core \ No newline at end of file diff --git a/train.py b/train.py index 5581d51d..49db3fd8 100644 --- a/train.py +++ b/train.py @@ -1,15 +1,17 @@ -import argparse import json import os import random import time +import hydra import numpy as np import torch.distributed as dist import torch.utils.data.distributed from apex import amp -from warpctc_pytorch import CTCLoss +from hydra.utils import to_absolute_path +from omegaconf import DictConfig from torch.nn.parallel import DistributedDataParallel +from warpctc_pytorch import CTCLoss from data.data_loader import SpectrogramDataset, DSRandomSampler, DSElasticDistributedSampler, AudioDataLoader from decoder import GreedyDecoder @@ -19,75 +21,6 @@ from test import evaluate from utils import check_loss, CheckpointHandler -parser = argparse.ArgumentParser(description='DeepSpeech training') -parser.add_argument('--train-manifest', metavar='DIR', - help='path to train manifest csv', default='data/train_manifest.csv') -parser.add_argument('--val-manifest', metavar='DIR', - help='path to validation manifest csv', default='data/val_manifest.csv') -parser.add_argument('--sample-rate', default=16000, type=int, help='Sample rate') -parser.add_argument('--batch-size', default=20, type=int, help='Batch size for training') -parser.add_argument('--num-workers', default=4, type=int, help='Number of workers used in data-loading') -parser.add_argument('--labels-path', default='labels.json', help='Contains all characters for transcription') -parser.add_argument('--window-size', default=.02, type=float, help='Window size for spectrogram in seconds') -parser.add_argument('--window-stride', default=.01, type=float, help='Window stride for spectrogram in seconds') -parser.add_argument('--window', default='hamming', help='Window type for spectrogram generation') -parser.add_argument('--hidden-size', default=1024, type=int, help='Hidden size of RNNs') -parser.add_argument('--hidden-layers', default=5, type=int, help='Number of RNN layers') -parser.add_argument('--rnn-type', default='lstm', help='Type of the RNN. rnn|gru|lstm are supported') -parser.add_argument('--epochs', default=70, type=int, help='Number of training epochs') -parser.add_argument('--no-cuda', dest='no_cuda', action='store_true', help='Enable CPU only training') -parser.add_argument('--lr', '--learning-rate', default=3e-4, type=float, help='initial learning rate') -parser.add_argument('--wd', '--weight_decay', default=1e-5, type=float, help='Initial weight decay') -parser.add_argument('--momentum', default=0.9, type=float, help='momentum') -parser.add_argument('--adam', dest='adam', action='store_true', help='Replace SGD with Adam') -parser.add_argument('--eps', default=1e-8, type=float, help='ADAM eps') -parser.add_argument('--betas', default=(0.9, 0.999), nargs='+', help='ADAM betas') -parser.add_argument('--max-norm', default=400, type=int, help='Norm cutoff to prevent explosion of gradients') -parser.add_argument('--learning-anneal', default=1.1, type=float, help='Annealing applied to learning rate every epoch') -parser.add_argument('--checkpoint', dest='checkpoint', action='store_true', - help='Enables epoch checkpoint saving of model') -parser.add_argument('--checkpoint-per-iteration', default=0, type=int, - help='Save checkpoint per N number of iterations. Default is disabled') -parser.add_argument('--save-n-recent-models', default=0, type=int, - help='Maximum number of checkpoints to save. If the max is reached, we delete older checkpoints.' - 'Default is there is no maximum number, so we save all checkpoints.') -parser.add_argument('--visdom', dest='visdom', action='store_true', help='Turn on visdom graphing') -parser.add_argument('--tensorboard', dest='tensorboard', action='store_true', help='Turn on tensorboard graphing') -parser.add_argument('--log-dir', default='visualize/deepspeech_final', help='Location of tensorboard log') -parser.add_argument('--log-params', dest='log_params', action='store_true', help='Log parameter values and gradients') -parser.add_argument('--id', default='Deepspeech training', help='Identifier for visdom/tensorboard run') -parser.add_argument('--save-folder', default='models/', help='Location to save epoch models') -parser.add_argument('--best-val-model-name', default='deepspeech_final.pth', - help='Location to save best validated model within the save folder') -parser.add_argument('--continue-from', default='', help='Continue from checkpoint model') -parser.add_argument('--finetune', dest='finetune', action='store_true', - help='Finetune the model from checkpoint "continue_from"') -parser.add_argument('--speed-volume-perturb', dest='speed_volume_perturb', action='store_true', - help='Use random tempo and gain perturbations.') -parser.add_argument('--spec-augment', dest='spec_augment', action='store_true', - help='Use simple spectral augmentation on mel spectograms.') -parser.add_argument('--noise-dir', default=None, - help='Directory to inject noise into audio. If default, noise Inject not added') -parser.add_argument('--noise-prob', default=0.4, help='Probability of noise being added per sample') -parser.add_argument('--noise-min', default=0.0, - help='Minimum noise level to sample from. (1.0 means all noise, not original signal)', type=float) -parser.add_argument('--noise-max', default=0.5, - help='Maximum noise levels to sample from. Maximum 1.0', type=float) -parser.add_argument('--no-bidirectional', dest='bidirectional', action='store_false', default=True, - help='Turn off bi-directional RNNs, introduces lookahead convolution') -parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') -parser.add_argument('--load_auto_checkpoint', dest='load_auto_checkpoint', action='store_true', - help='Enable when handling interruptions. Automatically load the latest checkpoint from the ' - 'save folder') -parser.add_argument('--seed', default=123456, type=int, help='Seed to generators') -parser.add_argument('--opt-level', type=str, - help='Apex optimization level,' - 'check https://nvidia.github.io/apex/amp.html for more information') -parser.add_argument('--keep-batchnorm-fp32', type=str, default=None, - help='Overrides Apex keep_batch_norm_fp32 flag') -parser.add_argument('--loss-scale', default=1, - help='Loss scaling used by Apex. Default is 1 due to warp-ctc not supporting scaling of gradients') - class AverageMeter(object): """Computes and stores the average and current value""" @@ -108,19 +41,19 @@ def update(self, val, n=1): self.avg = self.sum / self.count -if __name__ == '__main__': - args = parser.parse_args() - +def main(cfg): # Set seeds for determinism - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - np.random.seed(args.seed) - random.seed(args.seed) + torch.manual_seed(cfg.training.seed) + torch.cuda.manual_seed_all(cfg.training.seed) + np.random.seed(cfg.training.seed) + random.seed(cfg.training.seed) main_proc = True - device = torch.device("cpu" if args.no_cuda else "cuda") - args.distributed = os.environ.get("LOCAL_RANK") # If local rank exists, distributed env - if args.distributed: + device = torch.device("cpu" if cfg.training.no_cuda else "cuda") + + is_distributed = os.environ.get("LOCAL_RANK") # If local rank exists, distributed env + + if is_distributed: # when using NCCL, on failures, surviving nodes will deadlock on NCCL ops # because NCCL uses a spin-lock on the device. Set this env var and # to enable a watchdog thread that will destroy stale NCCL communicators @@ -130,107 +63,110 @@ def update(self, val, n=1): torch.cuda.set_device(device_id) print(f"Setting CUDA Device to {device_id}") - dist.init_process_group(backend=args.dist_backend) + dist.init_process_group(backend=cfg.training.dist_backend) main_proc = device_id == 0 # Main process handles saving of models and reporting - checkpoint_handler = CheckpointHandler(save_folder=args.save_folder, - best_val_model_name=args.best_val_model_name, - checkpoint_per_iteration=args.checkpoint_per_iteration, - save_n_recent_models=args.save_n_recent_models) + checkpoint_handler = CheckpointHandler(save_folder=to_absolute_path(cfg.checkpointing.save_folder), + best_val_model_name=cfg.checkpointing.best_val_model_name, + checkpoint_per_iteration=cfg.checkpointing.checkpoint_per_iteration, + save_n_recent_models=cfg.checkpointing.save_n_recent_models) - if main_proc and args.visdom: - visdom_logger = VisdomLogger(args.id, args.epochs) - if main_proc and args.tensorboard: - tensorboard_logger = TensorBoardLogger(args.id, args.log_dir, args.log_params) + if main_proc and cfg.visualization.visdom: + visdom_logger = VisdomLogger(id=cfg.visualization.id, + num_epochs=cfg.training.epochs) + if main_proc and cfg.visualization.tensorboard: + tensorboard_logger = TensorBoardLogger(id=cfg.visualization.id, + log_dir=to_absolute_path(cfg.visualization.log_dir), + log_params=cfg.visualization.log_params) - if args.load_auto_checkpoint: + if cfg.checkpointing.load_auto_checkpoint: latest_checkpoint = checkpoint_handler.find_latest_checkpoint() if latest_checkpoint: - args.continue_from = latest_checkpoint + cfg.checkpointing.continue_from = latest_checkpoint - if args.continue_from: # Starting from previous model - state = TrainingState.load_state(state_path=args.continue_from) + if cfg.checkpointing.continue_from: # Starting from previous model + state = TrainingState.load_state(state_path=to_absolute_path(cfg.checkpointing.continue_from)) model = state.model - if args.finetune: - state.init_finetune_states(args.epochs) + if cfg.training.finetune: + state.init_finetune_states(cfg.training.epochs) - if main_proc and args.visdom: # Add previous scores to visdom graph + if main_proc and cfg.visualization.visdom: # Add previous scores to visdom graph visdom_logger.load_previous_values(state.epoch, state.results) - if main_proc and args.tensorboard: # Previous scores to tensorboard logs + if main_proc and cfg.visualization.tensorboard: # Previous scores to tensorboard logs tensorboard_logger.load_previous_values(state.epoch, state.results) else: # Initialise new model training - with open(args.labels_path) as label_file: + with open(to_absolute_path(cfg.data.labels_path)) as label_file: labels = json.load(label_file) - audio_conf = dict(sample_rate=args.sample_rate, - window_size=args.window_size, - window_stride=args.window_stride, - window=args.window, - noise_dir=args.noise_dir, - noise_prob=args.noise_prob, - noise_levels=(args.noise_min, args.noise_max)) + audio_conf = dict(sample_rate=cfg.data.sample_rate, + window_size=cfg.data.window_size, + window_stride=cfg.data.window_stride, + window=cfg.data.window) + if cfg.augmentation.noise_dir: + audio_conf += dict(noise_dir=to_absolute_path(cfg.augmentation.noise_dir), + noise_prob=cfg.augmentation.noise_prob, + noise_levels=(cfg.augmentation.noise_min, cfg.augmentation.noise_max)) - rnn_type = args.rnn_type.lower() + rnn_type = cfg.model.rnn_type.lower() assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru" - model = DeepSpeech(rnn_hidden_size=args.hidden_size, - nb_layers=args.hidden_layers, + model = DeepSpeech(rnn_hidden_size=cfg.model.hidden_size, + nb_layers=cfg.model.hidden_layers, labels=labels, rnn_type=supported_rnns[rnn_type], audio_conf=audio_conf, - bidirectional=args.bidirectional) + bidirectional=cfg.model.bidirectional) state = TrainingState(model=model) - state.init_results_tracking(epochs=args.epochs) + state.init_results_tracking(epochs=cfg.training.epochs) # Data setup evaluation_decoder = GreedyDecoder(model.labels) # Decoder used for validation train_dataset = SpectrogramDataset(audio_conf=model.audio_conf, - manifest_filepath=args.train_manifest, + manifest_filepath=to_absolute_path(cfg.data.train_manifest), labels=model.labels, normalize=True, - speed_volume_perturb=args.speed_volume_perturb, - spec_augment=args.spec_augment) + speed_volume_perturb=cfg.augmentation.speed_volume_perturb, + spec_augment=cfg.augmentation.spec_augment) test_dataset = SpectrogramDataset(audio_conf=model.audio_conf, - manifest_filepath=args.val_manifest, + manifest_filepath=to_absolute_path(cfg.data.val_manifest), labels=model.labels, normalize=True, speed_volume_perturb=False, spec_augment=False) - if not args.distributed: + if not is_distributed: train_sampler = DSRandomSampler(dataset=train_dataset, - batch_size=args.batch_size, + batch_size=cfg.data.batch_size, start_index=state.training_step) else: train_sampler = DSElasticDistributedSampler(dataset=train_dataset, - batch_size=args.batch_size, + batch_size=cfg.data.batch_size, start_index=state.training_step) train_loader = AudioDataLoader(dataset=train_dataset, - num_workers=args.num_workers, + num_workers=cfg.data.num_workers, batch_sampler=train_sampler) test_loader = AudioDataLoader(dataset=test_dataset, - num_workers=args.num_workers, - batch_size=args.batch_size) + num_workers=cfg.data.num_workers, + batch_size=cfg.data.batch_size) model = model.to(device) parameters = model.parameters() - if args.adam: + if cfg.optimizer.adam: optimizer = torch.optim.AdamW(parameters, - lr=args.lr, - betas=args.betas, - eps=args.eps, - weight_decay=args.wd) + lr=cfg.optimizer.learning_rate, + betas=cfg.optimizer.betas, + eps=cfg.optimizer.eps, + weight_decay=cfg.optimizer.weight_decay) else: optimizer = torch.optim.SGD(parameters, - lr=args.lr, - momentum=args.momentum, + lr=cfg.optimizer.learning_rate, + momentum=cfg.optimizer.momentum, nesterov=True, - weight_decay=args.wd) + weight_decay=cfg.optimizer.weight_decay) model, optimizer = amp.initialize(model, optimizer, - opt_level=args.opt_level, - keep_batchnorm_fp32=args.keep_batchnorm_fp32, - loss_scale=args.loss_scale) + opt_level=cfg.apex.opt_level, + loss_scale=cfg.apex.loss_scale) if state.optim_state is not None: optimizer.load_state_dict(state.optim_state) amp.load_state_dict(state.amp_state) @@ -239,7 +175,7 @@ def update(self, val, n=1): state.track_optim_state(optimizer) state.track_amp_state(amp) - if args.distributed: + if is_distributed: model = DistributedDataParallel(model, device_ids=[device_id]) print(model) print("Number of parameters: %d" % DeepSpeech.get_param_size(model)) @@ -249,7 +185,7 @@ def update(self, val, n=1): data_time = AverageMeter() losses = AverageMeter() - for epoch in range(state.epoch, args.epochs): + for epoch in range(state.epoch, cfg.training.epochs): model.train() end = time.time() start_epoch_time = time.time() @@ -280,7 +216,7 @@ def update(self, val, n=1): # compute gradient with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), cfg.optimizer.max_norm) optimizer.step() else: print(error) @@ -299,7 +235,7 @@ def update(self, val, n=1): 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( (epoch + 1), (i + 1), len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) - if main_proc and args.checkpoint_per_iteration: + if main_proc and cfg.checkpointing.checkpoint_per_iteration: checkpoint_handler.save_iter_checkpoint_model(epoch=epoch, i=i, state=state) del loss, out, float_out @@ -326,16 +262,16 @@ def update(self, val, n=1): 'Average WER {wer:.3f}\t' 'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer)) - if main_proc and args.visdom: + if main_proc and cfg.visualization.visdom: visdom_logger.update(epoch, state.result_state) - if main_proc and args.tensorboard: + if main_proc and cfg.visualization.tensorboard: tensorboard_logger.update(epoch, state.result_state, model.named_parameters()) - if main_proc and args.checkpoint: # Save epoch checkpoint + if main_proc and cfg.checkpointing.checkpoint: # Save epoch checkpoint checkpoint_handler.save_checkpoint_model(epoch=epoch, state=state) # anneal lr for g in optimizer.param_groups: - g['lr'] = g['lr'] / args.learning_anneal + g['lr'] = g['lr'] / cfg.optimizer.learning_anneal print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr'])) if main_proc and (state.best_wer is None or state.best_wer > wer): @@ -343,3 +279,12 @@ def update(self, val, n=1): state.set_best_wer(wer) state.reset_avg_loss() state.reset_training_step() # Reset training step for next epoch + + +@hydra.main(config_path="config.yaml") +def hydra_main(cfg: DictConfig): + main(cfg=cfg) + + +if __name__ == '__main__': + hydra_main()