Skip to content

Commit

Permalink
Timing levels
Browse files Browse the repository at this point in the history
  • Loading branch information
Mohammad Shoeybi authored and jaredcasper committed Aug 10, 2022
1 parent 83d7867 commit 77efccc
Show file tree
Hide file tree
Showing 21 changed files with 495 additions and 219 deletions.
26 changes: 26 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,32 @@ def _add_logging_args(parser):
help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--timing-log-level', type=int,
default=0, choices=range(0,3),
help='Granularity level to measure and report timing. '
' 0: report only iteration time and make sure timing '
' does not introduce extra overhead.'
' 1: report timing for operations that are executed '
' very limited times (basically once) during '
' each iteration (such as gradient all-reduce) '
' 2: report timing for operations that migh be '
' executed numerous times during each iteration. '
'Note that setting the level to 1 or 2 might '
'cause increase in iteration time.')
group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
help='If not set, use barrier with level 1 time '
'measurements. Note that this is up to the user '
'to make sure calling barrier with their timers '
'will not result in hangs. This can happen if for '
'example the user adds a level 1 timer that is not '
'called by all ranks.',
dest='barrier_with_L1_time')
group.add_argument('--timing-log-option', type=str, default='minmax',
choices=['max', 'minmax', 'all'],
help='Options for logging timing:'
' max: report the max timing across all ranks'
' minmax: report min and max timings across all ranks'
' all: report timings of all ranks.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
Expand Down
90 changes: 5 additions & 85 deletions megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

import os
import sys
import time
from functools import reduce
import operator
import torch

from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer
from .microbatches import build_num_microbatches_calculator
from .timers import Timers

_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
Expand Down Expand Up @@ -108,7 +108,7 @@ def set_global_variables(args):
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
_set_timers(args)
_set_global_memory_buffer()

if args.exit_signal_handler:
Expand Down Expand Up @@ -182,11 +182,12 @@ def _set_adlr_autoresume(args):
_GLOBAL_ADLR_AUTORESUME = AutoResume


def _set_timers():
def _set_timers(args):
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers()
_GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)


def _set_global_memory_buffer():
"""Initialize global buffer"""
Expand All @@ -205,87 +206,6 @@ def _ensure_var_is_not_initialized(var, name):
assert var is None, '{} is already initialized.'.format(name)


class _Timer:
"""Timer."""

def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()

def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True

def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False

def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False

def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_


class Timers:
"""Group of timers."""

def __init__(self):
self.timers = {}

def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]

def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '-time', value, iteration)

def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)


class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Expand Down
20 changes: 12 additions & 8 deletions megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,17 +532,20 @@ def reduce_model_grads(self, args, timers):
"""

# All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start()
timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop()
timers('layernorm-grads-all-reduce').stop()

# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
timers('embedding-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
timers('embedding-grads-all-reduce').stop()

# Reduce-scatter setup.
timers('backward-params-all-reduce').start()
timers('grads-reduce-scatter', log_level=1).start(
barrier=args.barrier_with_L1_time)
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
Expand All @@ -563,7 +566,7 @@ def reduce_model_grads(self, args, timers):
group = data_parallel_group,
)

timers('backward-params-all-reduce').stop()
timers('grads-reduce-scatter').stop()


def gather_model_params(self, args, timers):
Expand All @@ -575,7 +578,8 @@ def gather_model_params(self, args, timers):
can be copied from param.main_grad to param.
"""

timers('backward-params-all-gather').start()
timers('params-all-gather', log_level=1).start(
barrier=args.barrier_with_L1_time)

data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
Expand All @@ -602,7 +606,7 @@ def gather_model_params(self, args, timers):
for param in param_map:
param.detach().copy_(param.main_grad)

timers('backward-params-all-gather').stop()
timers('params-all-gather').stop()


def _collect_main_grad_data_for_unscaling(self):
Expand Down
45 changes: 29 additions & 16 deletions megatron/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,21 +294,24 @@ def reduce_model_grads(self, args, timers):
"""All-reduce all grads, and all-reduce embeddings."""

# All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start()
timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop()
timers('layernorm-grads-all-reduce').stop()

# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
timers('grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
for model in self.models:
model.allreduce_gradients()
timers('backward-params-all-reduce').stop()
timers('grads-all-reduce').stop()

# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
timers('embedding-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
timers('embedding-grads-all-reduce').stop()


class MixedPrecisionOptimizer(MegatronOptimizer):
Expand Down Expand Up @@ -416,7 +419,8 @@ def _unscale_main_grads_and_check_for_nan(self):
def step(self, args, timers):

# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()

Expand All @@ -425,7 +429,8 @@ def step(self, args, timers):
if self.grad_scaler:

# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
timers('optimizer-unscale-and-check-inf', log_level=1).start(
barrier=args.barrier_with_L1_time)
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()

Expand All @@ -438,25 +443,29 @@ def step(self, args, timers):
return False, None, None

# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()

# Count the zeros in the grads.
timers('optimizer-count-zeros').start()
timers('optimizer-count-zeros', log_level=1).start(
barrier=args.barrier_with_L1_time)
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
timers('optimizer-count-zeros').stop()

# Step the optimizer.
timers('optimizer-inner-step').start()
timers('optimizer-inner-step', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.optimizer.step()
timers('optimizer-inner-step').stop()

# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
timers('optimizer-copy-main-to-model-params', log_level=1).start(
barrier=args.barrier_with_L1_time)
self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop()

Expand Down Expand Up @@ -725,7 +734,8 @@ def step(self, args, timers):
Always return successful since there is no overflow."""

# Copy main_grads to grads.
timers('optimizer-copy-to-main-grad').start()
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
if self.params_have_main_grad:
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
Expand All @@ -739,20 +749,23 @@ def step(self, args, timers):
timers('optimizer-copy-to-main-grad').stop()

# Clip gradients.
timers('optimizer-clip-main-grad').start()
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()

# count the zeros in the grads
timers('optimizer-count-zeros').start()
timers('optimizer-count-zeros', log_level=1).start(
barrier=args.barrier_with_L1_time)
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
timers('optimizer-count-zeros').stop()

# Update parameters.
timers('optimizer-inner-step').start()
timers('optimizer-inner-step', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.optimizer.step()
timers('optimizer-inner-step').stop()

Expand Down
Loading

0 comments on commit 77efccc

Please sign in to comment.