Skip to content

Commit

Permalink
Merge pull request SeanNaren#323 from SeanNaren/multiproc
Browse files Browse the repository at this point in the history
Add GPU IDs for multiproc script, sync losses
  • Loading branch information
Sean Naren committed Jul 12, 2018
2 parents 27ca24c + 90403f9 commit edf86aa
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ multiproc will open a log for all processes other than the main process.

We suggest using the gloo backend which defaults to TCP if Infiniband isn't available. Using NCCL2 is also possible as a backend. More information [here](http:https://pytorch.org/docs/master/distributed.html#distributed-basics).

You can also specify specific GPU IDs rather than allowing the script to use all available GPUs:

```
python -m multiproc train.py --visdom --cuda --device-ids 0,1,2,3 # Add your parameters as normal, will only run on 4 GPUs
```

### Noise Augmentation/Injection

There is support for two different types of noise; noise augmentation and noise injection.
Expand Down
8 changes: 8 additions & 0 deletions data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from tqdm import tqdm
import subprocess
import torch.distributed as dist


def create_manifest(data_path, output_path, min_duration=None, max_duration=None):
Expand Down Expand Up @@ -34,3 +35,10 @@ def func(element):

duration_file_paths.sort(key=func)
return [x[0] for x in duration_file_paths] # Remove durations

def reduce_tensor(tensor, world_size):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= world_size
return rt

12 changes: 11 additions & 1 deletion multiproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

argslist = list(sys.argv)[1:]
world_size = torch.cuda.device_count()
device_ids = None
if '--device-ids' in argslist: # Manually specified GPU IDs
device_ids = argslist[argslist.index('--device-ids') + 1].strip().split(',')
world_size = len(device_ids)
# Remove GPU IDs since these are not for the training script
argslist.pop(argslist.index('--device-ids') + 1)
argslist.pop(argslist.index('--device-ids'))

if '--world-size' in argslist:
argslist[argslist.index('--world-size') + 1] = str(world_size)
Expand All @@ -20,7 +27,10 @@
argslist.append('--rank')
argslist.append(str(i))
if '--gpu-rank' in argslist:
argslist[argslist.index('--gpu-rank') + 1] = str(i)
if device_ids:
argslist[argslist.index('--gpu-rank') + 1] = str(device_ids[i])
else:
argslist[argslist.index('--gpu-rank') + 1] = str(i)
else:
argslist.append('--gpu-rank')
argslist.append(str(i))
Expand Down
10 changes: 6 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from warpctc_pytorch import CTCLoss

from data.data_loader import AudioDataLoader, SpectrogramDataset, BucketingSampler, DistributedBucketingSampler
from data.utils import reduce_tensor
from decoder import GreedyDecoder
from model import DeepSpeech, supported_rnns

Expand Down Expand Up @@ -248,13 +249,14 @@ def update(self, val, n=1):
loss = criterion(out, targets, output_sizes, target_sizes)
loss = loss / inputs.size(0) # average the loss by minibatch

loss_sum = loss.data.sum()
inf = float("inf")
if loss_sum == inf or loss_sum == -inf:
print("WARNING: received an inf loss, setting loss value to 0")
loss_value = 0
if args.distributed:
loss_value = reduce_tensor(loss, args.world_size)[0]
else:
loss_value = loss.item()
if loss_value == inf or loss_value == -inf:
print("WARNING: received an inf loss, setting loss value to 0")
loss_value = 0

avg_loss += loss_value
losses.update(loss_value, inputs.size(0))
Expand Down

0 comments on commit edf86aa

Please sign in to comment.