forked from SeanNaren/deepspeech.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrated Hydra into the training script
- Loading branch information
sean.narenthiran
committed
Jun 20, 2020
1 parent
81cb575
commit f2c0c2d
Showing
3 changed files
with
142 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
training: | ||
no_cuda: false # Enable CPU only training | ||
finetune: false # Fine-tune the model from checkpoint "continue_from" | ||
seed: 123456 # Seed for generators | ||
dist_backend: nccl # If using distribution, the backend to be used | ||
epochs: 70 # Number of Training Epochs | ||
data: | ||
train_manifest: data/train_manifest.csv | ||
val_manifest: data/val_manifest.csv | ||
sample_rate: 16000 # The sample rate for the data/model features | ||
batch_size: 20 # Batch size for training | ||
num_workers: 4 # Number of workers used in data-loading | ||
labels_path: labels.json # Contains tokens for model output | ||
window_size: .02 # Window size for spectrogram generation (seconds) | ||
window_stride: .01 # Window stride for spectrogram generation (seconds) | ||
window: hamming # Window type for spectrogram generation | ||
model: | ||
rnn_type: lstm # Type of RNN to use in modeel, rnn/gru/lstm are supported | ||
hidden_size: 1024 # Hidden size of RNN Layer | ||
hidden_layers: 5 # Number of RNN layers | ||
bidirectional: true # Use BiRNNs. If False, uses lookahead conv | ||
optimizer: | ||
learning_rate: 3e-4 # Initial Learning Rate | ||
weight_decay: 1e-5 # Initial Weight Decay | ||
momentum: 0.9 | ||
adam: false # Replace SGD with AdamW | ||
eps: 1e-8 # Adam eps | ||
beta: (0.9, 0.999) # Adam betas | ||
max_norm: 400 # Norm cutoff to prevent explosion of gradients | ||
learning_anneal: 1.1 # Annealing applied to learning rate after each epoch | ||
checkpointing: | ||
continue_from: '' # Continue training from checkpoint model | ||
checkpoint: True # Enables epoch checkpoint saving of model | ||
checkpoint_per_iteration: 0 # Save checkpoint per N number of iterations. Default is disabled | ||
save_n_recent_models: 10 # Maximum number of checkpoints to save. If the max is reached, we delete older checkpoints | ||
save_folder: models/ # Location to save epoch models | ||
best_val_model_name: deepspeech_final.pth # Name to save best validated model within the save folder | ||
load_auto_checkpoint: false # Enable when handling interruptions. Automatically load the latest checkpoint from the save folder | ||
visualization: | ||
visdom: false # Turn on visdom graphing | ||
tensorboard: false # Turn on Tensorboard graphing | ||
log_dir: visualize/deepspeech_final # Location of Tensorboard log | ||
log_params: false # Log parameter values and gradients | ||
id: DeepSpeech training # Identifier for visdom/tensorboard run | ||
augmentation: | ||
speed_volume_perturb: false # Use random tempo and gain perturbations. | ||
spec_augment: false # Use simple spectral augmentation on mel spectograms. | ||
noise_dir: '' # Directory to inject noise into audio. If default, noise Inject not added | ||
noise_prob: 0.4 # Probability of noise being added per sample | ||
noise_min: 0.0 # Minimum noise level to sample from. (1.0 means all noise, not original signal) | ||
noise_max: 0.5 # Maximum noise levels to sample from. Maximum 1.0 | ||
apex: | ||
opt_level: O1 # Apex optimization level, check https://nvidia.github.io/apex/amp.html for more information | ||
loss_scale: 1 # Loss scaling used by Apex. Default is 1 due to warp-ctc not supporting scaling of gradients | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,4 +12,5 @@ matplotlib | |
flask | ||
sox | ||
sklearn | ||
soundfile | ||
soundfile | ||
hydra-core |
Oops, something went wrong.