Skip to content

Commit

Permalink
Fix some bugs, add exit-duration capability
Browse files Browse the repository at this point in the history
  • Loading branch information
mshoeybi authored and deepakn94 committed Dec 19, 2020
1 parent 5131590 commit a31833c
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 42 deletions.
1 change: 1 addition & 0 deletions megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
Expand Down
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ def _add_training_args(parser):
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after the iteration is divisible '
'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
Expand Down
22 changes: 17 additions & 5 deletions megatron/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type)

# Blend.
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = BlendableDataset(test_datasets, weights)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)

# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)

return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
Expand Down
21 changes: 15 additions & 6 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup)
train_datasets.append(train_ds)
valid_datasets.append(valid_ds)
test_datasets.append(test_ds)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)

# Blend.
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = BlendableDataset(test_datasets, weights)
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)

return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
Expand Down
9 changes: 7 additions & 2 deletions megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,13 @@ def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()


def update_num_microbatches(consumed_samples):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples)
def get_current_global_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()


def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
consistency_check)


def get_tokenizer():
Expand Down
38 changes: 22 additions & 16 deletions megatron/microbatches.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,16 @@ class NumMicroBatchesCalculator(ABC):

def __init__(self):
self.num_micro_batches = None
self.current_global_batch_size = None

def get(self):
return self.num_micro_batches

def get_current_global_batch_size(self):
return self.current_global_batch_size

@abstractmethod
def update(self, consumed_samples):
def update(self, consumed_samples, consistency_check):
pass


Expand All @@ -78,8 +82,9 @@ def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
self.num_micro_batches = global_batch_size // \
micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size

def update(self, consumed_samples):
def update(self, consumed_samples, consistency_check):
pass


Expand Down Expand Up @@ -128,24 +133,25 @@ def __init__(self, start_batch_size, batch_size_increment, ramup_samples,
self.rampup_samples_per_increment = self.ramup_samples / num_increments

# Initialize number of microbatches.
self.update(0)
self.update(0, False)


def update(self, consumed_samples):
def update(self, consumed_samples, consistency_check):

if consumed_samples > self.ramup_samples:
current_global_batch_size = self.global_batch_size
self.current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
assert current_global_batch_size <= self.global_batch_size

assert current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \
'data parallel size ({})'.format(current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size)
self.num_micro_batches = current_global_batch_size // \
self.current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
assert self.current_global_batch_size <= self.global_batch_size

if consistency_check:
assert self.current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \
'data parallel size ({})'.format(self.current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size)
self.num_micro_batches = self.current_global_batch_size // \
self.micro_batch_times_data_parallel_size
107 changes: 94 additions & 13 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
from datetime import datetime
import math
import sys
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

from megatron import get_args
from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size
from megatron import get_num_microbatches
from megatron import update_num_microbatches
from megatron import mpu
Expand All @@ -44,6 +49,13 @@
from megatron.utils import report_memory


def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print_rank_0('[' + string + '] datetime: {} '.format(time_str))


def pretrain(train_valid_test_dataset_provider, model_provider,
forward_step_func, extra_args_provider=None, args_defaults={}):
"""Main training program.
Expand Down Expand Up @@ -74,20 +86,35 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)

# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
print_rank_0('time took to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
print_datetime('after megatron is initialized')

args = get_args()
timers = get_timers()

# Model, optimizer, and learning rate.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')

# Data stuff.
timers('train/valid/test data iterators').start()
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test data iterators').stop()
print_datetime('after dataloaders are build')

# Print setup timing.
print_rank_0('done with setups ...')
Expand All @@ -99,6 +126,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
print_datetime('after training is done')

if args.do_valid:
prefix = 'the end of training for val data'
Expand Down Expand Up @@ -132,13 +160,11 @@ def update_train_iters(args):
consumed_samples = 0
# Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]):
update_num_microbatches(consumed_samples)
consumed_samples += get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
update_num_microbatches(consumed_samples, consistency_check=False)
consumed_samples += get_current_global_batch_size()
iterations += 1
# Reset
update_num_microbatches(0)
update_num_microbatches(0, consistency_check=False)
# Constant phase
# Note that we throw away any partial last batch.
iterations += (args.train_samples - consumed_samples) // \
Expand Down Expand Up @@ -267,7 +293,15 @@ def setup_model_and_optimizer(model_provider_func):
lr_scheduler = get_learning_rate_scheduler(optimizer)

if args.load is not None:
timers = get_timers()
# Extra barrier is added to make sure all ranks report the
# max time.
torch.distributed.barrier()
timers('load checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('load checkpoint').stop()
timers.log(['load checkpoint'])
else:
args.iteration = 0

Expand Down Expand Up @@ -685,11 +719,22 @@ def add_to_logging(name):

# Tensorboard values.
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
writer.add_scalar('learning_rate-iterations', learning_rate, iteration)
writer.add_scalar('learning_rate-samples', learning_rate,
args.consumed_train_samples)
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
writer.add_scalar('batch_size-iterations', batch_size, iteration)
writer.add_scalar('batch_size-samples', batch_size,
args.consumed_train_samples)
for key in loss_dict:
writer.add_scalar(key, loss_dict[key], iteration)
writer.add_scalar(key, loss_dict[key] + '-iterations', iteration)
writer.add_scalar(key, loss_dict[key] + '-samples',
args.consumed_train_samples)
if args.fp16:
writer.add_scalar('loss_scale', loss_scale, iteration)
writer.add_scalar('loss_scale-iterations', loss_scale, iteration)
writer.add_scalar('loss_scale-samples', loss_scale,
args.consumed_train_samples)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
Expand All @@ -703,6 +748,8 @@ def add_to_logging(name):
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
Expand Down Expand Up @@ -732,6 +779,18 @@ def add_to_logging(name):
return report_memory_flag


def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
torch.distributed.barrier()
timers('save checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('save checkpoint').stop()
timers.log(['save checkpoint'])


def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator):
"""Train the model function."""
Expand All @@ -748,6 +807,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration = args.iteration

timers('interval time').start()
print_datetime('before the start of training step')
report_memory_flag = True
while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
Expand Down Expand Up @@ -777,9 +837,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler)

# Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
saved_checkpoint = True


# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
Expand All @@ -789,14 +853,31 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
valid_data_iterator, model,
iteration, False)

# Exiting based on duration
if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
done_cuda = torch.cuda.IntTensor(
[train_time > args.exit_duration_in_mins])
torch.distributed.all_reduce(
done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
if done:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit()

# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print_rank_0('rank: {} | time: {} | exiting the program at '
'iteration {}'.format(rank, time_str, iteration))
print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit()


return iteration


Expand Down

0 comments on commit a31833c

Please sign in to comment.