apex.parallel¶
-
class
apex.parallel.
DistributedDataParallel
(module, message_size=10000000, delay_allreduce=False, shared_param=None, allreduce_trigger_params=None, retain_allreduce_buffers=False, allreduce_always_fp32=False, num_allreduce_streams=1, allreduce_communicators=None, gradient_average=True, gradient_predivide_factor=1.0, gradient_average_split_factor=None, prof=False)[source]¶ apex.parallel.DistributedDataParallel
is a module wrapper that enables easy multiprocess distributed data parallel training, similar totorch.nn.parallel.DistributedDataParallel
. Parameters are broadcast across participating processes on initialization, and gradients are allreduced and averaged over processes duringbackward()
.DistributedDataParallel
is optimized for use with NCCL. It achieves high performance by overlapping communication with computation duringbackward()
and bucketing smaller gradient transfers to reduce the total number of transfers required.DistributedDataParallel
is designed to work with the upstream launch utility scripttorch.distributed.launch
with--nproc_per_node <= number of gpus per node
. When used with this launcher,DistributedDataParallel
assumes 1:1 mapping of processes to GPUs. It also assumes that your script callstorch.cuda.set_device(args.rank)
before creating the model.https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed shows detailed usage. https://github.com/NVIDIA/apex/tree/master/examples/imagenet shows another example that combines
DistributedDataParallel
with mixed precision training.- Parameters
module – Network definition to be run in multi-gpu/distributed mode.
message_size (int, default=1e7) – Minimum number of elements in a communication bucket.
delay_allreduce (bool, default=False) – Delay all communication to the end of the backward pass. This disables overlapping communication with computation.
allreduce_trigger_params (list, optional, default=None) – If supplied, should contain a list of parameters drawn from the model. Allreduces will be kicked off whenever one of these parameters receives its gradient (as opposed to when a bucket of size message_size is full). At the end of backward(), a cleanup allreduce to catch any remaining gradients will also be performed automatically. If allreduce_trigger_params is supplied, the message_size argument will be ignored.
allreduce_always_fp32 (bool, default=False) – Convert any FP16 gradients to FP32 before allreducing. This can improve stability for widely scaled-out runs.
gradient_average (bool, default=True) – Option to toggle whether or not DDP averages the allreduced gradients over processes. For proper scaling, the default value of True is recommended.
gradient_predivide_factor (float, default=1.0) – Allows perfoming the average of gradients over processes partially before and partially after the allreduce. Before allreduce:
grads.mul_(1.0/gradient_predivide_factor)
. After allreduce:grads.mul_(gradient_predivide_factor/world size)
. This can reduce the stress on the dynamic range of FP16 allreduces for widely scaled-out runs.
Warning
If
gradient_average=False
, the pre-allreduce division (grads.mul_(1.0/gradient_predivide_factor)
) will still be applied, but the post-allreduce gradient averaging (grads.mul_(gradient_predivide_factor/world size)
) will be omitted.-
forward
(*inputs, **kwargs)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
class
apex.parallel.
Reducer
(module_or_grads_list)[source]¶ apex.parallel.Reducer
is a simple class that helps allreduce a module’s parameters across processes.Reducer
is intended to give the user additional control: UnlikeDistributedDataParallel
,Reducer
will not automatically allreduce parameters duringbackward()
. Instead,Reducer
waits for the user to call<reducer_instance>.reduce()
manually. This enables, for example, delaying the allreduce to be carried out every several iterations instead of every single iteration.Like
DistributedDataParallel
,Reducer
averages any tensors it allreduces over the number of participating processes.Reducer
is designed to work with the upstream launch utility scripttorch.distributed.launch
with--nproc_per_node <= number of gpus per node
. When used with this launcher,Reducer
assumes 1:1 mapping of processes to GPUs. It also assumes that your script callstorch.cuda.set_device(args.rank)
before creating the model.- Parameters
module_or_grads_list – Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they’re all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module’s parameters at the beginning of training.
-
class
apex.parallel.
SyncBatchNorm
(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False)[source]¶ synchronized batch normalization module extented from torch.nn.BatchNormNd with the added stats reduction across multiple processes.
apex.parallel.SyncBatchNorm
is designed to work with DistributedDataParallel.When running in training mode, the layer reduces stats across all processes to increase the effective batchsize for normalization layer. This is useful in applications where batch size is small on a given process that would diminish converged accuracy of the model. The model uses collective communication package from torch.distributed.
When running in evaluation mode, the layer falls back to torch.nn.functional.batch_norm
- Parameters
num_features – \(C\) from an expected input of size \((N, C, L)\) or \(L\) from input of size \((N, L)\)
eps – a value added to the denominator for numerical stability. Default: 1e-5
momentum – the value used for the running_mean and running_var computation. Can be set to
None
for cumulative moving average (i.e. simple average). Default: 0.1affine – a boolean value that when set to
True
, this module has learnable affine parameters. Default:True
track_running_stats – a boolean value that when set to
True
, this module tracks the running mean and variance, and when set toFalse
, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default:True
process_group – pass in a process group within which the stats of the mini-batch is being synchronized.
None
for using default process groupchannel_last – a boolean value that when set to
True
, this module take the last dimension of the input tensor to be the channel dimension. Default: False
- Examples::
>>> # channel first tensor >>> sbn = apex.parallel.SyncBatchNorm(100).cuda() >>> inp = torch.randn(10, 100, 14, 14).cuda() >>> out = sbn(inp) >>> inp = torch.randn(3, 100, 20).cuda() >>> out = sbn(inp) >>> # channel last tensor >>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda() >>> inp = torch.randn(10, 14, 14, 100).cuda()
-
forward
(input, z=None)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Utility functions¶
-
apex.parallel.
convert_syncbn_model
(module, process_group=None, channel_last=False)[source]¶ Recursively traverse module and its children to replace all instances of
torch.nn.modules.batchnorm._BatchNorm
withapex.parallel.SyncBatchNorm
.All
torch.nn.BatchNorm*N*d
wrap aroundtorch.nn.modules.batchnorm._BatchNorm
, so this function lets you easily switch to use sync BN.- Parameters
module (torch.nn.Module) – input module
Example:
>>> # model is an instance of torch.nn.Module >>> import apex >>> sync_bn_model = apex.parallel.convert_syncbn_model(model)