Skip to content

Commit

Permalink
done with refactoring and checks
Browse files Browse the repository at this point in the history
  • Loading branch information
mohammad committed Dec 30, 2020
1 parent 9a01031 commit b77d906
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 131 deletions.
1 change: 1 addition & 0 deletions megatron/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

from apex.optimizers import FusedAdam as Adam

from megatron import get_args
from megatron.model import import_layernorm

Expand Down
209 changes: 78 additions & 131 deletions megatron/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
from abc import abstractmethod

import torch
from torch._six import inf

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0

from .clip_grads import clip_grad_norm_fp32


def _zero_grad_group_helper(group, set_to_none):
Expand All @@ -43,95 +45,6 @@ def _zero_grad_group_helper(group, set_to_none):
param.grad.zero_()


def _clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""

if isinstance(parameters, torch.Tensor):
parameters = [parameters]

# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
for param in parameters:
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor'
grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared
is_not_tp_duplicate = param.tensor_model_parallel or \
(mpu.get_tensor_model_parallel_rank() == 0)
grad = param.grad.detach()
if grad_not_none:
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)

# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0.0

# Calculate norm.
if norm_type == inf:
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()

else:
if norm_type == 2.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False # no per-parameter norm
)
total_norm = grad_norm ** norm_type

else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type

# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
total_norm = total_norm.item() ** (1.0 / norm_type)

# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads, grads],
clip_coeff)

return total_norm



class MegatronOptimizer(ABC):

Expand All @@ -145,7 +58,7 @@ def clip_grad_norm(self, clip_grad):
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
_clip_grad_norm(params, clip_grad)
clip_grad_norm_fp32(params, clip_grad)

@abstractmethod
def zero_grad(self, set_to_none=True):
Expand Down Expand Up @@ -283,16 +196,7 @@ def get_loss_scale(self):
return self.grad_scaler.scale


@torch.no_grad()
def step(self):

timers = get_timers()

# ==================================================
# Copy gradients from model params to master params.
# ==================================================

timers('optimizer-copy-to-master-grad').start()
def _copy_model_grads_to_master_grads(self):
# This only needs to be done for the fp16 group.
model_grads = []
master_grads = []
Expand All @@ -302,26 +206,28 @@ def step(self):
if model_param.grad is not None:
if master_param.grad is None:
master_param.grad = torch.empty_like(master_param)
model_grads.append(model_param.grad)
master_grads.append(master_param.grad)
model_grads.append(model_param.grad.data)
master_grads.append(master_param.grad.data)
self._dummy_overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
self._dummy_overflow_buf,
[model_grads, master_grads],
1.0)
timers('optimizer-copy-to-master-grad').stop()

# ==============================
# Unscale and check for inf/nan.
# ==============================

timers('optimizer-unscale-and-check-inf').start()
def _unscale_master_grads_and_check_for_nan(self):
master_grads = []
# fp32 params fromm fp16 ones.
for master_group in self.fp32_from_fp16_groups:
for master_param in master_group:
if master_param.grad is not None:
master_grads.append(master_param.grad.data)
# Append fp32 parameters.
for master_group in self.fp32_from_fp32_groups:
for master_param in master_group:
if master_param.grad is not None:
master_grads.append(master_param.grad)
master_grads.append(master_param.grad.data)
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
Expand All @@ -331,13 +237,52 @@ def step(self):
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())

# Check for nan.
found_inf_flag = (self.found_inf.item() > 0)
return found_inf_flag


def _copy_master_params_to_model_params(self):
# Only needed for the fp16 params.
model_data = []
master_data = []
for model_group, master_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups):
for model_param, master_param in zip(model_group, master_group):
model_data.append(model_param.data)
master_data.append(master_param.data)
self._dummy_overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
self._dummy_overflow_buf,
[master_data, model_data],
1.0)


@torch.no_grad()
def step(self):

timers = get_timers()

# ==================================================
# Copy gradients from model params to master params.
# ==================================================
timers('optimizer-copy-to-master-grad').start()
self._copy_model_grads_to_master_grads()
timers('optimizer-copy-to-master-grad').stop()

# ==============================
# Unscale and check for inf/nan.
# ==============================
timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_master_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()

# ==================================
# We are done with scaling gradients
# so we can update the loss scale.
# ==================================
found_inf_flag = (self.found_inf.item() > 0)
self.grad_scaler.update(found_inf_flag)

# =====================================
Expand All @@ -349,38 +294,25 @@ def step(self):
# ==========================
# Clip the master gradients.
# ==========================

timers('optimizer-clip-master-grad').start()
self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-master-grad').stop()

# ===================
# Step the optimizer.
# ===================

self.optimizer.step()

# =================================
# Update params from master params.
# =================================

timers('optimizer-copy-master-to-model-params').start()
# Only needed for the fp16 params.
model_data = []
master_data = []
for model_group, master_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups):
for model_param, master_param in zip(model_group, master_group):
model_data.append(model_param.data)
master_data.append(master_param.data)
self._dummy_overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
self._dummy_overflow_buf,
[master_data, model_data],
1.0)
self._copy_master_params_to_model_params()
timers('optimizer-copy-master-to-model-params').stop()

# ==================
# Successful update.
# ==================
return True


Expand All @@ -393,13 +325,28 @@ def state_dict(self):


def load_state_dict(self, state_dict):
# Defer to the class to load.
self.optimizer.load_state_dict(state_dict['optimizer'])
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# Optimizer.
optimizer_key = 'optimizer'
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
print_rank_0('***WARNING*** loading optimizer from '
'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key])

# Grad scaler.
if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])

# Copy data for the master params.
fp32_from_fp16_params_key = 'fp32_from_fp16_params'
if fp32_from_fp16_params_key not in state_dict:
fp32_from_fp16_params_key = 'fp32_from_fp16'
for current_group, saved_group in zip(
self.fp32_from_fp16_groups,
state_dict['fp32_from_fp16_params']):
state_dict[fp32_from_fp16_params_key]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)

Expand Down

0 comments on commit b77d906

Please sign in to comment.