Skip to content

Commit

Permalink
Merge branch 'main' into _bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed May 14, 2021
2 parents f13e5c5 + 16c8acc commit b5f2701
Show file tree
Hide file tree
Showing 51 changed files with 2,246 additions and 904 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@ dmypy.json
wandb/

# data files
data/
data/**/*.idx
data/**/*.bin
data/**/*.json*
data/**/*.txt
data/**/*.gz
data/**/*.np*
data/**/*.npy
checkpoints/
.vscode/
*.pt
Expand Down
11 changes: 8 additions & 3 deletions configs/eleutherai_cluster.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Data paths and options when using EleutherAI cluster
{
"data-path": "/mnt/ssd-cluster/data/enron/enron_text_document",
#"train-data-path": "/mnt/ssd-cluster/data/train/train_text_document",
#"test-data-path": "/mnt/ssd-cluster/data/test/test_text_document",
#"valid-data-path": "/mnt/ssd-cluster/data/valid/valid_text_document",
# or for weighted datasets:
# "train-data-paths": ["/mnt/ssd-cluster/data/enron/enron_text_document", "/mnt/ssd-cluster/data/enron/enron_text_document"],
# "test-data-paths": ["/mnt/ssd-cluster/data/enron/enron_text_document", "/mnt/ssd-cluster/data/enron/enron_text_document"],
# "valid-data-paths": ["/mnt/ssd-cluster/data/enron/enron_text_document", "/mnt/ssd-cluster/data/enron/enron_text_document"],
# "train-data-weights": [1., 2.],
# "test-data-weights": [2., 1.],
# "valid-data-weights": [0.5, 0.4],

"vocab-file": "/mnt/ssd-cluster/data/gpt2-vocab.json",
"merge-file": "/mnt/ssd-cluster/data/gpt2-merges.txt",
"save": "/mnt/ssd-cluster/checkpoints",
Expand Down
20 changes: 17 additions & 3 deletions configs/local_setup.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
# Suggested data paths when using GPT-NeoX locally
{
"data-path": "data/enron/enron_text_document",
# "train-data-path": "data/train/train_text_document",
# "test-data-path": "data/test/test_text_document",
# "valid-data-path": "data/valid/valid_text_document",

# or for weighted datasets:
# "train-data-paths": ["data/enron/enron_text_document", "data/enron/enron_text_document"],
# "test-data-paths": ["data/enron/enron_text_document", "data/enron/enron_text_document"],
# "valid-data-paths": ["data/enron/enron_text_document", "data/enron/enron_text_document"],
# "train-data-weights": [1., 2.],
# "test-data-weights": [2., 1.],
# "valid-data-weights": [0.5, 0.4],

# If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group.
# WARNING: setting this to True will override any user provided weights
# "weight_by_num_documents": false,
# "weighted_sampler_alpha": 0.3,

"vocab-file": "data/gpt2-vocab.json",
"merge-file": "data/gpt2-merges.txt",

"save": "checkpoints",
"load": "checkpoints",
"checkpoint_validation_with_forward_pass": False,

"tensorboard-dir": "tensorboard",
"log-dir": "logs",
"use_wandb": True,
Expand Down
2 changes: 0 additions & 2 deletions deepy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import os

import deepspeed
import requests
from deepspeed.launcher.runner import main


logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))

from megatron.neox_arguments import NeoXArgs
Expand Down
169 changes: 91 additions & 78 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,70 @@
from megatron import mpu
from megatron import print_rank_0
from megatron.utils import natural_sort

from megatron.text_generation_utils import get_batch, forward_model

def check_checkpoint_args(neox_args, checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""

def _compare(arg_name):
checkpoint_value = getattr(checkpoint_args, arg_name)
args_value = getattr(neox_args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the input argument value ({}).'.format(arg_name, checkpoint_value, args_value)
assert checkpoint_value == args_value, error_message
assert isinstance(checkpoint_args, dict), "args stored in checkpoint is a dict"
for checkpoint_arg_name, checkpoint_arg_value in checkpoint_args.items():
args_value = getattr(neox_args, checkpoint_arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the currently set argument value ({}).'.format(checkpoint_arg_name, checkpoint_arg_value, args_value)
assert checkpoint_arg_value == args_value, error_message

def do_forward_pass(neox_args, model, inference=False):

# set to eval mode
model_was_in_train = model.training
model.eval()

# get context tokens
# always forward full batch size
context_tokens_tensor = torch.arange(2049).repeat((neox_args.train_micro_batch_size_per_gpu, 1)).cuda()

# forward
if inference:
tokens, attention_mask, position_ids = get_batch(neox_args, context_tokens_tensor[:, :2048])
model_inputs = (tokens,
position_ids,
attention_mask,
torch.Tensor(),
)
logits, _ = forward_model(neox_args, model, model_inputs)
elif neox_args.is_pipe_parallel:
data_iterator = iter([{"text": context_tokens_tensor}])
_, logits = model.eval_batch(data_iter=data_iterator, return_logits=True)
else:
tokens, attention_mask, position_ids = get_batch(neox_args, context_tokens_tensor[:, :2048])
logits = model((tokens, position_ids, attention_mask))

# reset to train mode, if model was in training before
if model_was_in_train:
model.train()

_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
_compare('model_parallel_size')
if logits is not None:
logits = logits.detach().cpu()[0] # just return first batch item (they are all equal)

return logits

def check_forward_pass(neox_args, model, checkpoint_logits, inference):
# do forward pass with loaded checkpoint
logits = do_forward_pass(neox_args=neox_args, model=model, inference=inference)

# check
if logits is not None and checkpoint_logits is not None: # this could be the case for non-final pipeline stages
if not (logits == checkpoint_logits).all().item():
if mpu.get_data_parallel_rank() == 0:
print(" > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result")
assert torch.isclose(logits, checkpoint_logits).all().item(), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result"

def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)


def get_checkpoint_name(checkpoints_path, iteration,
release=False, mp_rank=None):
"""A unified checkpoint name."""
Expand All @@ -73,13 +108,6 @@ def get_checkpoint_name(checkpoints_path, iteration,
else mp_rank),
'model_optim_rng.pt')


def get_checkpoint_tracker_filename(checkpoints_path):
"""Tracker file rescords the latest chckpoint during
training to restart from."""
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


def delete_old_checkpoints(save_dir, n_to_keep):
if torch.distributed.get_rank() == 0:
ckpt_dir_regex = r'global_step[\d]*'
Expand All @@ -97,20 +125,35 @@ def delete_old_checkpoints(save_dir, n_to_keep):
except FileNotFoundError:
pass


def save_ds_checkpoint(iteration, model, neox_args):
"""Save a model checkpoint."""
sd = {'iteration': iteration}
sd = {
'iteration': iteration,
'args': {
'num_layers': neox_args.num_layers,
'hidden_size': neox_args.hidden_size,
'num_attention_heads': neox_args.num_attention_heads,
'max_position_embeddings': neox_args.max_position_embeddings,
'make_vocab_size_divisible_by': neox_args.make_vocab_size_divisible_by,
'padded_vocab_size': neox_args.padded_vocab_size,
'tokenizer_type': neox_args.tokenizer_type,
'model_parallel_size': neox_args.model_parallel_size
}
}
# rng states.
if not neox_args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()

if neox_args.checkpoint_validation_with_forward_pass:
logits = do_forward_pass(neox_args=neox_args, model=model)
sd['checkpoint_validation_logits'] = logits

model.save_checkpoint(neox_args.save, client_state=sd)


def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint."""

Expand All @@ -119,14 +162,6 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
else:
raise ValueError('Must be using deepspeed to use neox')

# Wait so everyone is done (necessary)
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(neox_args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))

# Wait so everyone is done (necessary)
torch.distributed.barrier()
if neox_args.keep_last_n_checkpoints is not None:
Expand All @@ -135,39 +170,9 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
# Wait so everyone is done (not necessary)
torch.distributed.barrier()


def load_checkpoint(neox_args, model, optimizer, lr_scheduler):
def load_checkpoint(neox_args, model, optimizer, lr_scheduler, inference=False):
"""Load a model checkpoint and return the iteration."""

# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(neox_args.load)

# If no tracker file, return iteration zero.
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0

# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()

assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)

if neox_args.deepspeed:
load_optim_and_scheduler = not neox_args.no_load_optim # TODO: These should be configured by separate args
checkpoint_name, state_dict = model.load_checkpoint(neox_args.load,
Expand All @@ -177,34 +182,42 @@ def load_checkpoint(neox_args, model, optimizer, lr_scheduler):
if checkpoint_name is None:
if mpu.get_data_parallel_rank() == 0:
print("Unable to load checkpoint.")
return iteration
return 0 # iteration 0, if not checkpoint loaded
else:
raise ValueError('Must be using deepspeed to use neox')

# Set iteration.
if neox_args.finetune or release:
if neox_args.finetune:
iteration = 0
else:
try:
iteration = state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
except KeyError:
print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
checkpoint_name))
sys.exit()
iteration = state_dict.get('iteration') or state_dict.get("total_iters") # total_iters backward compatible with older checkpoints
if iteration is None:
raise ValueError('Unable to load iteration from checkpoint {}, exiting'.format(checkpoint_name))

# Check arguments.
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args)
print_rank_0(' > validated currently set args with arguments in the checkpoint ...')
else:
print_rank_0('could not find arguments in the checkpoint ...')
print_rank_0(' > could not find arguments in the checkpoint for validation...')

# Check loaded checkpoint with forward pass
if neox_args.checkpoint_validation_with_forward_pass:
if "checkpoint_validation_logits" in state_dict:
check_forward_pass(
neox_args=neox_args,
model=model,
checkpoint_logits=state_dict["checkpoint_validation_logits"],
inference=inference
)
print_rank_0(' > validated loaded checkpoint with forward pass ...')
else:
if mpu.get_data_parallel_rank() == 0:
print(' > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}'.format(checkpoint_name))

# rng states.
if not release and not neox_args.finetune and not neox_args.no_load_rng:
if not neox_args.finetune and not neox_args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
Expand Down
Loading

0 comments on commit b5f2701

Please sign in to comment.